ReprojectionHeatmapLossο
- class lightning_pose.losses.losses.ReprojectionHeatmapLoss[source]ο
Bases:
LossPenalize error between predicted 2D->3D->2D->heatmap and ground truth heatmap.
Attributes Summary
Methods Summary
__call__(heatmaps_targ,Β ...[,Β stage])Compute the reprojection heatmap loss.
compute_loss(targets,Β predictions)Compute pixel-wise MSE between reprojected and ground-truth heatmaps.
remove_nans(loss,Β targets)Select only valid (non-zero-target) loss entries.
Attributes Documentation
- loss_name: str = 'supervised_reprojection_heatmap_mse'ο
Methods Documentation
- __call__(heatmaps_targ: Float[Tensor, 'batch num_keypoints heatmap_height heatmap_width'], keypoints_pred_2d_reprojected: Float[Tensor, 'batch num_keypoints 2'], stage: Literal['train', 'val', 'test'] | None = None, **kwargs: Any) tuple[Float[Tensor, ''], list[dict]][source]ο
Compute the reprojection heatmap loss.
- Parameters:
heatmaps_targ β ground-truth heatmaps.
keypoints_pred_2d_reprojected β 2D keypoints obtained by projecting triangulated 3D predictions back into each camera, shape
(batch, num_keypoints, 2).stage β training stage for logging.
**kwargs β ignored extra keyword arguments.
- Returns:
Tuple of scalar loss and list of logging dicts.
- Raises:
ValueError β if
keypoints_pred_2d_reprojectedisNone.
- compute_loss(targets: Float[Tensor, 'batch_x_num_keypoints heatmap_height heatmap_width'], predictions: Float[Tensor, 'batch_x_num_keypoints heatmap_height heatmap_width']) Float[Tensor, 'batch_x_num_keypoints heatmap_height heatmap_width'][source]ο
Compute pixel-wise MSE between reprojected and ground-truth heatmaps.
- Parameters:
targets β ground-truth heatmaps.
predictions β heatmaps generated from reprojected 2D keypoints.
- Returns:
Element-wise MSE scaled by the number of heatmap pixels.
- remove_nans(loss: Float[Tensor, 'batch num_keypoints heatmap_height heatmap_width'], targets: Float[Tensor, 'batch num_keypoints heatmap_height heatmap_width']) Float[Tensor, 'valid_losses'][source]ο
Select only valid (non-zero-target) loss entries.
- Parameters:
loss β element-wise MSE loss tensor.
targets β ground-truth heatmaps; all-zero heatmaps indicate unlabeled keypoints.
- Returns:
Flat tensor of valid loss values, or a zero scalar if none are valid.
- __init__(original_image_height: int, original_image_width: int, downsampled_image_height: int, downsampled_image_width: int, log_weight: float = 0.0, uniform_heatmaps: bool = False, **kwargs: Any) None[source]ο
Initialize ReprojectionHeatmapLoss.
Converts 2D reprojected keypoints (obtained by projecting 3D triangulated predictions back into each cameraβs image plane) into heatmaps and compares them with the ground truth heatmaps using pixel-wise MSE.
- Parameters:
original_image_height β height of the full-resolution input image in pixels.
original_image_width β width of the full-resolution input image in pixels.
downsampled_image_height β height of the heatmap output (after backbone downsampling).
downsampled_image_width β width of the heatmap output.
log_weight β final weight in front of the loss term in the objective function is computed as
1.0 / (2.0 * exp(log_weight)).uniform_heatmaps β if
True, generate uniform (flat) target heatmaps for NaN ground truth keypoints instead of ignoring them in the loss.
- __new__(**kwargs)ο