lightning_pose.data

lightning_pose.data.augmentations Module

Functions to build augmentation pipeline.

Functions

imgaug_transform(cfg)

Create simple data transform pipeline that augments images.

lightning_pose.data.dali Module

Data pipelines based on efficient video reading by nvidia dali package.

Functions

video_pipe(filenames[, resize_dims, ...])

Generic video reader pipeline that loads videos, resizes, augments, and normalizes.

Classes

LitDaliWrapper(*args, eval_mode[, ...])

wrapper around a DALI pipeline to get batches for ptl.

PrepareDALI(train_stage, model_type, ...[, ...])

All the DALI stuff in one place.

lightning_pose.data.datamodules Module

Data modules split a dataset into train, val, and test modules.

Classes

BaseDataModule(dataset[, train_batch_size, ...])

Splits a labeled dataset into train, val, and test data loaders.

UnlabeledDataModule(dataset, ...[, ...])

Data module that contains labeled and unlabled data loaders.

lightning_pose.data.datasets Module

Dataset objects store images, labels, and functions for manipulation.

Classes

BaseTrackingDataset(root_directory, csv_path)

Base dataset that contains images and keypoints as (x, y) pairs.

HeatmapDataset(root_directory, csv_path[, ...])

Heatmap dataset that contains the images and keypoints in 2D arrays.

MultiviewHeatmapDataset(root_directory, ...)

Heatmap dataset that contains the images and keypoints in 2D arrays from all the cameras.

lightning_pose.data.utils Module

Dataset/data module utilities.

Functions

split_sizes_from_probabilities(total_number, ...)

Returns the number of examples for train, val and test given split probs.

clean_any_nans(data, dim)

Remove samples from a data array that contain nans.

count_frames(video_list)

Simple function to count the number of frames in a video or a list of videos.

compute_num_train_frames(len_train_dataset)

Quickly compute number of training frames for a given dataset.

generate_heatmaps(keypoints, height, width, ...)

Generate 2D Gaussian heatmaps from mean and sigma.

evaluate_heatmaps_at_location(heatmaps, locs)

Evaluate 4D heatmaps using a 3D location tensor (last dim is x, y coords).

undo_affine_transform(keypoints, transform)

Undo an affine transform given a tensor of keypoints and the tranform matrix.

undo_affine_transform_batch(...[, is_multiview])

Potentially undo an affine transform given a tensor of keypoints and the tranform matrix.

Classes

BaseLabeledExampleDict(*args, **kwargs)

Return type when calling __getitem__() on BaseTrackingDataset.

HeatmapLabeledExampleDict(*args, **kwargs)

Return type when calling __getitem__() on HeatmapTrackingDataset.

MultiviewLabeledExampleDict(*args, **kwargs)

Return type when calling __getitem__() on MultiviewDataset.

MultiviewHeatmapLabeledExampleDict(*args, ...)

Return type when calling __getitem__() on MultiviewHeatmapDataset.

BaseLabeledBatchDict(*args, **kwargs)

Batch type for base labeled data.

HeatmapLabeledBatchDict(*args, **kwargs)

Batch type for heatmap labeled data.

MultiviewLabeledBatchDict(*args, **kwargs)

Batch type for multiview labeled data.

MultiviewHeatmapLabeledBatchDict(*args, **kwargs)

Batch type for multiview heatmap labeled data.

UnlabeledBatchDict(*args, **kwargs)

Batch type for unlabeled data.

MultiviewUnlabeledBatchDict(*args, **kwargs)

Batch type for multiview unlabeled data.

SemiSupervisedBatchDict(*args, **kwargs)

Batch type for base labeled+unlabeled data.

SemiSupervisedHeatmapBatchDict(*args, **kwargs)

Batch type for heatmap labeled+unlabeled data.

SemiSupervisedDataLoaderDict(*args, **kwargs)

Return type when calling train/val/test_dataloader() on semi-supervised models.

DataExtractor(data_module[, cond, ...])

Helper class to extract all data from a data module.