TemporalHeatmapLoss

class lightning_pose.losses.losses.TemporalHeatmapLoss(loss_name: Literal['temporal_heatmap_mse', 'temporal_heatmap_kl'], data_module: BaseDataModule | UnlabeledDataModule | None = None, epsilon: float | List[float] = 0.0, prob_threshold: float = 0.0, log_weight: float = 0.0, **kwargs)[source]

Bases: Loss

Penalize temporal differences for each heatmap.

Motion model: x_t = x_(t-1) + e_t, e_t ~ N(0, s)

Methods Summary

__call__(heatmaps_pred, confidences[, stage])

Call self as a function.

compute_loss(predictions)

rectify_epsilon(loss)

Rectify supporting a list of epsilons, one per bodypart.

remove_nans(confidences, loss)

Methods Documentation

__call__(heatmaps_pred: Tensor[Tensor], confidences: Tensor[Tensor], stage: Literal['train', 'val', 'test'] | None = None, **kwargs) Tuple[Tensor[Tensor], List[dict]][source]

Call self as a function.

compute_loss(predictions: Tensor[Tensor]) Tensor[Tensor][source]
rectify_epsilon(loss: Tensor[Tensor]) Tensor[Tensor][source]

Rectify supporting a list of epsilons, one per bodypart. Not implemented in Loss class, because shapes of broadcasting may vary

remove_nans(confidences: Tensor[Tensor], loss: Tensor[Tensor]) Tensor[Tensor][source]