UnimodalLoss
- class lightning_pose.losses.losses.UnimodalLoss(loss_name: Literal['unimodal_mse', 'unimodal_kl', 'unimodal_js'], original_image_height: int, original_image_width: int, downsampled_image_height: int, downsampled_image_width: int, data_module: BaseDataModule | UnlabeledDataModule | None = None, prob_threshold: float = 0.0, log_weight: float = 0.0, uniform_heatmaps: bool = False, **kwargs)[source]
Bases:
LossEncourage heatmaps to be unimodal using various measures.
Methods Summary
__call__(keypoints_pred_augmented, ...[, stage])Compute unimodal loss.
compute_loss(targets, predictions)remove_nans(targets, predictions, confidences)Remove nans from targets and predictions.
Methods Documentation
- __call__(keypoints_pred_augmented: Tensor[Tensor], heatmaps_pred: Tensor[Tensor], confidences: Tensor[Tensor], stage: Literal['train', 'val', 'test'] | None = None, **kwargs) Tuple[Tensor[Tensor], List[dict]][source]
Compute unimodal loss.
- Parameters:
keypoints_pred_augmented – these are in the augmented image space
heatmaps_pred – also in the augmented space, matching the keypoints_pred_augmented
- remove_nans(targets: Tensor[Tensor], predictions: Tensor[Tensor], confidences: Tensor[Tensor]) Tuple[Tensor[Tensor], Tensor[Tensor]][source]
Remove nans from targets and predictions. :param targets: (batch, num_keypoints, heatmap_height, heatmap_width) :param predictions: (batch, num_keypoints, heatmap_height, heatmap_width) :param confidences: (batch, num_keypoints)
- Returns:
concatenated across different images and keypoints clean predictions: concatenated across different images and keypoints
- Return type:
clean targets