UnimodalLoss
- class lightning_pose.losses.losses.UnimodalLoss[source]
Bases:
LossEncourage heatmaps to be unimodal using various measures.
Attributes Summary
Methods Summary
__call__(keypoints_pred_augmented, ...[, stage])Compute unimodal loss.
compute_loss(targets, predictions)remove_nans(targets, predictions, confidences)Remove nans from targets and predictions.
Attributes Documentation
- LOSS_NAME_JS = 'unimodal_js'
- LOSS_NAME_KL = 'unimodal_kl'
- LOSS_NAME_MSE = 'unimodal_mse'
Methods Documentation
- __call__(keypoints_pred_augmented: Tensor, {'__torchtyping__': True, 'details': ('batch', 'two_x_num_keypoints'), 'cls_name': 'TensorType'}], heatmaps_pred: Tensor, {'__torchtyping__': True, 'details': ('batch', 'num_keypoints', 'heatmap_height', 'heatmap_width'), 'cls_name': 'TensorType'}], confidences: Tensor, {'__torchtyping__': True, 'details': ('batch', 'num_keypoints'), 'cls_name': 'TensorType'}], stage: Literal['train', 'val', 'test'] | None = None, **kwargs) tuple[~torch.Annotated[~torch.Tensor, {'__torchtyping__': True, 'details': ((),), 'cls_name': 'TensorType'}], list[dict]][source]
Compute unimodal loss.
- Parameters:
keypoints_pred_augmented – these are in the augmented image space
heatmaps_pred – also in the augmented space, matching the keypoints_pred_augmented
- compute_loss(targets: Tensor, {'__torchtyping__': True, 'details': ('num_valid_keypoints', 'heatmap_height', 'heatmap_width'), 'cls_name': 'TensorType'}], predictions: Tensor, {'__torchtyping__': True, 'details': ('num_valid_keypoints', 'heatmap_height', 'heatmap_width'), 'cls_name': 'TensorType'}]) Tensor[source]
- remove_nans(targets: Tensor, {'__torchtyping__': True, 'details': ('batch', 'num_keypoints', 'heatmap_height', 'heatmap_width'), 'cls_name': 'TensorType'}], predictions: Tensor, {'__torchtyping__': True, 'details': ('batch', 'num_keypoints', 'heatmap_height', 'heatmap_width'), 'cls_name': 'TensorType'}], confidences: Tensor, {'__torchtyping__': True, 'details': ('batch', 'num_keypoints'), 'cls_name': 'TensorType'}]) tuple[~torch.Annotated[~torch.Tensor, {'__torchtyping__': True, 'details': ('num_valid_keypoints', 'heatmap_height', 'heatmap_width',), 'cls_name': 'TensorType'}], ~torch.Annotated[~torch.Tensor, {'__torchtyping__': True, 'details': ('num_valid_keypoints', 'heatmap_height', 'heatmap_width',), 'cls_name': 'TensorType'}]][source]
Remove nans from targets and predictions. :param targets: (batch, num_keypoints, heatmap_height, heatmap_width) :param predictions: (batch, num_keypoints, heatmap_height, heatmap_width) :param confidences: (batch, num_keypoints)
- Returns:
concatenated across different images and keypoints clean predictions: concatenated across different images and keypoints
- Return type:
clean targets
- __init__(loss_name: Literal['unimodal_mse', 'unimodal_kl', 'unimodal_js'], original_image_height: int, original_image_width: int, downsampled_image_height: int, downsampled_image_width: int, data_module: BaseDataModule | UnlabeledDataModule | None = None, prob_threshold: float = 0.0, log_weight: float = 0.0, uniform_heatmaps: bool = False, **kwargs) None[source]
Initialize UnimodalLoss.
Generates an ideal unimodal heatmap from each predicted keypoint coordinate and penalizes the difference between that ideal heatmap and the network’s predicted heatmap.
- Parameters:
loss_name – divergence measure to use.
"unimodal_mse"uses pixel-wise MSE;"unimodal_kl"uses KL divergence;"unimodal_js"uses Jensen-Shannon divergence.original_image_height – height of the full-resolution input image in pixels, used when generating ideal heatmaps.
original_image_width – width of the full-resolution input image in pixels.
downsampled_image_height – height of the heatmap output (after backbone downsampling).
downsampled_image_width – width of the heatmap output.
data_module – data module providing access to datasets; passed to the parent class.
prob_threshold – predictions whose confidence is below this value are excluded from the loss computation.
log_weight – final weight in front of the loss term in the objective function is computed as
1.0 / (2.0 * exp(log_weight)).uniform_heatmaps – if
True, generate uniform (flat) target heatmaps for NaN ground truth keypoints instead of ignoring them in the loss.
- __new__(**kwargs)