Source code for lightning_pose.data.extractor

"""Helper class to extract labeled data from a data module."""

from typing import Literal

import imgaug.augmenters as iaa
import torch
from jaxtyping import Float
from lightning.pytorch.utilities import CombinedLoader

from lightning_pose.data.datamodules import BaseDataModule, UnlabeledDataModule
from lightning_pose.data.datasets import (
    BaseTrackingDataset,
    HeatmapDataset,
    MultiviewHeatmapDataset,
)

# to ignore imports for sphix-autoapidoc
__all__ = ["DataExtractor"]


[docs] class DataExtractor: """Helper class to extract all data from a data module."""
[docs] def __init__( self, data_module: BaseDataModule | UnlabeledDataModule, cond: Literal["train", "test", "val"] = "train", extract_images: bool = False, remove_augmentations: bool = True, ) -> None: """Initialize DataExtractor. Args: data_module: data module containing the labeled dataset and splits. cond: which data split to extract (``"train"``, ``"val"``, or ``"test"``). extract_images: if True, also extract and return image tensors. remove_augmentations: if True, rebuild the dataset with only resize augmentation before extracting, to avoid contaminating PCA fits with augmented data. """ self.cond = cond self.extract_images = extract_images self.remove_augmentations = remove_augmentations if self.remove_augmentations: assert isinstance( data_module.dataset, (BaseTrackingDataset, HeatmapDataset, MultiviewHeatmapDataset) ) imgaug_curr = data_module.dataset.imgaug_transform assert isinstance(imgaug_curr, iaa.Sequential) if len(imgaug_curr) == 1 and isinstance(imgaug_curr[0], iaa.Resize): # current augmentation just resizes; keep this self.data_module = data_module else: # create a simple resize-only augmentation pipeline for PCA # use the same resize dimensions as the original dataset dataset_old = data_module.dataset image_resize_height = dataset_old.image_resize_height image_resize_width = dataset_old.image_resize_width imgaug_new = iaa.Sequential([ iaa.Resize({'height': image_resize_height, 'width': image_resize_width}) ]) if isinstance(dataset_old, HeatmapDataset): dataset_new = HeatmapDataset( root_directory=dataset_old.root_directory, csv_path=dataset_old.csv_path, image_resize_height=dataset_old.image_resize_height, image_resize_width=dataset_old.image_resize_width, imgaug_transform=imgaug_new, downsample_factor=dataset_old.downsample_factor, do_context=dataset_old.do_context, ) elif isinstance(dataset_old, BaseTrackingDataset): dataset_new = BaseTrackingDataset( root_directory=dataset_old.root_directory, csv_path=dataset_old.csv_path, image_resize_height=dataset_old.image_resize_height, image_resize_width=dataset_old.image_resize_width, imgaug_transform=imgaug_new, do_context=dataset_old.do_context, ) elif isinstance(dataset_old, MultiviewHeatmapDataset): dataset_new = MultiviewHeatmapDataset( root_directory=dataset_old.root_directory, csv_paths=dataset_old.csv_paths, view_names=dataset_old.view_names, image_resize_height=dataset_old.image_resize_height, image_resize_width=dataset_old.image_resize_width, imgaug_transform=imgaug_new, do_context=dataset_old.do_context, ) else: raise NotImplementedError # rebuild data_module with new dataset if isinstance(data_module, UnlabeledDataModule): data_module_new = UnlabeledDataModule( dataset=dataset_new, video_paths_list=data_module.video_paths_list, train_batch_size=data_module.train_batch_size, val_batch_size=data_module.val_batch_size, test_batch_size=data_module.test_batch_size, num_workers=data_module.num_workers, train_probability=data_module.train_probability, val_probability=data_module.val_probability, train_frames=data_module.train_frames, dali_config=data_module.dali_config, torch_seed=data_module.torch_seed, ) # data_module_new.setup() happens internally elif isinstance(data_module, BaseDataModule): data_module_new = BaseDataModule( dataset=dataset_new, train_batch_size=data_module.train_batch_size, val_batch_size=data_module.val_batch_size, test_batch_size=data_module.test_batch_size, num_workers=data_module.num_workers, train_probability=data_module.train_probability, val_probability=data_module.val_probability, train_frames=data_module.train_frames, torch_seed=data_module.torch_seed, ) else: raise NotImplementedError self.data_module = data_module_new else: self.data_module = data_module
@property def dataset_length(self) -> int: """Number of examples in the selected data split. Returns: Length of the ``train``, ``val``, or ``test`` dataset depending on ``self.cond``. """ name = f'{self.cond}_dataset' return len(getattr(self.data_module, name))
[docs] def get_loader(self) -> torch.utils.data.DataLoader | CombinedLoader: """Return the dataloader for the selected split. Returns: DataLoader or ``CombinedLoader`` corresponding to ``self.cond``. Raises: ValueError: if ``self.cond`` is not ``"train"``, ``"val"``, or ``"test"``. """ if self.cond == 'train': return self.data_module.train_dataloader() # type: ignore[return-value] if self.cond == 'val': return self.data_module.val_dataloader() if self.cond == 'test': return self.data_module.test_dataloader() raise ValueError(f'cond must be "train", "val", or "test", got {self.cond!r}')
[docs] @staticmethod def verify_labeled_loader( loader: torch.utils.data.DataLoader | CombinedLoader, ) -> torch.utils.data.DataLoader: """Extract and return the labeled DataLoader from a potentially combined loader. Args: loader: either a plain ``DataLoader`` or a ``CombinedLoader`` containing labeled and unlabeled sub-loaders. Returns: The labeled ``DataLoader``. """ if isinstance(loader, torch.utils.data.DataLoader): return loader # CombinedLoader wraps labeled + unlabeled; extract only the labeled one return loader.iterables['labeled'] # type: ignore[index]
[docs] def iterate_over_dataloader( self, loader: torch.utils.data.DataLoader ) -> tuple[ torch.Tensor, ( Float[torch.Tensor, "num_examples 3 image_width image_height"] | Float[torch.Tensor, "num_examples frames 3 image_width image_height"] | None ), ]: """Iterate over a dataloader and collect keypoints (and optionally images). Args: loader: labeled dataloader to iterate over. Returns: Tuple of: - concatenated keypoints tensor of shape ``(num_examples, num_targets)``. - concatenated image tensor or ``None`` if ``self.extract_images`` is False. """ keypoints_list = [] images_list = [] for _ind, batch in enumerate(loader): keypoints_list.append(batch['keypoints']) if self.extract_images: images_list.append(batch['images']) concat_keypoints = torch.cat(keypoints_list, dim=0) if self.extract_images: concat_images = torch.cat(images_list, dim=0) else: concat_images = None assert concat_keypoints.shape == ( self.dataset_length, keypoints_list[0].shape[1], ) return concat_keypoints, concat_images
[docs] def __call__( self, ) -> tuple[ torch.Tensor, ( Float[torch.Tensor, "num_examples 3 image_width image_height"] | Float[torch.Tensor, "num_examples frames 3 image_width image_height"] | None ), ]: """Extract all keypoints (and optionally images) from the selected data split. Returns: Tuple of: - concatenated keypoints tensor of shape ``(num_examples, num_targets)``. - concatenated image tensor or ``None`` if ``self.extract_images`` is False. """ loader = self.get_loader() loader = self.verify_labeled_loader(loader) return self.iterate_over_dataloader(loader)