UnimodalLoss

class lightning_pose.losses.losses.UnimodalLoss[source]

Bases: Loss

Encourage heatmaps to be unimodal using various measures.

Methods Summary

__call__(keypoints_pred_augmented, ...[, stage])

Compute unimodal loss.

compute_loss(targets, predictions)

remove_nans(targets, predictions, confidences)

Remove nans from targets and predictions.

Methods Documentation

__call__(keypoints_pred_augmented: Tensor, {'__torchtyping__': True, 'details': ('batch', 'two_x_num_keypoints'), 'cls_name': 'TensorType'}], heatmaps_pred: Tensor, {'__torchtyping__': True, 'details': ('batch', 'num_keypoints', 'heatmap_height', 'heatmap_width'), 'cls_name': 'TensorType'}], confidences: Tensor, {'__torchtyping__': True, 'details': ('batch', 'num_keypoints'), 'cls_name': 'TensorType'}], stage: Literal['train', 'val', 'test'] | None = None, **kwargs) Tensor, {'__torchtyping__': True, 'details': ((),), 'cls_name': 'TensorType'}], list[dict]][source]

Compute unimodal loss.

Parameters:
  • keypoints_pred_augmented – these are in the augmented image space

  • heatmaps_pred – also in the augmented space, matching the keypoints_pred_augmented

compute_loss(targets: Tensor, {'__torchtyping__': True, 'details': ('num_valid_keypoints', 'heatmap_height', 'heatmap_width'), 'cls_name': 'TensorType'}], predictions: Tensor, {'__torchtyping__': True, 'details': ('num_valid_keypoints', 'heatmap_height', 'heatmap_width'), 'cls_name': 'TensorType'}]) Tensor[source]
remove_nans(targets: Tensor, {'__torchtyping__': True, 'details': ('batch', 'num_keypoints', 'heatmap_height', 'heatmap_width'), 'cls_name': 'TensorType'}], predictions: Tensor, {'__torchtyping__': True, 'details': ('batch', 'num_keypoints', 'heatmap_height', 'heatmap_width'), 'cls_name': 'TensorType'}], confidences: Tensor, {'__torchtyping__': True, 'details': ('batch', 'num_keypoints'), 'cls_name': 'TensorType'}]) Tensor, {'__torchtyping__': True, 'details': ('num_valid_keypoints', 'heatmap_height', 'heatmap_width',), 'cls_name': 'TensorType'}]][source]

Remove nans from targets and predictions. :param targets: (batch, num_keypoints, heatmap_height, heatmap_width) :param predictions: (batch, num_keypoints, heatmap_height, heatmap_width) :param confidences: (batch, num_keypoints)

Returns:

concatenated across different images and keypoints clean predictions: concatenated across different images and keypoints

Return type:

clean targets

__init__(loss_name: Literal['unimodal_mse', 'unimodal_kl', 'unimodal_js'], original_image_height: int, original_image_width: int, downsampled_image_height: int, downsampled_image_width: int, data_module: BaseDataModule | UnlabeledDataModule | None = None, prob_threshold: float = 0.0, log_weight: float = 0.0, uniform_heatmaps: bool = False, **kwargs) None[source]
Parameters:
  • data_module – give losses access to data for computing data-specific loss params

  • epsilon – loss values below epsilon will be zeroed out

  • log_weight – natural log of the weight in front of the loss term in the final objective function

__new__(**kwargs)