Source code for lightning_pose.data.datamodules

"""Data modules split a dataset into train, val, and test modules."""

from __future__ import annotations

import copy
import os
from typing import TYPE_CHECKING, Literal

if TYPE_CHECKING:
    from lightning_pose.data.datasets import (
        BaseTrackingDataset,
        HeatmapDataset,
        MultiviewHeatmapDataset,
    )

import imgaug.augmenters as iaa
import lightning.pytorch as pl
import torch
from lightning.pytorch.utilities import CombinedLoader
from omegaconf import DictConfig, ListConfig
from torch.utils.data import DataLoader, Subset, random_split

from lightning_pose.data.dali import PrepareDALI
from lightning_pose.data.datatypes import SemiSupervisedDataLoaderDict
from lightning_pose.data.utils import (
    compute_num_train_frames,
    split_sizes_from_probabilities,
)
from lightning_pose.utils.io import check_video_paths

# to ignore imports for sphix-autoapidoc
__all__ = [
    "BaseDataModule",
    "UnlabeledDataModule",
]


[docs] class BaseDataModule(pl.LightningDataModule): """Splits a labeled dataset into train, val, and test data loaders."""
[docs] def __init__( self, dataset: BaseTrackingDataset | HeatmapDataset | MultiviewHeatmapDataset, train_batch_size: int = 16, val_batch_size: int = 16, test_batch_size: int = 1, num_workers: int | None = None, train_probability: float = 0.8, val_probability: float | None = None, test_probability: float | None = None, train_frames: float | int | None = None, torch_seed: int = 42, ) -> None: """Data module splits a dataset into train, val, and test data loaders. Args: dataset: base dataset to be split into train/val/test train_batch_size: number of samples of training batches val_batch_size: number of samples in validation batches test_batch_size: number of samples in test batches num_workers: number of threads used for prefetching data train_probability: fraction of full dataset used for training val_probability: fraction of full dataset used for validation test_probability: fraction of full dataset used for testing train_frames: if integer, select this number of training frames from the initially selected train frames (defined by `train_probability`); if float, must be between 0 and 1 (exclusive) and defines the fraction of the initially selected train frames torch_seed: control data splits """ super().__init__() self.dataset = dataset self.train_batch_size = train_batch_size self.val_batch_size = val_batch_size self.test_batch_size = test_batch_size if num_workers is not None: self.num_workers = num_workers else: slurm_cpus = os.getenv("SLURM_CPUS_PER_TASK") if slurm_cpus: self.num_workers = int(slurm_cpus) else: # Fallback to os.cpu_count() self.num_workers = os.cpu_count() or 0 self.train_probability = train_probability self.val_probability = val_probability self.test_probability = test_probability self.train_frames = train_frames self.train_dataset: Subset | None = None self.val_dataset: Subset | None = None self.test_dataset: Subset | None = None self.torch_seed = torch_seed self._setup()
def _setup(self) -> None: """Split the dataset into train, validation, and test subsets.""" datalen = len(self.dataset) print(f"Number of labeled images in the full dataset (train+val+test): {datalen}") # split data based on provided probabilities data_splits_list = split_sizes_from_probabilities( datalen, train_probability=self.train_probability, val_probability=self.val_probability, test_probability=self.test_probability, ) if len(self.dataset.imgaug_transform) == 1: # type: ignore[arg-type] # no augmentations in the pipeline; subsets can share same underlying dataset self.train_dataset, self.val_dataset, self.test_dataset = random_split( self.dataset, data_splits_list, generator=torch.Generator().manual_seed(self.torch_seed), ) else: # augmentations in the pipeline; we want validation and test datasets that only resize # we can't simply change the imgaug pipeline in the datasets after they've been split # because the subsets actually point to the same underlying dataset, so we create # separate datasets here train_idxs, val_idxs, test_idxs = random_split( range(len(self.dataset)), # type: ignore[arg-type] data_splits_list, generator=torch.Generator().manual_seed(self.torch_seed), ) self.train_dataset = Subset( copy.deepcopy(self.dataset), indices=list(train_idxs), # type: ignore[arg-type] ) self.val_dataset = Subset( copy.deepcopy(self.dataset), indices=list(val_idxs), # type: ignore[arg-type] ) self.test_dataset = Subset( copy.deepcopy(self.dataset), indices=list(test_idxs), # type: ignore[arg-type] ) # only use the final resize transform for the validation and test datasets if self.dataset.imgaug_transform[-1].__str__().find("Resize") == 0: # type: ignore[index] final_transform = iaa.Sequential([self.dataset.imgaug_transform[-1]]) # type: ignore[index] else: # if we're here it's because the dataset is a MultiviewHeatmapDataset that doesn't # resize by default in the pipeline; we enforce resizing here on val/test batches height = self.dataset.height width = self.dataset.width final_transform = iaa.Sequential([iaa.Resize({"height": height, "width": width})]) self.val_dataset.dataset.imgaug_transform = final_transform # type: ignore[union-attr] if hasattr(self.val_dataset.dataset, "dataset"): # this will get triggered for multiview datasets print("val: updating children datasets with resize imgaug pipeline") for _view_name, dset in self.val_dataset.dataset.dataset.items(): # type: ignore[union-attr] dset.imgaug_transform = final_transform self.test_dataset.dataset.imgaug_transform = final_transform # type: ignore[union-attr] if hasattr(self.test_dataset.dataset, "dataset"): # this will get triggered for multiview datasets print("test: updating children datasets with resize imgaug pipeline") for _view_name, dset in self.test_dataset.dataset.dataset.items(): # type: ignore[union-attr] dset.imgaug_transform = final_transform # further subsample training data if desired if self.train_frames is not None: n_frames = compute_num_train_frames(len(self.train_dataset), self.train_frames) if n_frames < len(self.train_dataset): # split the data a second time to reflect further subsampling from # train_frames self.train_dataset.indices = self.train_dataset.indices[:n_frames] print( f"Dataset splits -- " f"train: {len(self.train_dataset)}, " f"val: {len(self.val_dataset)}, " f"test: {len(self.test_dataset)}" )
[docs] def train_dataloader(self) -> torch.utils.data.DataLoader: """Return the training dataloader with shuffling enabled. Returns: DataLoader wrapping the training subset. """ return DataLoader( self.train_dataset, # type: ignore[arg-type] batch_size=self.train_batch_size, num_workers=self.num_workers, persistent_workers=True if self.num_workers > 0 else False, shuffle=True, generator=torch.Generator().manual_seed(self.torch_seed), )
[docs] def val_dataloader(self) -> torch.utils.data.DataLoader: """Return the validation dataloader. Returns: DataLoader wrapping the validation subset. """ return DataLoader( self.val_dataset, # type: ignore[arg-type] batch_size=self.val_batch_size, num_workers=self.num_workers, persistent_workers=True if self.num_workers > 0 else False, )
[docs] def test_dataloader(self) -> torch.utils.data.DataLoader: """Return the test dataloader. Returns: DataLoader wrapping the test subset. """ return DataLoader( self.test_dataset, # type: ignore[arg-type] batch_size=self.test_batch_size, num_workers=self.num_workers, )
[docs] def full_labeled_dataloader(self) -> torch.utils.data.DataLoader: """Return a dataloader covering the entire labeled dataset (all splits combined). Returns: DataLoader over the full underlying dataset. """ return DataLoader( self.dataset, batch_size=self.val_batch_size, num_workers=self.num_workers, )
[docs] class UnlabeledDataModule(BaseDataModule): """Data module that contains labeled and unlabled data loaders."""
[docs] def __init__( self, dataset: BaseTrackingDataset | HeatmapDataset | MultiviewHeatmapDataset, video_paths_list: list[str] | str, dali_config: dict | DictConfig | ListConfig, view_names: list[str] | None = None, train_batch_size: int = 16, val_batch_size: int = 16, test_batch_size: int = 1, num_workers: int | None = None, train_probability: float = 0.8, val_probability: float | None = None, test_probability: float | None = None, train_frames: float | None = None, torch_seed: int = 42, imgaug: Literal["default", "dlc", "dlc-top-down"] = "default", ) -> None: """Data module that contains labeled and unlabeled data loaders. Args: dataset: pytorch Dataset for labeled data video_paths_list: absolute paths of videos ("unlabeled" data) view_names: if fitting a non-mirrored multiview model, pass view names in order to correctly organize the video paths dali_config: see `dali` entry of default config file for keys train_batch_size: number of samples of training batches val_batch_size: number of samples in validation batches test_batch_size: number of samples in test batches num_workers: number of threads used for prefetching data train_probability: fraction of full dataset used for training val_probability: fraction of full dataset used for validation test_probability: fraction of full dataset used for testing train_frames: if integer, select this number of training frames from the initially selected train frames (defined by `train_probability`); if float, must be between 0 and 1 (exclusive) and defines the fraction of the initially selected train frames torch_seed: control data splits torch_seed: control randomness of labeled data loading imgaug: type of image augmentation to apply to unlabeled frames """ super().__init__( dataset=dataset, train_batch_size=train_batch_size, val_batch_size=val_batch_size, test_batch_size=test_batch_size, num_workers=num_workers, train_probability=train_probability, val_probability=val_probability, test_probability=test_probability, train_frames=train_frames, torch_seed=torch_seed, ) self.video_paths_list = video_paths_list self.filenames = check_video_paths(self.video_paths_list, view_names=view_names) self.num_workers_for_unlabeled = 1 # WARNING!! do not increase above 1, weird behavior self.dali_config = dali_config self.unlabeled_dataloader = None # initialized in setup_unlabeled self.imgaug = imgaug self.setup_unlabeled()
[docs] def setup_unlabeled(self) -> None: """Sets up the unlabeled data loader.""" dali_prep = PrepareDALI( train_stage="train", model_type="context" if self.dataset.do_context else "base", filenames=self.filenames, resize_dims=[self.dataset.height, self.dataset.width], dali_config=self.dali_config, imgaug=self.imgaug, num_threads=self.num_workers_for_unlabeled, ) self.unlabeled_dataloader = dali_prep()
[docs] def train_dataloader(self) -> CombinedLoader: """Return a combined dataloader pairing labeled and unlabeled training data. Returns: ``CombinedLoader`` in ``max_size_cycle`` mode that cycles through labeled and unlabeled batches together. """ assert self.unlabeled_dataloader is not None loader = SemiSupervisedDataLoaderDict( labeled=super().train_dataloader(), unlabeled=self.unlabeled_dataloader, ) # CombinedLoader mode="max_size_cycle" works in concert with # `trainer.limit_train_batches`. Assuming unlabeled data is plentiful, # it will cycle through labeled data until limit_train_batches. # We set limit_train_batches such that it exhausts all labeled data # in an epoch, or it cycles for a minimum of 10 batches. # # The reason to have a minimum number of batches is so that when labeled data is # scarce, the model sees more unlabeled data per epoch instead of just stopping # (empirically better). return CombinedLoader(loader, mode="max_size_cycle")