HeatmapDataset

class lightning_pose.data.datasets.HeatmapDataset[source]

Bases: BaseTrackingDataset

Heatmap dataset that contains the images and keypoints in 2D arrays.

Attributes Summary

output_shape

Spatial shape of the heatmap output (height, width) after downsampling.

Methods Summary

compute_heatmap(example_dict[,Β ignore_nans])

Compute 2D heatmaps from arbitrary (x, y) coordinates.

compute_heatmaps()

Compute initial 2D heatmaps for all labeled data.

Attributes Documentation

output_shape

Spatial shape of the heatmap output (height, width) after downsampling.

Returns:

Tuple of (heatmap_height, heatmap_width).

Methods Documentation

compute_heatmap(example_dict: BaseLabeledExampleDict, ignore_nans: bool = False) Float[Tensor, 'num_keypoints heatmap_height heatmap_width'][source]

Compute 2D heatmaps from arbitrary (x, y) coordinates.

compute_heatmaps() Tensor[source]

Compute initial 2D heatmaps for all labeled data. Note this will apply augmentations.

original image dims e.g., (406, 396) -> resized image dims e.g., (384, 384) -> potentially downsampled heatmaps e.g., (96, 96)

__init__(root_directory: str | Path, csv_path: str, image_resize_height: int, image_resize_width: int, header_rows: list[int] | None = [0, 1, 2], imgaug_transform: Sequential | None = None, downsample_factor: Literal[1, 2, 3] = 2, do_context: bool = False, resize: bool = True, uniform_heatmaps: bool = False, bbox_path: str | None = None) None[source]

Initialize the Heatmap Dataset.

Parameters:
  • root_directory – path to data directory

  • csv_path – path to CSV or h5 file (within root_directory). CSV file should be in the form (image_path, bodypart_1_x, bodypart_1_y, …, bodypart_n_y) Note: image_path is relative to the given root_directory

  • image_resize_height – height to resize images before sending to network

  • image_resize_width – height to resize images before sending to network

  • header_rows – which rows in the csv are header rows

  • imgaug_transform – imgaug transform pipeline to apply to images

  • downsample_factor – factor by which to downsample original image dims to have a smaller heatmap

  • do_context – include additional frames of context if possible

  • resize – True to add final resizing augmentation before sending data to network. This can be set to False if inheritors of this class need to implement more sophisticated augmentations before resizing (e.g. 3d augmentations). Note that when this is False, it is up to the child class to perform this resizing on both images and keypoints before returning a batch of data.

  • uniform_heatmaps – True to force the model to output uniform heatmaps for missing data; False will output all-zero heatmaps

  • bbox_path – path to csv file that contains bounding box information; rows must be in same order as csv file

__new__(**kwargs)