UnimodalLoss

class lightning_pose.losses.losses.UnimodalLoss[source]

Bases: Loss

Encourage heatmaps to be unimodal using various measures.

Attributes Summary

LOSS_NAME_JS

LOSS_NAME_KL

LOSS_NAME_MSE

Methods Summary

__call__(keypoints_pred_augmented, ...[, stage])

Compute unimodal loss.

compute_loss(targets, predictions)

Compute per-element divergence between ideal unimodal targets and predicted heatmaps.

remove_nans(targets, predictions, confidences)

Remove nans from targets and predictions.

Attributes Documentation

LOSS_NAME_JS = 'unimodal_js'
LOSS_NAME_KL = 'unimodal_kl'
LOSS_NAME_MSE = 'unimodal_mse'

Methods Documentation

__call__(keypoints_pred_augmented: Float[Tensor, 'batch two_x_num_keypoints'], heatmaps_pred: Float[Tensor, 'batch num_keypoints heatmap_height heatmap_width'], confidences: Float[Tensor, 'batch num_keypoints'], stage: Literal['train', 'val', 'test'] | None = None, **kwargs: Any) tuple[Float[Tensor, ''], list[dict]][source]

Compute unimodal loss.

Parameters:
  • keypoints_pred_augmented – these are in the augmented image space

  • heatmaps_pred – also in the augmented space, matching the keypoints_pred_augmented

compute_loss(targets: Float[Tensor, 'num_valid_keypoints heatmap_height heatmap_width'], predictions: Float[Tensor, 'num_valid_keypoints heatmap_height heatmap_width']) Tensor[source]

Compute per-element divergence between ideal unimodal targets and predicted heatmaps.

Parameters:
  • targets – ideal unimodal heatmaps derived from predicted keypoint coordinates.

  • predictions – predicted heatmaps from the network.

Returns:

Element-wise loss tensor.

remove_nans(targets: Float[Tensor, 'batch num_keypoints heatmap_height heatmap_width'], predictions: Float[Tensor, 'batch num_keypoints heatmap_height heatmap_width'], confidences: Float[Tensor, 'batch num_keypoints']) tuple[Float[Tensor, 'num_valid_keypoints heatmap_height heatmap_width'], Float[Tensor, 'num_valid_keypoints heatmap_height heatmap_width']][source]

Remove nans from targets and predictions. :param targets: (batch, num_keypoints, heatmap_height, heatmap_width) :param predictions: (batch, num_keypoints, heatmap_height, heatmap_width) :param confidences: (batch, num_keypoints)

Returns:

concatenated across different images and keypoints clean predictions: concatenated across different images and keypoints

Return type:

clean targets

__init__(loss_name: Literal['unimodal_mse', 'unimodal_kl', 'unimodal_js'], original_image_height: int, original_image_width: int, downsampled_image_height: int, downsampled_image_width: int, data_module: BaseDataModule | UnlabeledDataModule | None = None, prob_threshold: float = 0.0, log_weight: float = 0.0, uniform_heatmaps: bool = False, **kwargs: Any) None[source]

Initialize UnimodalLoss.

Generates an ideal unimodal heatmap from each predicted keypoint coordinate and penalizes the difference between that ideal heatmap and the network’s predicted heatmap.

Parameters:
  • loss_name – divergence measure to use. "unimodal_mse" uses pixel-wise MSE; "unimodal_kl" uses KL divergence; "unimodal_js" uses Jensen-Shannon divergence.

  • original_image_height – height of the full-resolution input image in pixels, used when generating ideal heatmaps.

  • original_image_width – width of the full-resolution input image in pixels.

  • downsampled_image_height – height of the heatmap output (after backbone downsampling).

  • downsampled_image_width – width of the heatmap output.

  • data_module – data module providing access to datasets; passed to the parent class.

  • prob_threshold – predictions whose confidence is below this value are excluded from the loss computation.

  • log_weight – final weight in front of the loss term in the objective function is computed as 1.0 / (2.0 * exp(log_weight)).

  • uniform_heatmaps – if True, generate uniform (flat) target heatmaps for NaN ground truth keypoints instead of ignoring them in the loss.

__new__(**kwargs)