DataExtractor

class lightning_pose.data.utils.DataExtractor(data_module: LightningDataModule, cond: Literal['train', 'test', 'val'] = 'train', extract_images: bool = False, remove_augmentations: bool = True)

Bases: object

Helper class to extract all data from a data module.

Attributes Summary

dataset_length

Methods Summary

__call__()

Call self as a function.

get_loader()

iterate_over_dataloader(loader)

verify_labeled_loader(loader)

Attributes Documentation

dataset_length

Methods Documentation

__call__() Tuple[Tensor[Tensor], Tensor[Tensor] | Tensor[Tensor] | None]

Call self as a function.

get_loader() DataLoader | SemiSupervisedDataLoaderDict
iterate_over_dataloader(loader: DataLoader) Tuple[Tensor[Tensor], Tensor[Tensor] | Tensor[Tensor] | None]
static verify_labeled_loader(loader: DataLoader | SemiSupervisedDataLoaderDict) DataLoader