HeatmapLoss

class lightning_pose.losses.losses.HeatmapLoss(data_module: BaseDataModule | UnlabeledDataModule | None = None, log_weight: float = 0.0, **kwargs)[source]

Bases: Loss

Parent class for different heatmap losses (MSE, Wasserstein, etc).

Methods Summary

__call__(heatmaps_targ, heatmaps_pred[, stage])

Call self as a function.

compute_loss(**kwargs)

remove_nans(targets, predictions)

Methods Documentation

__call__(heatmaps_targ: Tensor[Tensor], heatmaps_pred: Tensor[Tensor], stage: Literal['train', 'val', 'test'] | None = None, **kwargs) Tuple[Tensor[Tensor], List[dict]][source]

Call self as a function.

compute_loss(**kwargs)[source]
remove_nans(targets: Tensor[Tensor], predictions: Tensor[Tensor]) Tuple[Tensor[Tensor], Tensor[Tensor]][source]