TemporalHeatmapLoss

class lightning_pose.losses.losses.TemporalHeatmapLoss[source]

Bases: Loss

Penalize temporal differences for each heatmap.

Motion model: x_t = x_(t-1) + e_t, e_t ~ N(0, s)

Attributes Summary

LOSS_NAME_KL

LOSS_NAME_MSE

Methods Summary

__call__(heatmaps_pred,Β confidences[,Β stage])

Compute the temporal heatmap loss for a batch of predicted heatmaps.

compute_loss(predictions)

Compute per-keypoint temporal heatmap differences between consecutive frames.

rectify_epsilon(loss)

Rectify supporting a list of epsilons, one per bodypart.

remove_nans(confidences,Β loss)

Zero out heatmap temporal difference losses where adjacent frames are low-confidence.

Attributes Documentation

LOSS_NAME_KL = 'temporal_heatmap_kl'
LOSS_NAME_MSE = 'temporal_heatmap_mse'

Methods Documentation

__call__(heatmaps_pred: Float[Tensor, 'batch num_keypoints heatmap_height heatmap_width'], confidences: Float[Tensor, 'batch num_keypoints'], stage: Literal['train', 'val', 'test'] | None = None, **kwargs: Any) tuple[Float[Tensor, ''], list[dict]][source]

Compute the temporal heatmap loss for a batch of predicted heatmaps.

Parameters:
  • heatmaps_pred – predicted heatmaps of shape (batch, num_keypoints, heatmap_height, heatmap_width).

  • confidences – per-frame confidence scores of shape (batch, num_keypoints).

  • stage – training stage for logging.

  • **kwargs – ignored extra keyword arguments.

Returns:

Tuple of scalar loss and list of logging dicts.

compute_loss(predictions: Float[Tensor, 'batch num_valid_keypoints heatmap_height heatmap_width']) Float[Tensor, 'batch_minus_one num_valid_keypoints'][source]

Compute per-keypoint temporal heatmap differences between consecutive frames.

Parameters:

predictions – predicted heatmaps of shape (batch, num_keypoints, heatmap_height, heatmap_width).

Returns:

Per-keypoint temporal divergence of shape (batch-1, num_keypoints).

rectify_epsilon(loss: Float[Tensor, 'batch_minus_one num_valid_keypoints']) Float[Tensor, 'batch_minus_one num_valid_keypoints'][source]

Rectify supporting a list of epsilons, one per bodypart. Not implemented in Loss class, because shapes of broadcasting may vary

remove_nans(confidences: Float[Tensor, 'batch num_keypoints'], loss: Float[Tensor, 'batch_minus_one num_keypoints']) Float[Tensor, 'batch_minus_one num_keypoints'][source]

Zero out heatmap temporal difference losses where adjacent frames are low-confidence.

Parameters:
  • confidences – per-frame confidence scores of shape (batch, num_keypoints).

  • loss – temporal difference losses of shape (batch-1, num_keypoints).

Returns:

Loss tensor with entries zeroed where confidence falls below self.prob_threshold.

__init__(loss_name: Literal['temporal_heatmap_mse', 'temporal_heatmap_kl'], data_module: BaseDataModule | UnlabeledDataModule | None = None, epsilon: float | list[float] = 0.0, prob_threshold: float = 0.0, log_weight: float = 0.0, **kwargs: Any) None[source]

Initialize TemporalHeatmapLoss.

Parameters:
  • loss_name – "temporal_heatmap_mse" uses pixel-wise MSE between consecutive heatmaps; "temporal_heatmap_kl" uses the KL divergence.

  • data_module – data module providing access to datasets; passed to the parent class.

  • epsilon – loss values below this threshold are zeroed out. May be a scalar or a list with one value per keypoint.

  • prob_threshold – predictions whose confidence is below this value are excluded from the loss computation.

  • log_weight – final weight in front of the loss term in the objective function is computed as 1.0 / (2.0 * exp(log_weight)).

__new__(**kwargs)