Source code for lightning_pose.data.utils

"""Dataset/data module utilities."""
import os
from typing import Any, Literal, Tuple, Union

import imgaug.augmenters as iaa
import lightning.pytorch as pl
import numpy as np
import torch
from torchtyping import TensorType
from typeguard import typechecked

from lightning_pose.data.datatypes import SemiSupervisedDataLoaderDict

# to ignore imports for sphix-autoapidoc
__all__ = [
    "DataExtractor",
    "split_sizes_from_probabilities",
    "clean_any_nans",
    "count_frames",
    "compute_num_train_frames",
    "generate_heatmaps",
    "evaluate_heatmaps_at_location",
    "undo_affine_transform",
    "undo_affine_transform_batch",
]


[docs] class DataExtractor(object): """Helper class to extract all data from a data module."""
[docs] def __init__( self, data_module: pl.LightningDataModule, cond: Literal["train", "test", "val"] = "train", extract_images: bool = False, remove_augmentations: bool = True, ) -> None: self.cond = cond self.extract_images = extract_images self.remove_augmentations = remove_augmentations if self.remove_augmentations: imgaug_curr = data_module.dataset.imgaug_transform if len(imgaug_curr) == 1 and isinstance(imgaug_curr[0], iaa.Resize): # current augmentation just resizes; keep this self.data_module = data_module else: from lightning_pose.data.datamodules import BaseDataModule, UnlabeledDataModule from lightning_pose.data.datasets import ( BaseTrackingDataset, HeatmapDataset, MultiviewHeatmapDataset, ) # make new augmentation pipeline that just resizes if not isinstance(imgaug_curr[-1], iaa.Resize): # we currently assume the last transform is resizing raise NotImplementedError # keep the resizing aug imgaug_new = iaa.Sequential([imgaug_curr[-1]]) # TODO: is there a cleaner way to do this? # rebuild dataset with new aug pipeline dataset_old = data_module.dataset if isinstance(data_module.dataset, HeatmapDataset): dataset_new = HeatmapDataset( root_directory=dataset_old.root_directory, csv_path=dataset_old.csv_path, image_resize_height=dataset_old.image_resize_height, image_resize_width=dataset_old.image_resize_width, imgaug_transform=imgaug_new, downsample_factor=dataset_old.downsample_factor, do_context=dataset_old.do_context, ) elif isinstance(dataset_old, BaseTrackingDataset): dataset_new = BaseTrackingDataset( root_directory=dataset_old.root_directory, csv_path=dataset_old.csv_path, image_resize_height=dataset_old.image_resize_height, image_resize_width=dataset_old.image_resize_width, imgaug_transform=imgaug_new, do_context=dataset_old.do_context, ) elif isinstance(dataset_old, MultiviewHeatmapDataset): dataset_new = MultiviewHeatmapDataset( root_directory=dataset_old.root_directory, csv_paths=dataset_old.csv_paths, view_names=dataset_old.view_names, image_resize_height=dataset_old.image_resize_height, image_resize_width=dataset_old.image_resize_width, imgaug_transform=imgaug_new, do_context=dataset_old.do_context, ) else: raise NotImplementedError # rebuild data_module with new dataset if isinstance(data_module, UnlabeledDataModule): data_module_new = UnlabeledDataModule( dataset=dataset_new, video_paths_list=data_module.video_paths_list, train_batch_size=data_module.train_batch_size, val_batch_size=data_module.val_batch_size, test_batch_size=data_module.test_batch_size, num_workers=data_module.num_workers, train_probability=data_module.train_probability, val_probability=data_module.val_probability, train_frames=data_module.train_frames, dali_config=data_module.dali_config, torch_seed=data_module.torch_seed, ) # data_module_new.setup() happens internally elif isinstance(data_module, BaseDataModule): data_module_new = BaseDataModule( dataset=dataset_new, train_batch_size=data_module.train_batch_size, val_batch_size=data_module.val_batch_size, test_batch_size=data_module.test_batch_size, num_workers=data_module.num_workers, train_probability=data_module.train_probability, val_probability=data_module.val_probability, train_frames=data_module.train_frames, torch_seed=data_module.torch_seed, ) else: raise NotImplementedError self.data_module = data_module_new else: self.data_module = data_module
@property def dataset_length(self) -> int: name = "%s_dataset" % self.cond return len(getattr(self.data_module, name))
[docs] def get_loader( self, ) -> torch.utils.data.DataLoader | SemiSupervisedDataLoaderDict: if self.cond == "train": return self.data_module.train_dataloader() if self.cond == "val": return self.data_module.val_dataloader() if self.cond == "test": return self.data_module.test_dataloader()
[docs] @staticmethod def verify_labeled_loader( loader: torch.utils.data.DataLoader | SemiSupervisedDataLoaderDict ) -> torch.utils.data.DataLoader: if isinstance(loader, torch.utils.data.DataLoader): labeled_loader = loader else: # if we have a dictionary of dataloaders, we take the loader called # "labeled" (the loader called "unlabeled" doesn't have keypoints) labeled_loader = loader.iterables["labeled"] return labeled_loader
[docs] def iterate_over_dataloader( self, loader: torch.utils.data.DataLoader ) -> Tuple[ TensorType["num_examples", Any], Union[ TensorType["num_examples", 3, "image_width", "image_height"], TensorType["num_examples", "frames", 3, "image_width", "image_height"], None, ], ]: keypoints_list = [] images_list = [] for ind, batch in enumerate(loader): keypoints_list.append(batch["keypoints"]) if self.extract_images: images_list.append(batch["images"]) concat_keypoints = torch.cat(keypoints_list, dim=0) if self.extract_images: concat_images = torch.cat(images_list, dim=0) else: concat_images = None # assert that indeed the number of columns does not change after concatenation, # and that the number of rows is the dataset length. assert concat_keypoints.shape == ( self.dataset_length, keypoints_list[0].shape[1], ) return concat_keypoints, concat_images
[docs] def __call__( self, ) -> Tuple[ TensorType["num_examples", Any], Union[ TensorType["num_examples", 3, "image_width", "image_height"], TensorType["num_examples", "frames", 3, "image_width", "image_height"], None, ], ]: loader = self.get_loader() loader = self.verify_labeled_loader(loader) return self.iterate_over_dataloader(loader)
[docs] @typechecked def split_sizes_from_probabilities( total_number: int, train_probability: float, val_probability: float | None = None, test_probability: float | None = None, ) -> list[int]: """Returns the number of examples for train, val and test given split probs. Args: total_number: total number of examples in dataset train_probability: fraction of examples used for training val_probability: fraction of examples used for validation test_probability: fraction of examples used for test. Defaults to None. Can be computed as the remaining examples. Returns: [num training examples, num validation examples, num test examples] """ if test_probability is None and val_probability is None: remaining_probability = 1.0 - train_probability # round each to 5 decimal places (issue with floating point precision) val_probability = round(remaining_probability / 2, 5) test_probability = round(remaining_probability / 2, 5) elif test_probability is None: test_probability = 1.0 - train_probability - val_probability # probabilities should add to one assert test_probability + train_probability + val_probability == 1.0 # compute numbers from probabilities train_number = int(np.floor(train_probability * total_number)) val_number = int(np.floor(val_probability * total_number)) # if we lose extra examples by flooring, send these to train_number or test_number, depending leftover = total_number - train_number - val_number if leftover < 5: # very few samples, let's bulk up train train_number += leftover test_number = 0 else: test_number = leftover # make sure that we have at least one validation sample if val_number == 0: train_number -= 1 val_number += 1 if train_number < 1: raise ValueError("Must have at least two labeled frames, one train and one validation") # assert that we're using all datapoints assert train_number + test_number + val_number == total_number return [train_number, val_number, test_number]
[docs] @typechecked def clean_any_nans(data: torch.Tensor, dim: int) -> torch.Tensor: """Remove samples from a data array that contain nans.""" # currently supports only 2D arrays nan_bool = ( torch.sum(torch.isnan(data), dim=dim) > 0 ) # e.g., when dim == 0, those columns (keypoints) that have >0 nans if dim == 0: return data[:, ~nan_bool] elif dim == 1: return data[~nan_bool]
[docs] @typechecked def count_frames(video_file: str) -> int: """ Simple function to count the number of frames in a video. """ assert os.path.isfile(video_file) import cv2 cap = cv2.VideoCapture(video_file) num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) cap.release() return num_frames
[docs] @typechecked def compute_num_train_frames( len_train_dataset: int, train_frames: int | float | None = None, ) -> int: """Quickly compute number of training frames for a given dataset. Args: len_train_dataset: total number of frames in training dataset train_frames: <=1 - fraction of total train frames used for training >1 - number of total train frames used for training Returns: total number of train frames """ if train_frames is None: n_train_frames = len_train_dataset else: if train_frames >= len_train_dataset: # take max number of train frames print("Warning! Requested training frames exceeds training set size; using all") n_train_frames = len_train_dataset elif train_frames == 1: # assume this is a fraction; use full dataset n_train_frames = len_train_dataset elif train_frames > 1: # take this number of train frames n_train_frames = int(train_frames) elif train_frames > 0: # take this fraction of train frames n_train_frames = int(train_frames * len_train_dataset) else: raise ValueError("train_frames must be >0") return n_train_frames
# @typechecked
[docs] def generate_heatmaps( keypoints: TensorType["batch", "num_keypoints", 2], height: int, width: int, output_shape: Tuple[int, int], sigma: float = 1.25, uniform_heatmaps: bool = False, ) -> TensorType["batch", "num_keypoints", "height", "width"]: """Generate 2D Gaussian heatmaps from mean and sigma. Args: keypoints: coordinates that serve as mean of gaussian bump height: height of reshaped image (pixels, e.g., 128, 256, 512...) width: width of reshaped image (pixels, e.g., 128, 256, 512...) output_shape: dimensions of downsampled heatmap, (height, width) sigma: control spread of gaussian uniform_heatmaps: output uniform heatmaps if missing ground truth label, rather than skip Returns: batch of 2D heatmaps """ keypoints = keypoints.detach().clone() out_height = output_shape[0] out_width = output_shape[1] keypoints[:, :, 1] *= out_height / height keypoints[:, :, 0] *= out_width / width nan_idxs = torch.isnan(keypoints)[:, :, 0] xv = torch.arange(out_width, device=keypoints.device) yv = torch.arange(out_height, device=keypoints.device) # note flipped order because of pytorch's ij and numpy's xy indexing for meshgrid xx, yy = torch.meshgrid(yv, xv, indexing="ij") # adds batch and num_keypoints dimensions to grids xx = xx.unsqueeze(0).unsqueeze(0) yy = yy.unsqueeze(0).unsqueeze(0) # adds dimension corresponding to the first dimension of the 2d grid keypoints = keypoints.unsqueeze(2) # evaluates 2d gaussian with mean equal to the keypoint and var equal to sigma^2 heatmaps = (yy - keypoints[:, :, :, :1]) ** 2 # also flipped order here heatmaps += (xx - keypoints[:, :, :, 1:]) ** 2 # also flipped order here heatmaps *= -1 heatmaps /= 2 * sigma**2 heatmaps = torch.exp(heatmaps) # normalize all heatmaps to one heatmaps = heatmaps / torch.sum(heatmaps, dim=(2, 3), keepdim=True) # replace nans with zeros heatmaps # (all zeros heatmaps are ignored in the supervised heatmap loss) if uniform_heatmaps: filler_heatmap = torch.ones( (out_height, out_width), device=keypoints.device ) / (out_height * out_width) else: filler_heatmap = torch.zeros((out_height, out_width), device=keypoints.device) heatmaps[nan_idxs] = filler_heatmap return heatmaps
# @typechecked
[docs] def evaluate_heatmaps_at_location( heatmaps: TensorType["batch", "num_keypoints", "heatmap_height", "heatmap_width"], locs: TensorType["batch", "num_keypoints", 2], sigma: float = 1.25, # sigma used for generating heatmaps num_stds: int = 2, # num standard deviations of pixels to compute confidence ) -> TensorType["batch", "num_keypoints"]: """Evaluate 4D heatmaps using a 3D location tensor (last dim is x, y coords). Since the model outputs heatmaps with a standard deviation of sigma, confidence will be spread across neighboring pixels. To account for this, confidence is computed by taking all pixels within two standard deviations of the predicted pixel.""" pix_to_consider = int(np.floor(sigma * num_stds)) # get all pixels within num_stds. num_pad = pix_to_consider heatmaps_padded = torch.zeros( ( heatmaps.shape[0], heatmaps.shape[1], heatmaps.shape[2] + num_pad * 2, heatmaps.shape[3] + num_pad * 2, ), device=heatmaps.device, ) heatmaps_padded[:, :, num_pad:-num_pad, num_pad:-num_pad] = heatmaps i = torch.arange(heatmaps_padded.shape[0], device=heatmaps_padded.device).reshape( -1, 1, 1, 1 ) j = torch.arange(heatmaps_padded.shape[1], device=heatmaps_padded.device).reshape( 1, -1, 1, 1 ) k = locs[:, :, None, 1, None].type(torch.int64) + num_pad m = locs[:, :, 0, None, None].type(torch.int64) + num_pad offsets = list(np.arange(-pix_to_consider, pix_to_consider + 1)) vals_all = [] for offset in offsets: k_offset = k + offset for offset_2 in offsets: m_offset = m + offset_2 # get rid of singleton dims vals = heatmaps_padded[i, j, k_offset, m_offset].squeeze(-1).squeeze(-1) vals_all.append(vals) vals = torch.stack(vals_all, 0).sum(0) return vals
# @typechecked
[docs] def undo_affine_transform( keypoints: TensorType["seq_len", "num_keypoints", 2], transform: TensorType["seq_len", 2, 3] | TensorType[2, 3], ) -> TensorType["seq_len", "num_keypoints", 2]: """Undo an affine transform given a tensor of keypoints and the tranform matrix.""" # add 1s to get keypoints in projective geometry coords ones = torch.ones( (keypoints.shape[0], keypoints.shape[1], 1), dtype=keypoints.dtype, device=keypoints.device, requires_grad=True, ) kps_aff = torch.concat([keypoints, ones], axis=2) mat = torch.clone(transform).detach() if len(transform.shape) == 2: # single transform for all frames; add batch dim mat = mat.unsqueeze(0) # create inverse matrices mats_inv_torch = [] for idx in range(mat.shape[0]): mat_inv_ = torch.linalg.inv(mat[idx, :, :2]) mat_inv = torch.concat( [mat_inv_, torch.matmul(-mat_inv_, mat[idx, :, -1, None])], dim=1 ) mats_inv_torch.append( torch.tensor( torch.transpose(mat_inv, 1, 0), dtype=keypoints.dtype, device=keypoints.device, requires_grad=True, ) ) # make a single block of inverse matrices if len(mats_inv_torch) == 1: # replicate this inverse matrix for each element of the batch mat_inv_torch = torch.tile( mats_inv_torch[0].unsqueeze(0), dims=(keypoints.shape[0], 1, 1) ) else: # different transformation for each element of the batch mat_inv_torch = torch.stack(mats_inv_torch, dim=0) # apply inverse matrix to each element individually using batch matrix multiply kps_noaug = torch.bmm(kps_aff, mat_inv_torch) return kps_noaug
[docs] def undo_affine_transform_batch( keypoints_augmented: TensorType["seq_len", "num_keypoints x 2"], transforms: Union[ TensorType["seq_len", "h":2, "w":3], TensorType["h":2, "w":3], TensorType["seq_len", "null":1], TensorType["null":1], TensorType["num_views", "h":2, "w":3], TensorType["num_views", "null":1, "null":1], ], is_multiview: bool = False, ) -> TensorType["seq_len", "num_keypoints x 2"]: """Potentially undo an affine transform given a tensor of keypoints and the tranform matrix.""" # undo augmentation if needed if transforms.shape[-1] == 3: # initial shape is (seq_len, n_keypoints * 2) # reshape to (seq_len, n_keypoints, 2) pred_kps = torch.reshape( keypoints_augmented, (keypoints_augmented.shape[0], -1, 2) ) # undo if not is_multiview: # single affine transform for the whole batch pred_kps = undo_affine_transform(pred_kps, transforms) else: # each view has its own affine transform that we need to undo num_views = transforms.shape[0] kps_per_view = int(pred_kps.shape[1] / num_views) for v in range(num_views): idx_beg = v * kps_per_view idx_end = (v + 1) * kps_per_view # undo pred_kps[:, idx_beg:idx_end] = undo_affine_transform( pred_kps[:, idx_beg:idx_end], transforms[v] ) # reshape to (seq_len, n_keypoints * 2) keypoints_unaugmented = torch.reshape(pred_kps, (pred_kps.shape[0], -1)) else: keypoints_unaugmented = keypoints_augmented return keypoints_unaugmented