HeatmapMSELoss
- class lightning_pose.losses.losses.HeatmapMSELoss[source]
Bases:
HeatmapLossMSE loss between heatmaps.
Attributes Summary
Methods Summary
compute_loss(targets, predictions)Attributes Documentation
- loss_name = 'heatmap_mse'
Methods Documentation
- compute_loss(targets: Tensor, {'__torchtyping__': True, 'details': ('batch_x_num_keypoints', 'heatmap_height', 'heatmap_width'), 'cls_name': 'TensorType'}], predictions: Tensor, {'__torchtyping__': True, 'details': ('batch_x_num_keypoints', 'heatmap_height', 'heatmap_width'), 'cls_name': 'TensorType'}]) Tensor, {'__torchtyping__': True, 'details': ('batch_x_num_keypoints', 'heatmap_height', 'heatmap_width',), 'cls_name': 'TensorType'}][source]
- __init__(data_module: BaseDataModule | UnlabeledDataModule | None = None, log_weight: float = 0.0, **kwargs) None[source]
- Parameters:
data_module – give losses access to data for computing data-specific loss params
epsilon – loss values below epsilon will be zeroed out
log_weight – natural log of the weight in front of the loss term in the final objective function
- __new__(**kwargs)