HeatmapLoss
- class lightning_pose.losses.losses.HeatmapLoss[source]
Bases:
LossParent class for different heatmap losses (MSE, Wasserstein, etc).
Methods Summary
__call__(heatmaps_targ, heatmaps_pred[, stage])Call self as a function.
compute_loss(**kwargs)remove_nans(targets, predictions)Methods Documentation
- __call__(heatmaps_targ: Tensor, {'__torchtyping__': True, 'details': ('batch', 'num_keypoints', 'heatmap_height', 'heatmap_width'), 'cls_name': 'TensorType'}], heatmaps_pred: Tensor, {'__torchtyping__': True, 'details': ('batch', 'num_keypoints', 'heatmap_height', 'heatmap_width'), 'cls_name': 'TensorType'}], stage: Literal['train', 'val', 'test'] | None = None, **kwargs) tuple[~torch.Annotated[~torch.Tensor, {'__torchtyping__': True, 'details': ((),), 'cls_name': 'TensorType'}], list[dict]][source]
Call self as a function.
- remove_nans(targets: Tensor, {'__torchtyping__': True, 'details': ('batch', 'num_keypoints', 'heatmap_height', 'heatmap_width'), 'cls_name': 'TensorType'}], predictions: Tensor, {'__torchtyping__': True, 'details': ('batch', 'num_keypoints', 'heatmap_height', 'heatmap_width'), 'cls_name': 'TensorType'}]) tuple[~torch.Annotated[~torch.Tensor, {'__torchtyping__': True, 'details': ('num_valid_keypoints', 'heatmap_height', 'heatmap_width',), 'cls_name': 'TensorType'}], ~torch.Annotated[~torch.Tensor, {'__torchtyping__': True, 'details': ('num_valid_keypoints', 'heatmap_height', 'heatmap_width',), 'cls_name': 'TensorType'}]][source]
- __init__(data_module: BaseDataModule | UnlabeledDataModule | None = None, log_weight: float = 0.0, **kwargs) None[source]
Initialize HeatmapLoss.
- Parameters:
data_module – data module providing access to datasets; passed to the parent class.
log_weight – final weight in front of the loss term in the objective function is computed as
1.0 / (2.0 * exp(log_weight)).
- __new__(**kwargs)