ReprojectionHeatmapLoss

class lightning_pose.losses.losses.ReprojectionHeatmapLoss[source]

Bases: Loss

Penalize error between predicted 2D->3D->2D->heatmap and ground truth heatmap.

Attributes Summary

loss_name

Methods Summary

__call__(heatmaps_targ,Β ...[,Β stage])

Compute the reprojection heatmap loss.

compute_loss(targets,Β predictions)

Compute pixel-wise MSE between reprojected and ground-truth heatmaps.

remove_nans(loss,Β targets)

Select only valid (non-zero-target) loss entries.

Attributes Documentation

loss_name: str = 'supervised_reprojection_heatmap_mse'

Methods Documentation

__call__(heatmaps_targ: Float[Tensor, 'batch num_keypoints heatmap_height heatmap_width'], keypoints_pred_2d_reprojected: Float[Tensor, 'batch num_keypoints 2'], stage: Literal['train', 'val', 'test'] | None = None, **kwargs: Any) tuple[Float[Tensor, ''], list[dict]][source]

Compute the reprojection heatmap loss.

Parameters:
  • heatmaps_targ – ground-truth heatmaps.

  • keypoints_pred_2d_reprojected – 2D keypoints obtained by projecting triangulated 3D predictions back into each camera, shape (batch, num_keypoints, 2).

  • stage – training stage for logging.

  • **kwargs – ignored extra keyword arguments.

Returns:

Tuple of scalar loss and list of logging dicts.

Raises:

ValueError – if keypoints_pred_2d_reprojected is None.

compute_loss(targets: Float[Tensor, 'batch_x_num_keypoints heatmap_height heatmap_width'], predictions: Float[Tensor, 'batch_x_num_keypoints heatmap_height heatmap_width']) Float[Tensor, 'batch_x_num_keypoints heatmap_height heatmap_width'][source]

Compute pixel-wise MSE between reprojected and ground-truth heatmaps.

Parameters:
  • targets – ground-truth heatmaps.

  • predictions – heatmaps generated from reprojected 2D keypoints.

Returns:

Element-wise MSE scaled by the number of heatmap pixels.

remove_nans(loss: Float[Tensor, 'batch num_keypoints heatmap_height heatmap_width'], targets: Float[Tensor, 'batch num_keypoints heatmap_height heatmap_width']) Float[Tensor, 'valid_losses'][source]

Select only valid (non-zero-target) loss entries.

Parameters:
  • loss – element-wise MSE loss tensor.

  • targets – ground-truth heatmaps; all-zero heatmaps indicate unlabeled keypoints.

Returns:

Flat tensor of valid loss values, or a zero scalar if none are valid.

__init__(original_image_height: int, original_image_width: int, downsampled_image_height: int, downsampled_image_width: int, log_weight: float = 0.0, uniform_heatmaps: bool = False, **kwargs: Any) None[source]

Initialize ReprojectionHeatmapLoss.

Converts 2D reprojected keypoints (obtained by projecting 3D triangulated predictions back into each camera’s image plane) into heatmaps and compares them with the ground truth heatmaps using pixel-wise MSE.

Parameters:
  • original_image_height – height of the full-resolution input image in pixels.

  • original_image_width – width of the full-resolution input image in pixels.

  • downsampled_image_height – height of the heatmap output (after backbone downsampling).

  • downsampled_image_width – width of the heatmap output.

  • log_weight – final weight in front of the loss term in the objective function is computed as 1.0 / (2.0 * exp(log_weight)).

  • uniform_heatmaps – if True, generate uniform (flat) target heatmaps for NaN ground truth keypoints instead of ignoring them in the loss.

__new__(**kwargs)