BaseDataModule

class lightning_pose.data.datamodules.BaseDataModule[source]

Bases: LightningDataModule

Splits a labeled dataset into train, val, and test data loaders.

Methods Summary

full_labeled_dataloader()

Return a dataloader covering the entire labeled dataset (all splits combined).

test_dataloader()

Return the test dataloader.

train_dataloader()

Return the training dataloader with shuffling enabled.

val_dataloader()

Return the validation dataloader.

Methods Documentation

full_labeled_dataloader() DataLoader[source]

Return a dataloader covering the entire labeled dataset (all splits combined).

Returns:

DataLoader over the full underlying dataset.

test_dataloader() DataLoader[source]

Return the test dataloader.

Returns:

DataLoader wrapping the test subset.

train_dataloader() DataLoader[source]

Return the training dataloader with shuffling enabled.

Returns:

DataLoader wrapping the training subset.

val_dataloader() DataLoader[source]

Return the validation dataloader.

Returns:

DataLoader wrapping the validation subset.

__init__(dataset: BaseTrackingDataset | HeatmapDataset | MultiviewHeatmapDataset, train_batch_size: int = 16, val_batch_size: int = 16, test_batch_size: int = 1, num_workers: int | None = None, train_probability: float = 0.8, val_probability: float | None = None, test_probability: float | None = None, train_frames: float | int | None = None, torch_seed: int = 42) None[source]

Data module splits a dataset into train, val, and test data loaders.

Parameters:
  • dataset – base dataset to be split into train/val/test

  • train_batch_size – number of samples of training batches

  • val_batch_size – number of samples in validation batches

  • test_batch_size – number of samples in test batches

  • num_workers – number of threads used for prefetching data

  • train_probability – fraction of full dataset used for training

  • val_probability – fraction of full dataset used for validation

  • test_probability – fraction of full dataset used for testing

  • train_frames – if integer, select this number of training frames from the initially selected train frames (defined by train_probability); if float, must be between 0 and 1 (exclusive) and defines the fraction of the initially selected train frames

  • torch_seed – control data splits

__new__(**kwargs)