HeatmapLossο
- class lightning_pose.losses.losses.HeatmapLoss[source]ο
Bases:
LossParent class for different heatmap losses (MSE, Wasserstein, etc).
Methods Summary
__call__(heatmaps_targ,Β heatmaps_pred[,Β stage])Compute the heatmap loss.
compute_loss(**kwargs)Compute element-wise divergence between target and predicted heatmaps.
remove_nans(targets,Β predictions)Remove heatmap entries where all target pixels are zero (NaN/unlabeled keypoints).
Methods Documentation
- __call__(heatmaps_targ: Float[Tensor, 'batch num_keypoints heatmap_height heatmap_width'], heatmaps_pred: Float[Tensor, 'batch num_keypoints heatmap_height heatmap_width'], stage: Literal['train', 'val', 'test'] | None = None, **kwargs: Any) tuple[Float[Tensor, ''], list[dict]][source]ο
Compute the heatmap loss.
- Parameters:
heatmaps_targ β ground-truth heatmaps.
heatmaps_pred β predicted heatmaps.
stage β training stage for logging.
**kwargs β ignored extra keyword arguments.
- Returns:
Tuple of scalar loss and list of logging dicts.
- compute_loss(**kwargs: Any) Tensor[source]ο
Compute element-wise divergence between target and predicted heatmaps.
Subclasses must override this method with the specific divergence measure.
- Raises:
NotImplementedError β always, unless overridden by a subclass.
- remove_nans(targets: Float[Tensor, 'batch num_keypoints heatmap_height heatmap_width'], predictions: Float[Tensor, 'batch num_keypoints heatmap_height heatmap_width']) tuple[Float[Tensor, 'num_valid_keypoints heatmap_height heatmap_width'], Float[Tensor, 'num_valid_keypoints heatmap_height heatmap_width']][source]ο
Remove heatmap entries where all target pixels are zero (NaN/unlabeled keypoints).
- Parameters:
targets β ground-truth heatmaps.
predictions β predicted heatmaps.
- Returns:
Tuple of
(clean_targets, clean_predictions)with all-zero target rows removed.
- __init__(data_module: BaseDataModule | UnlabeledDataModule | None = None, log_weight: float = 0.0, **kwargs: Any) None[source]ο
Initialize HeatmapLoss.
- Parameters:
data_module β data module providing access to datasets; passed to the parent class.
log_weight β final weight in front of the loss term in the objective function is computed as
1.0 / (2.0 * exp(log_weight)).
- __new__(**kwargs)ο