TemporalHeatmapLoss

class lightning_pose.losses.losses.TemporalHeatmapLoss[source]

Bases: Loss

Penalize temporal differences for each heatmap.

Motion model: x_t = x_(t-1) + e_t, e_t ~ N(0, s)

Attributes Summary

LOSS_NAME_KL

LOSS_NAME_MSE

Methods Summary

__call__(heatmaps_pred,Β confidences[,Β stage])

Call self as a function.

compute_loss(predictions)

rectify_epsilon(loss)

Rectify supporting a list of epsilons, one per bodypart.

remove_nans(confidences,Β loss)

Attributes Documentation

LOSS_NAME_KL = 'temporal_heatmap_kl'
LOSS_NAME_MSE = 'temporal_heatmap_mse'

Methods Documentation

__call__(heatmaps_pred: Tensor, {'__torchtyping__': True, 'details': ('batch', 'num_keypoints', 'heatmap_height', 'heatmap_width'), 'cls_name': 'TensorType'}], confidences: Tensor, {'__torchtyping__': True, 'details': ('batch', 'num_keypoints'), 'cls_name': 'TensorType'}], stage: Literal['train', 'val', 'test'] | None = None, **kwargs) Tensor, {'__torchtyping__': True, 'details': ((),), 'cls_name': 'TensorType'}], list[dict]][source]

Call self as a function.

compute_loss(predictions: Tensor, {'__torchtyping__': True, 'details': ('batch', 'num_valid_keypoints', 'heatmap_height', 'heatmap_width'), 'cls_name': 'TensorType'}]) Tensor, {'__torchtyping__': True, 'details': ('batch_minus_one', 'num_valid_keypoints',), 'cls_name': 'TensorType'}][source]
rectify_epsilon(loss: Tensor, {'__torchtyping__': True, 'details': ('batch_minus_one', 'num_valid_keypoints'), 'cls_name': 'TensorType'}]) Tensor, {'__torchtyping__': True, 'details': ('batch_minus_one', 'num_valid_keypoints',), 'cls_name': 'TensorType'}][source]

Rectify supporting a list of epsilons, one per bodypart. Not implemented in Loss class, because shapes of broadcasting may vary

remove_nans(confidences: Tensor, {'__torchtyping__': True, 'details': ('batch', 'num_keypoints'), 'cls_name': 'TensorType'}], loss: Tensor, {'__torchtyping__': True, 'details': ('batch_minus_one', 'num_keypoints'), 'cls_name': 'TensorType'}]) Tensor, {'__torchtyping__': True, 'details': ('batch_minus_one', 'num_keypoints',), 'cls_name': 'TensorType'}][source]
__init__(loss_name: Literal['temporal_heatmap_mse', 'temporal_heatmap_kl'], data_module: BaseDataModule | UnlabeledDataModule | None = None, epsilon: float | list[float] = 0.0, prob_threshold: float = 0.0, log_weight: float = 0.0, **kwargs) None[source]

Initialize TemporalHeatmapLoss.

Parameters:
  • loss_name – "temporal_heatmap_mse" uses pixel-wise MSE between consecutive heatmaps; "temporal_heatmap_kl" uses the KL divergence.

  • data_module – data module providing access to datasets; passed to the parent class.

  • epsilon – loss values below this threshold are zeroed out. May be a scalar or a list with one value per keypoint.

  • prob_threshold – predictions whose confidence is below this value are excluded from the loss computation.

  • log_weight – final weight in front of the loss term in the objective function is computed as 1.0 / (2.0 * exp(log_weight)).

__new__(**kwargs)