HeatmapDataset

class lightning_pose.data.datasets.HeatmapDataset(root_directory: str, csv_path: str, header_rows: list[int] | None = [0, 1, 2], imgaug_transform: Callable | None = None, downsample_factor: Literal[1, 2, 3] = 2, do_context: bool = False, uniform_heatmaps: bool = False)[source]

Bases: BaseTrackingDataset

Heatmap dataset that contains the images and keypoints in 2D arrays.

Attributes Summary

output_shape

Methods Summary

compute_heatmap(example_dict)

Compute 2D heatmaps from arbitrary (x, y) coordinates.

compute_heatmaps()

Compute initial 2D heatmaps for all labeled data.

Attributes Documentation

output_shape

Methods Documentation

compute_heatmap(example_dict: BaseLabeledExampleDict) Tensor, {'__torchtyping__': True, 'details': ('num_keypoints', 'heatmap_height', 'heatmap_width',), 'cls_name': 'TensorType'}][source]

Compute 2D heatmaps from arbitrary (x, y) coordinates.

compute_heatmaps()[source]

Compute initial 2D heatmaps for all labeled data. Note this will apply augmentations.

original image dims e.g., (406, 396) -> resized image dims e.g., (384, 384) -> potentially downsampled heatmaps e.g., (96, 96)

__init__(root_directory: str, csv_path: str, header_rows: list[int] | None = [0, 1, 2], imgaug_transform: Callable | None = None, downsample_factor: Literal[1, 2, 3] = 2, do_context: bool = False, uniform_heatmaps: bool = False) None[source]

Initialize the Heatmap Dataset.

Parameters:
  • root_directory – path to data directory

  • csv_path – path to CSV or h5 file (within root_directory). CSV file should be in the form (image_path, bodypart_1_x, bodypart_1_y, …, bodypart_n_y) Note: image_path is relative to the given root_directory

  • header_rows – which rows in the csv are header rows

  • imgaug_transform – imgaug transform pipeline to apply to images

  • downsample_factor – factor by which to downsample original image dims to have a smaller heatmap

  • do_context – include additional frames of context if possible