"""Dataset/data module utilities."""
import os
from typing import Any, Literal
import imgaug.augmenters as iaa
import lightning.pytorch as pl
import numpy as np
import torch
from torchtyping import TensorType
from typeguard import typechecked
from lightning_pose.data.datatypes import (
HeatmapLabeledBatchDict,
MultiviewHeatmapLabeledBatchDict,
MultiviewUnlabeledBatchDict,
SemiSupervisedDataLoaderDict,
UnlabeledBatchDict,
)
# to ignore imports for sphix-autoapidoc
__all__ = [
"DataExtractor",
"split_sizes_from_probabilities",
"clean_any_nans",
"count_frames",
"compute_num_train_frames",
"generate_heatmaps",
"evaluate_heatmaps_at_location",
"undo_affine_transform",
"undo_affine_transform_batch",
"normalized_to_bbox",
"convert_bbox_coords",
"convert_original_to_model_coords",
"original_to_model",
]
[docs]
@typechecked
def split_sizes_from_probabilities(
total_number: int,
train_probability: float,
val_probability: float | None = None,
test_probability: float | None = None,
) -> list[int]:
"""Returns the number of examples for train, val and test given split probs.
Args:
total_number: total number of examples in dataset
train_probability: fraction of examples used for training
val_probability: fraction of examples used for validation
test_probability: fraction of examples used for test. Defaults to None. Can be computed
as the remaining examples.
Returns:
[num training examples, num validation examples, num test examples]
"""
if test_probability is None and val_probability is None:
remaining_probability = 1.0 - train_probability
# round each to 5 decimal places (issue with floating point precision)
val_probability = round(remaining_probability / 2, 5)
test_probability = round(remaining_probability / 2, 5)
elif test_probability is None:
test_probability = 1.0 - train_probability - val_probability
# probabilities should add to one
assert test_probability + train_probability + val_probability == 1.0
# compute numbers from probabilities
train_number = int(np.floor(train_probability * total_number))
val_number = int(np.floor(val_probability * total_number))
# if we lose extra examples by flooring, send these to train_number or test_number, depending
leftover = total_number - train_number - val_number
if leftover < 5:
# very few samples, let's bulk up train
train_number += leftover
test_number = 0
else:
test_number = leftover
# make sure that we have at least one validation sample
if val_number == 0:
train_number -= 1
val_number += 1
if train_number < 1:
raise ValueError("Must have at least two labeled frames, one train and one validation")
# assert that we're using all datapoints
assert train_number + test_number + val_number == total_number
return [train_number, val_number, test_number]
[docs]
@typechecked
def clean_any_nans(data: torch.Tensor, dim: int) -> torch.Tensor:
"""Remove samples from a data array that contain nans."""
# currently supports only 2D arrays
nan_bool = (
torch.sum(torch.isnan(data), dim=dim) > 0
) # e.g., when dim == 0, those columns (keypoints) that have >0 nans
if dim == 0:
return data[:, ~nan_bool]
elif dim == 1:
return data[~nan_bool]
[docs]
@typechecked
def count_frames(video_file: str) -> int:
"""
Simple function to count the number of frames in a video.
"""
assert os.path.isfile(video_file)
import cv2
cap = cv2.VideoCapture(video_file)
num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
cap.release()
return num_frames
[docs]
@typechecked
def compute_num_train_frames(
len_train_dataset: int,
train_frames: int | float | None = None,
) -> int:
"""Quickly compute number of training frames for a given dataset.
Args:
len_train_dataset: total number of frames in training dataset
train_frames:
<=1 - fraction of total train frames used for training
>1 - number of total train frames used for training
Returns:
total number of train frames
"""
if train_frames is None:
n_train_frames = len_train_dataset
else:
if train_frames >= len_train_dataset:
# take max number of train frames
print("Warning! Requested training frames exceeds training set size; using all")
n_train_frames = len_train_dataset
elif train_frames == 1:
# assume this is a fraction; use full dataset
n_train_frames = len_train_dataset
elif train_frames > 1:
# take this number of train frames
n_train_frames = int(train_frames)
elif train_frames > 0:
# take this fraction of train frames
n_train_frames = int(train_frames * len_train_dataset)
else:
raise ValueError("train_frames must be >0")
return n_train_frames
# @typechecked
[docs]
def generate_heatmaps(
keypoints: TensorType["batch", "num_keypoints", 2],
height: int,
width: int,
output_shape: tuple[int, int],
sigma: float = 1.25,
uniform_heatmaps: bool = False,
keep_gradients: bool = False,
) -> TensorType["batch", "num_keypoints", "height", "width"]:
"""Generate 2D Gaussian heatmaps from mean and sigma.
Args:
keypoints: coordinates that serve as mean of gaussian bump
height: height of reshaped image (pixels, e.g., 128, 256, 512...)
width: width of reshaped image (pixels, e.g., 128, 256, 512...)
output_shape: dimensions of downsampled heatmap, (height, width)
sigma: control spread of gaussian
uniform_heatmaps: output uniform heatmaps if missing ground truth label, rather than skip
keep_gradients: True to not detach gradients from keypoints before creating heatmaps
Returns:
batch of 2D heatmaps
"""
if keep_gradients:
keypoints = keypoints.clone()
else:
keypoints = keypoints.detach().clone()
out_height = output_shape[0]
out_width = output_shape[1]
keypoints[:, :, 1] *= out_height / height
keypoints[:, :, 0] *= out_width / width
# nan_idxs = torch.isnan(keypoints)[:, :, 0]
# Mark as invalid: NaN keypoints OR out-of-bounds keypoints
nan_idxs = (
torch.isnan(keypoints)[:, :, 0] # Original NaN check
| (keypoints[:, :, 0] < -1) # x < -1
| (keypoints[:, :, 0] > out_width + 1) # x > width + 1
| (keypoints[:, :, 1] < -1) # y < -1
| (keypoints[:, :, 1] > out_height + 1) # y > height + 1
)
# Clamp keypoints to prevent extreme Gaussian computations
# Use a reasonable buffer around the image bounds
# keypoints[:, :, 0] = torch.clamp(keypoints[:, :, 0], -margin, out_width + margin)
# keypoints[:, :, 1] = torch.clamp(keypoints[:, :, 1], -margin, out_height + margin)
clamped_x = torch.clamp(keypoints[:, :, 0], -1, out_width + 1)
clamped_y = torch.clamp(keypoints[:, :, 1], -1, out_height + 1)
keypoints = torch.stack([clamped_x, clamped_y], dim=2)
xv = torch.arange(out_width, device=keypoints.device)
yv = torch.arange(out_height, device=keypoints.device)
# note flipped order because of pytorch's ij and numpy's xy indexing for meshgrid
xx, yy = torch.meshgrid(yv, xv, indexing="ij")
# adds batch and num_keypoints dimensions to grids
xx = xx.unsqueeze(0).unsqueeze(0)
yy = yy.unsqueeze(0).unsqueeze(0)
# adds dimension corresponding to the first dimension of the 2d grid
keypoints = keypoints.unsqueeze(2)
# evaluates 2d gaussian with mean equal to the keypoint and var equal to sigma^2
heatmaps = (yy - keypoints[:, :, :, :1]) ** 2 # also flipped order here
heatmaps += (xx - keypoints[:, :, :, 1:]) ** 2 # also flipped order here
heatmaps *= -1
heatmaps /= 2 * sigma**2
heatmaps = torch.exp(heatmaps)
# normalize all heatmaps to one
heatmaps = heatmaps / torch.sum(heatmaps, dim=(2, 3), keepdim=True)
# replace nans with zeros heatmaps
# (all zeros heatmaps are ignored in the supervised heatmap loss)
if uniform_heatmaps:
filler_heatmap = torch.ones(
(out_height, out_width), device=keypoints.device
) / (out_height * out_width)
else:
filler_heatmap = torch.zeros((out_height, out_width), device=keypoints.device)
heatmaps[nan_idxs] = filler_heatmap
return heatmaps
# @typechecked
[docs]
def evaluate_heatmaps_at_location(
heatmaps: TensorType["batch", "num_keypoints", "heatmap_height", "heatmap_width"],
locs: TensorType["batch", "num_keypoints", 2],
sigma: float = 1.25, # sigma used for generating heatmaps
num_stds: int = 2, # num standard deviations of pixels to compute confidence
) -> TensorType["batch", "num_keypoints"]:
"""Evaluate 4D heatmaps using a 3D location tensor (last dim is x, y coords). Since
the model outputs heatmaps with a standard deviation of sigma, confidence will be
spread across neighboring pixels. To account for this, confidence is computed by
taking all pixels within two standard deviations of the predicted pixel."""
pix_to_consider = int(np.floor(sigma * num_stds)) # get all pixels within num_stds.
num_pad = pix_to_consider
heatmaps_padded = torch.zeros(
(
heatmaps.shape[0],
heatmaps.shape[1],
heatmaps.shape[2] + num_pad * 2,
heatmaps.shape[3] + num_pad * 2,
),
device=heatmaps.device,
)
heatmaps_padded[:, :, num_pad:-num_pad, num_pad:-num_pad] = heatmaps
i = torch.arange(heatmaps_padded.shape[0], device=heatmaps_padded.device).reshape(
-1, 1, 1, 1
)
j = torch.arange(heatmaps_padded.shape[1], device=heatmaps_padded.device).reshape(
1, -1, 1, 1
)
k = locs[:, :, None, 1, None].type(torch.int64) + num_pad
m = locs[:, :, 0, None, None].type(torch.int64) + num_pad
offsets = list(np.arange(-pix_to_consider, pix_to_consider + 1))
vals_all = []
for offset in offsets:
k_offset = k + offset
for offset_2 in offsets:
m_offset = m + offset_2
# get rid of singleton dims
vals = heatmaps_padded[i, j, k_offset, m_offset].squeeze(-1).squeeze(-1)
vals_all.append(vals)
vals = torch.stack(vals_all, 0).sum(0)
return vals
# @typechecked
[docs]
def normalized_to_bbox(
keypoints: TensorType["batch", "num_keypoints", "xy":2],
bbox: TensorType["batch", "xyhw":4]
) -> TensorType["batch", "num_keypoints", "xy":2]:
"""Transform keypoints from normalized coordinates to bbox coordinates"""
if keypoints.shape[0] == bbox.shape[0]:
# normal batch
keypoints[:, :, 0] *= bbox[:, 3].unsqueeze(1) # scale x by box width
keypoints[:, :, 0] += bbox[:, 0].unsqueeze(1) # add bbox x offset
keypoints[:, :, 1] *= bbox[:, 2].unsqueeze(1) # scale y by box height
keypoints[:, :, 1] += bbox[:, 1].unsqueeze(1) # add bbox y offset
else:
# context batch; we don't have predictions for first/last two frames
keypoints[:, :, 0] *= bbox[2:-2, 3].unsqueeze(1) # scale x by box width
keypoints[:, :, 0] += bbox[2:-2, 0].unsqueeze(1) # add bbox x offset
keypoints[:, :, 1] *= bbox[2:-2, 2].unsqueeze(1) # scale y by box height
keypoints[:, :, 1] += bbox[2:-2, 1].unsqueeze(1) # add bbox y offset
return keypoints
[docs]
def convert_bbox_coords(
batch_dict: (
HeatmapLabeledBatchDict
| MultiviewHeatmapLabeledBatchDict
| MultiviewUnlabeledBatchDict
| UnlabeledBatchDict
),
predicted_keypoints: TensorType["batch", "num_targets"],
in_place: bool = True,
) -> TensorType["batch", "num_targets"]:
"""Transform keypoints from bbox coordinates to absolute frame coordinates."""
num_targets = predicted_keypoints.shape[1]
num_keypoints = num_targets // 2
# reshape from (batch, n_targets) back to (batch, n_key, 2), in x,y order
if in_place:
predicted_keypoints_ = predicted_keypoints.reshape((-1, num_keypoints, 2))
else:
predicted_keypoints_ = predicted_keypoints.clone().reshape((-1, num_keypoints, 2))
# divide by image dims to get 0-1 normalized coordinates
if "images" in batch_dict.keys():
predicted_keypoints_[:, :, 0] /= batch_dict["images"].shape[-1] # -1 dim is width "x"
predicted_keypoints_[:, :, 1] /= batch_dict["images"].shape[-2] # -2 dim is height "y"
else: # we have unlabeled dict, 'frames' instead of 'images'
predicted_keypoints_[:, :, 0] /= batch_dict["frames"].shape[-1] # -1 dim is width "x"
predicted_keypoints_[:, :, 1] /= batch_dict["frames"].shape[-2] # -2 dim is height "y"
# multiply and add by bbox dims (x,y,h,w)
if (
("num_views" in batch_dict.keys() and int(batch_dict["num_views"].max()) > 1)
or batch_dict.get("is_multiview", False)
):
# the first check is for labeled batches while is_multiview is for unlabeled batches
# For MultiviewUnlabeledBatchDict, we need to infer num_views from bbox shape
if "num_views" in batch_dict.keys():
unique = batch_dict["num_views"].unique()
if len(unique) != 1:
raise ValueError(
f"each batch element must contain the same number of views; "
f"found elements with {unique} views"
)
num_views = int(unique)
else:
# Infer from bbox shape: bbox has shape [seq_len, num_views * 4]
num_views = batch_dict["bbox"].shape[1] // 4
num_keypoints_per_view = num_keypoints // num_views
for v in range(num_views):
idx_beg = num_keypoints_per_view * v
idx_end = idx_beg + num_keypoints_per_view
bbox_slice = batch_dict["bbox"][:, 4 * v:4 * (v + 1)]
predicted_keypoints_[:, idx_beg:idx_end, :] = normalized_to_bbox(
predicted_keypoints_[:, idx_beg:idx_end, :],
bbox_slice,
)
else:
predicted_keypoints_ = normalized_to_bbox(predicted_keypoints_, batch_dict["bbox"])
# return new keypoints, reshaped to (batch, num_targets)
return predicted_keypoints_.reshape((-1, num_targets))
[docs]
def convert_original_to_model_coords(
batch_dict: MultiviewHeatmapLabeledBatchDict,
original_keypoints: TensorType["batch", "num_views", "num_keypoints", 2],
) -> TensorType["batch", "num_views", "num_keypoints", 2]:
"""Transform keypoints from original frame coordinates to model input coordinates."""
batch_size, num_views, num_keypoints, _ = original_keypoints.shape
# Get model input dimensions
model_height = batch_dict["images"].shape[-2] # height
model_width = batch_dict["images"].shape[-1] # width
# Clone to avoid modifying original
model_keypoints = original_keypoints.clone()
# Process each view
for v in range(num_views):
bbox_slice = batch_dict["bbox"][:, 4 * v:4 * (v + 1)] # (batch, 4)
model_keypoints[:, v, :, :] = original_to_model(
original_keypoints[:, v, :, :], # (batch, num_keypoints, 2)
bbox_slice, # (batch, 4)
model_width,
model_height,
)
return model_keypoints
[docs]
def original_to_model(
keypoints: TensorType["batch", "num_keypoints", 2],
bbox: TensorType["batch", 4],
model_width: float,
model_height: float,
) -> TensorType["batch", "num_keypoints", 2]:
"""Convert keypoints from original image coordinates to model input coordinates.
This combines the transformations:
1. original → bbox: subtract offset, divide by bbox dimensions
2. bbox → model: multiply by model dimensions
bbox format: [x, y, h, w] where x,y is top-left corner
"""
model_keypoints = keypoints.clone()
if keypoints.shape[0] == bbox.shape[0]:
# normal batch
# Step 1: original → bbox (normalize to [0,1] relative to bbox)
model_keypoints[:, :, 0] -= bbox[:, 0].unsqueeze(1) # subtract bbox x offset
model_keypoints[:, :, 0] /= bbox[:, 3].unsqueeze(1) # divide by box width
model_keypoints[:, :, 1] -= bbox[:, 1].unsqueeze(1) # subtract bbox y offset
model_keypoints[:, :, 1] /= bbox[:, 2].unsqueeze(1) # divide by box height
# Step 2: bbox → model (scale to model dimensions)
model_keypoints[:, :, 0] *= model_width # scale to model width
model_keypoints[:, :, 1] *= model_height # scale to model height
else:
# context batch; we don't have predictions for first/last two frames
# Step 1: original → bbox
model_keypoints[:, :, 0] -= bbox[2:-2, 0].unsqueeze(1) # subtract bbox x offset
model_keypoints[:, :, 0] /= bbox[2:-2, 3].unsqueeze(1) # divide by box width
model_keypoints[:, :, 1] -= bbox[2:-2, 1].unsqueeze(1) # subtract bbox y offset
model_keypoints[:, :, 1] /= bbox[2:-2, 2].unsqueeze(1) # divide by box height
# Step 2: bbox → model
model_keypoints[:, :, 0] *= model_width # scale to model width
model_keypoints[:, :, 1] *= model_height # scale to model height
return model_keypoints