DataExtractor
- class lightning_pose.data.extractor.DataExtractor[source]
Bases:
objectHelper class to extract all data from a data module.
Attributes Summary
Number of examples in the selected data split.
Methods Summary
__call__()Extract all keypoints (and optionally images) from the selected data split.
Return the dataloader for the selected split.
iterate_over_dataloader(loader)Iterate over a dataloader and collect keypoints (and optionally images).
verify_labeled_loader(loader)Extract and return the labeled DataLoader from a potentially combined loader.
Attributes Documentation
- dataset_length
Number of examples in the selected data split.
- Returns:
Length of the
train,val, ortestdataset depending onself.cond.
Methods Documentation
- __call__() tuple[Tensor, Float[Tensor, 'num_examples 3 image_width image_height'] | Float[Tensor, 'num_examples frames 3 image_width image_height'] | None][source]
Extract all keypoints (and optionally images) from the selected data split.
- Returns:
concatenated keypoints tensor of shape
(num_examples, num_targets).concatenated image tensor or
Noneifself.extract_imagesis False.
- Return type:
Tuple of
- get_loader() DataLoader | CombinedLoader[source]
Return the dataloader for the selected split.
- Returns:
DataLoader or
CombinedLoadercorresponding toself.cond.- Raises:
ValueError – if
self.condis not"train","val", or"test".
- iterate_over_dataloader(loader: DataLoader) tuple[Tensor, Float[Tensor, 'num_examples 3 image_width image_height'] | Float[Tensor, 'num_examples frames 3 image_width image_height'] | None][source]
Iterate over a dataloader and collect keypoints (and optionally images).
- Parameters:
loader – labeled dataloader to iterate over.
- Returns:
concatenated keypoints tensor of shape
(num_examples, num_targets).concatenated image tensor or
Noneifself.extract_imagesis False.
- Return type:
Tuple of
- static verify_labeled_loader(loader: DataLoader | CombinedLoader) DataLoader[source]
Extract and return the labeled DataLoader from a potentially combined loader.
- Parameters:
loader – either a plain
DataLoaderor aCombinedLoadercontaining labeled and unlabeled sub-loaders.- Returns:
The labeled
DataLoader.
- __init__(data_module: BaseDataModule | UnlabeledDataModule, cond: Literal['train', 'test', 'val'] = 'train', extract_images: bool = False, remove_augmentations: bool = True) None[source]
Initialize DataExtractor.
- Parameters:
data_module – data module containing the labeled dataset and splits.
cond – which data split to extract (
"train","val", or"test").extract_images – if True, also extract and return image tensors.
remove_augmentations – if True, rebuild the dataset with only resize augmentation before extracting, to avoid contaminating PCA fits with augmented data.
- __new__(**kwargs)