BaseDataModule
- class lightning_pose.data.datamodules.BaseDataModule[source]
Bases:
LightningDataModuleSplits a labeled dataset into train, val, and test data loaders.
Methods Summary
Return a dataloader covering the entire labeled dataset (all splits combined).
Return the test dataloader.
Return the training dataloader with shuffling enabled.
Return the validation dataloader.
Methods Documentation
- full_labeled_dataloader() DataLoader[source]
Return a dataloader covering the entire labeled dataset (all splits combined).
- Returns:
DataLoader over the full underlying dataset.
- test_dataloader() DataLoader[source]
Return the test dataloader.
- Returns:
DataLoader wrapping the test subset.
- train_dataloader() DataLoader[source]
Return the training dataloader with shuffling enabled.
- Returns:
DataLoader wrapping the training subset.
- val_dataloader() DataLoader[source]
Return the validation dataloader.
- Returns:
DataLoader wrapping the validation subset.
- __init__(dataset: BaseTrackingDataset | HeatmapDataset | MultiviewHeatmapDataset, train_batch_size: int = 16, val_batch_size: int = 16, test_batch_size: int = 1, num_workers: int | None = None, train_probability: float = 0.8, val_probability: float | None = None, test_probability: float | None = None, train_frames: float | int | None = None, torch_seed: int = 42) None[source]
Data module splits a dataset into train, val, and test data loaders.
- Parameters:
dataset – base dataset to be split into train/val/test
train_batch_size – number of samples of training batches
val_batch_size – number of samples in validation batches
test_batch_size – number of samples in test batches
num_workers – number of threads used for prefetching data
train_probability – fraction of full dataset used for training
val_probability – fraction of full dataset used for validation
test_probability – fraction of full dataset used for testing
train_frames – if integer, select this number of training frames from the initially selected train frames (defined by train_probability); if float, must be between 0 and 1 (exclusive) and defines the fraction of the initially selected train frames
torch_seed – control data splits
- __new__(**kwargs)