HeatmapDataset

class lightning_pose.data.datasets.HeatmapDataset(root_directory: str, csv_path: str, header_rows: List[int] | None = [0, 1, 2], imgaug_transform: Callable | None = None, downsample_factor: Literal[1, 2, 3] = 2, do_context: bool = False, uniform_heatmaps: bool = False)[source]

Bases: BaseTrackingDataset

Heatmap dataset that contains the images and keypoints in 2D arrays.

Attributes Summary

output_shape

Methods Summary

compute_heatmap(example_dict)

Compute 2D heatmaps from arbitrary (x, y) coordinates.

compute_heatmaps()

Compute initial 2D heatmaps for all labeled data.

Attributes Documentation

output_shape

Methods Documentation

compute_heatmap(example_dict: BaseLabeledExampleDict) Tensor[Tensor][source]

Compute 2D heatmaps from arbitrary (x, y) coordinates.

compute_heatmaps()[source]

Compute initial 2D heatmaps for all labeled data. Note this will apply augmentations.

original image dims e.g., (406, 396) -> resized image dims e.g., (384, 384) -> potentially downsampled heatmaps e.g., (96, 96)