HeatmapLoss

class lightning_pose.losses.losses.HeatmapLoss[source]

Bases: Loss

Parent class for different heatmap losses (MSE, Wasserstein, etc).

Methods Summary

__call__(heatmaps_targ,Β heatmaps_pred[,Β stage])

Compute the heatmap loss.

compute_loss(**kwargs)

Compute element-wise divergence between target and predicted heatmaps.

remove_nans(targets,Β predictions)

Remove heatmap entries where all target pixels are zero (NaN/unlabeled keypoints).

Methods Documentation

__call__(heatmaps_targ: Float[Tensor, 'batch num_keypoints heatmap_height heatmap_width'], heatmaps_pred: Float[Tensor, 'batch num_keypoints heatmap_height heatmap_width'], stage: Literal['train', 'val', 'test'] | None = None, **kwargs: Any) tuple[Float[Tensor, ''], list[dict]][source]

Compute the heatmap loss.

Parameters:
  • heatmaps_targ – ground-truth heatmaps.

  • heatmaps_pred – predicted heatmaps.

  • stage – training stage for logging.

  • **kwargs – ignored extra keyword arguments.

Returns:

Tuple of scalar loss and list of logging dicts.

compute_loss(**kwargs: Any) Tensor[source]

Compute element-wise divergence between target and predicted heatmaps.

Subclasses must override this method with the specific divergence measure.

Raises:

NotImplementedError – always, unless overridden by a subclass.

remove_nans(targets: Float[Tensor, 'batch num_keypoints heatmap_height heatmap_width'], predictions: Float[Tensor, 'batch num_keypoints heatmap_height heatmap_width']) tuple[Float[Tensor, 'num_valid_keypoints heatmap_height heatmap_width'], Float[Tensor, 'num_valid_keypoints heatmap_height heatmap_width']][source]

Remove heatmap entries where all target pixels are zero (NaN/unlabeled keypoints).

Parameters:
  • targets – ground-truth heatmaps.

  • predictions – predicted heatmaps.

Returns:

Tuple of (clean_targets, clean_predictions) with all-zero target rows removed.

__init__(data_module: BaseDataModule | UnlabeledDataModule | None = None, log_weight: float = 0.0, **kwargs: Any) None[source]

Initialize HeatmapLoss.

Parameters:
  • data_module – data module providing access to datasets; passed to the parent class.

  • log_weight – final weight in front of the loss term in the objective function is computed as 1.0 / (2.0 * exp(log_weight)).

__new__(**kwargs)