DataExtractor

class lightning_pose.data.extractor.DataExtractor[source]

Bases: object

Helper class to extract all data from a data module.

Attributes Summary

dataset_length

Number of examples in the selected data split.

Methods Summary

__call__()

Extract all keypoints (and optionally images) from the selected data split.

get_loader()

Return the dataloader for the selected split.

iterate_over_dataloader(loader)

Iterate over a dataloader and collect keypoints (and optionally images).

verify_labeled_loader(loader)

Extract and return the labeled DataLoader from a potentially combined loader.

Attributes Documentation

dataset_length

Number of examples in the selected data split.

Returns:

Length of the train, val, or test dataset depending on self.cond.

Methods Documentation

__call__() tuple[Tensor, Float[Tensor, 'num_examples 3 image_width image_height'] | Float[Tensor, 'num_examples frames 3 image_width image_height'] | None][source]

Extract all keypoints (and optionally images) from the selected data split.

Returns:

  • concatenated keypoints tensor of shape (num_examples, num_targets).

  • concatenated image tensor or None if self.extract_images is False.

Return type:

Tuple of

get_loader() DataLoader | CombinedLoader[source]

Return the dataloader for the selected split.

Returns:

DataLoader or CombinedLoader corresponding to self.cond.

Raises:

ValueError – if self.cond is not "train", "val", or "test".

iterate_over_dataloader(loader: DataLoader) tuple[Tensor, Float[Tensor, 'num_examples 3 image_width image_height'] | Float[Tensor, 'num_examples frames 3 image_width image_height'] | None][source]

Iterate over a dataloader and collect keypoints (and optionally images).

Parameters:

loader – labeled dataloader to iterate over.

Returns:

  • concatenated keypoints tensor of shape (num_examples, num_targets).

  • concatenated image tensor or None if self.extract_images is False.

Return type:

Tuple of

static verify_labeled_loader(loader: DataLoader | CombinedLoader) DataLoader[source]

Extract and return the labeled DataLoader from a potentially combined loader.

Parameters:

loader – either a plain DataLoader or a CombinedLoader containing labeled and unlabeled sub-loaders.

Returns:

The labeled DataLoader.

__init__(data_module: BaseDataModule | UnlabeledDataModule, cond: Literal['train', 'test', 'val'] = 'train', extract_images: bool = False, remove_augmentations: bool = True) None[source]

Initialize DataExtractor.

Parameters:
  • data_module – data module containing the labeled dataset and splits.

  • cond – which data split to extract ("train", "val", or "test").

  • extract_images – if True, also extract and return image tensors.

  • remove_augmentations – if True, rebuild the dataset with only resize augmentation before extracting, to avoid contaminating PCA fits with augmented data.

__new__(**kwargs)