BaseDataModule

class lightning_pose.data.datamodules.BaseDataModule(dataset: Dataset, train_batch_size: int = 16, val_batch_size: int = 16, test_batch_size: int = 1, num_workers: int = 8, train_probability: float = 0.8, val_probability: float | None = None, test_probability: float | None = None, train_frames: float | int | None = None, torch_seed: int = 42)[source]

Bases: LightningDataModule

Splits a labeled dataset into train, val, and test data loaders.

Methods Summary

full_labeled_dataloader()

setup([stage])

Called at the beginning of fit (train + validate), validate, test, or predict.

test_dataloader()

An iterable or collection of iterables specifying test samples.

train_dataloader()

An iterable or collection of iterables specifying training samples.

val_dataloader()

An iterable or collection of iterables specifying validation samples.

Methods Documentation

full_labeled_dataloader() DataLoader[source]
setup(stage: str | None = None) None[source]

Called at the beginning of fit (train + validate), validate, test, or predict. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP.

Parameters:

stage – either 'fit', 'validate', 'test', or 'predict'

Example:

class LitModel(...):
    def __init__(self):
        self.l1 = None

    def prepare_data(self):
        download_data()
        tokenize()

        # don't do this
        self.something = else

    def setup(self, stage):
        data = load_data(...)
        self.l1 = nn.Linear(28, data.num_classes)
test_dataloader() DataLoader[source]

An iterable or collection of iterables specifying test samples.

For more information about multiple dataloaders, see this section.

For data processing use the following pattern:

  • download in prepare_data()

  • process and split in setup()

However, the above are only necessary for distributed processing.

Warning

do not assign state in prepare_data

Note

Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.

Note

If you don’t need a test dataset and a test_step(), you don’t need to implement this method.

train_dataloader() DataLoader[source]

An iterable or collection of iterables specifying training samples.

For more information about multiple dataloaders, see this section.

The dataloader you return will not be reloaded unless you set :paramref:`~lightning.pytorch.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.

For data processing use the following pattern:

  • download in prepare_data()

  • process and split in setup()

However, the above are only necessary for distributed processing.

Warning

do not assign state in prepare_data

Note

Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.

val_dataloader() DataLoader[source]

An iterable or collection of iterables specifying validation samples.

For more information about multiple dataloaders, see this section.

The dataloader you return will not be reloaded unless you set :paramref:`~lightning.pytorch.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.

It’s recommended that all data downloads and preparation happen in prepare_data().

  • fit()

  • validate()

  • prepare_data()

  • setup()

Note

Lightning tries to add the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.

Note

If you don’t need a validation dataset and a validation_step(), you don’t need to implement this method.

__init__(dataset: Dataset, train_batch_size: int = 16, val_batch_size: int = 16, test_batch_size: int = 1, num_workers: int = 8, train_probability: float = 0.8, val_probability: float | None = None, test_probability: float | None = None, train_frames: float | int | None = None, torch_seed: int = 42) None[source]

Data module splits a dataset into train, val, and test data loaders.

Parameters:
  • dataset – base dataset to be split into train/val/test

  • train_batch_size – number of samples of training batches

  • val_batch_size – number of samples in validation batches

  • test_batch_size – number of samples in test batches

  • num_workers – number of threads used for prefetching data

  • train_probability – fraction of full dataset used for training

  • val_probability – fraction of full dataset used for validation

  • test_probability – fraction of full dataset used for testing

  • train_frames – if integer, select this number of training frames from the initially selected train frames (defined by train_probability); if float, must be between 0 and 1 (exclusive) and defines the fraction of the initially selected train frames

  • torch_seed – control data splits