Source code for lightning_pose.data.datamodules

"""Data modules split a dataset into train, val, and test modules."""

import copy
from typing import Literal

import imgaug.augmenters as iaa
import lightning.pytorch as pl
import torch
from lightning.pytorch.utilities import CombinedLoader
from omegaconf import DictConfig
from torch.utils.data import DataLoader, Subset, random_split

from lightning_pose.data.dali import PrepareDALI
from lightning_pose.data.utils import (
    SemiSupervisedDataLoaderDict,
    compute_num_train_frames,
    split_sizes_from_probabilities,
)
from lightning_pose.utils.io import check_video_paths

# to ignore imports for sphix-autoapidoc
__all__ = [
    "BaseDataModule",
    "UnlabeledDataModule",
]


[docs] class BaseDataModule(pl.LightningDataModule): """Splits a labeled dataset into train, val, and test data loaders."""
[docs] def __init__( self, dataset: torch.utils.data.Dataset, train_batch_size: int = 16, val_batch_size: int = 16, test_batch_size: int = 1, num_workers: int = 8, train_probability: float = 0.8, val_probability: float | None = None, test_probability: float | None = None, train_frames: float | int | None = None, torch_seed: int = 42, ) -> None: """Data module splits a dataset into train, val, and test data loaders. Args: dataset: base dataset to be split into train/val/test train_batch_size: number of samples of training batches val_batch_size: number of samples in validation batches test_batch_size: number of samples in test batches num_workers: number of threads used for prefetching data train_probability: fraction of full dataset used for training val_probability: fraction of full dataset used for validation test_probability: fraction of full dataset used for testing train_frames: if integer, select this number of training frames from the initially selected train frames (defined by `train_probability`); if float, must be between 0 and 1 (exclusive) and defines the fraction of the initially selected train frames torch_seed: control data splits """ super().__init__() self.dataset = dataset self.train_batch_size = train_batch_size self.val_batch_size = val_batch_size self.test_batch_size = test_batch_size self.num_workers = num_workers self.train_probability = train_probability self.val_probability = val_probability self.test_probability = test_probability self.train_frames = train_frames self.train_dataset = None # populated by self.setup() self.val_dataset = None # populated by self.setup() self.test_dataset = None # populated by self.setup() self.torch_seed = torch_seed
[docs] def setup(self, stage: str | None = None) -> None: # stage arg needed for ptl datalen = self.dataset.__len__() print(f"Number of labeled images in the full dataset (train+val+test): {datalen}") # split data based on provided probabilities data_splits_list = split_sizes_from_probabilities( datalen, train_probability=self.train_probability, val_probability=self.val_probability, test_probability=self.test_probability, ) if len(self.dataset.imgaug_transform) == 1: # no augmentations in the pipeline; subsets can share same underlying dataset self.train_dataset, self.val_dataset, self.test_dataset = random_split( self.dataset, data_splits_list, generator=torch.Generator().manual_seed(self.torch_seed), ) else: # augmentations in the pipeline; we want validation and test datasets that only resize # we can't simply change the imgaug pipeline in the datasets after they've been split # because the subsets actually point to the same underlying dataset, so we create # separate datasets here train_idxs, val_idxs, test_idxs = random_split( range(len(self.dataset)), data_splits_list, generator=torch.Generator().manual_seed(self.torch_seed), ) self.train_dataset = Subset(copy.deepcopy(self.dataset), indices=list(train_idxs)) self.val_dataset = Subset(copy.deepcopy(self.dataset), indices=list(val_idxs)) self.test_dataset = Subset(copy.deepcopy(self.dataset), indices=list(test_idxs)) # only use the final resize transform for the validation and test datasets resize_transform = iaa.Sequential([self.dataset.imgaug_transform[-1]]) self.val_dataset.dataset.imgaug_transform = resize_transform self.test_dataset.dataset.imgaug_transform = resize_transform # further subsample training data if desired if self.train_frames is not None: n_frames = compute_num_train_frames(len(self.train_dataset), self.train_frames) if n_frames < len(self.train_dataset): # split the data a second time to reflect further subsampling from # train_frames self.train_dataset.indices = self.train_dataset.indices[:n_frames] print( f"Dataset splits -- " f"train: {len(self.train_dataset)}, " f"val: {len(self.val_dataset)}, " f"test: {len(self.test_dataset)}" )
[docs] def train_dataloader(self) -> torch.utils.data.DataLoader: return DataLoader( self.train_dataset, batch_size=self.train_batch_size, num_workers=self.num_workers, persistent_workers=True if self.num_workers > 0 else False, shuffle=True, generator=torch.Generator().manual_seed(self.torch_seed), )
[docs] def val_dataloader(self) -> torch.utils.data.DataLoader: return DataLoader( self.val_dataset, batch_size=self.val_batch_size, num_workers=self.num_workers, persistent_workers=True if self.num_workers > 0 else False, )
[docs] def test_dataloader(self) -> torch.utils.data.DataLoader: return DataLoader( self.test_dataset, batch_size=self.test_batch_size, num_workers=self.num_workers, )
[docs] def full_labeled_dataloader(self) -> torch.utils.data.DataLoader: return DataLoader( self.dataset, batch_size=self.val_batch_size, num_workers=self.num_workers, )
[docs] class UnlabeledDataModule(BaseDataModule): """Data module that contains labeled and unlabled data loaders."""
[docs] def __init__( self, dataset: torch.utils.data.Dataset, video_paths_list: list[str] | str, dali_config: dict | DictConfig, view_names: list[str] | None = None, train_batch_size: int = 16, val_batch_size: int = 16, test_batch_size: int = 1, num_workers: int = 8, train_probability: float = 0.8, val_probability: float | None = None, test_probability: float | None = None, train_frames: float | None = None, torch_seed: int = 42, imgaug: Literal["default", "dlc", "dlc-top-down"] = "default", ) -> None: """Data module that contains labeled and unlabeled data loaders. Args: dataset: pytorch Dataset for labeled data video_paths_list: absolute paths of videos ("unlabeled" data) view_names: if fitting a non-mirrored multiview model, pass view names in order to correctly organize the video paths dali_config: see `dali` entry of default config file for keys train_batch_size: number of samples of training batches val_batch_size: number of samples in validation batches test_batch_size: number of samples in test batches num_workers: number of threads used for prefetching data train_probability: fraction of full dataset used for training val_probability: fraction of full dataset used for validation test_probability: fraction of full dataset used for testing train_frames: if integer, select this number of training frames from the initially selected train frames (defined by `train_probability`); if float, must be between 0 and 1 (exclusive) and defines the fraction of the initially selected train frames torch_seed: control data splits torch_seed: control randomness of labeled data loading imgaug: type of image augmentation to apply to unlabeled frames """ super().__init__( dataset=dataset, train_batch_size=train_batch_size, val_batch_size=val_batch_size, test_batch_size=test_batch_size, num_workers=num_workers, train_probability=train_probability, val_probability=val_probability, test_probability=test_probability, train_frames=train_frames, torch_seed=torch_seed, ) self.video_paths_list = video_paths_list self.filenames = check_video_paths(self.video_paths_list, view_names=view_names) self.num_workers_for_unlabeled = 1 # WARNING!! do not increase above 1, weird behavior self.dali_config = dali_config self.unlabeled_dataloader = None # initialized in setup_unlabeled self.imgaug = imgaug # TODO: Should these belong in a setup method that called by lightning, # rather than __init__? BaseDataModule already follows that pattern. super().setup() self.setup_unlabeled()
[docs] def setup_unlabeled(self) -> None: """Sets up the unlabeled data loader.""" dali_prep = PrepareDALI( train_stage="train", model_type="context" if self.dataset.do_context else "base", filenames=self.filenames, resize_dims=[self.dataset.height, self.dataset.width], dali_config=self.dali_config, imgaug=self.imgaug, num_threads=self.num_workers_for_unlabeled, ) self.unlabeled_dataloader = dali_prep()
[docs] def train_dataloader(self) -> CombinedLoader: loader = SemiSupervisedDataLoaderDict( labeled=super().train_dataloader(), unlabeled=self.unlabeled_dataloader, ) return CombinedLoader(loader, mode="max_size_cycle")