Source code for lightning_pose.losses.losses

"""Supervised and unsupervised losses implemented in pytorch.

The lightning pose package defines each loss as its own class; an initialized loss
object, in addition to computing the loss, stores hyperparameters related to the loss
(weight in the final objective funcion, epsilon-insensitivity parameter, etc.)

A separate LossFactory class (defined in lightning_pose.losses.factory) collects all
losses for a given model and orchestrates their execution, logging, etc.

The general flow of each loss class is as follows:
- input: predicted and ground truth data
- step 0: remove ground truth samples containing nans if desired
- step 1: compute loss for each batch element/keypoint/etc
- step 2: epsilon-insensitivity: set loss to zero for any batch element with loss < epsilon
- step 3: reduce loss (usually mean)
- step 4: log values to a dict
- step 5: return loss

"""

import os
from typing import Any, Literal

import torch
from jaxtyping import Float
from kornia.losses import js_div_loss_2d, kl_div_loss_2d
from omegaconf import ListConfig
from torch.nn import functional as F

from lightning_pose.data.datamodules import BaseDataModule, UnlabeledDataModule
from lightning_pose.data.utils import generate_heatmaps
from lightning_pose.utils.pca import KeypointPCA

# to ignore imports for sphix-autoapidoc
__all__ = [
    "Loss",
    "HeatmapLoss",
    "HeatmapMSELoss",
    "HeatmapKLLoss",
    "HeatmapJSLoss",
    "PCALoss",
    "TemporalLoss",
    "TemporalHeatmapLoss",
    "UnimodalLoss",
    "RegressionMSELoss",
    "RegressionRMSELoss",
    "PairwiseProjectionsLoss",
    "ReprojectionHeatmapLoss",
]

_DEFAULT_TORCH_DEVICE = "cpu"
if torch.cuda.is_available():
    # When running with multiple GPUs, the LOCAL_RANK variable correctly
    # contains the DDP Local Rank, which is also the cuda device index.
    _DEFAULT_TORCH_DEVICE = f"cuda:{int(os.environ.get('LOCAL_RANK', '0'))}"


[docs] class Loss: """Parent class for all losses.""" loss_name: str
[docs] def __init__( self, data_module: BaseDataModule | UnlabeledDataModule | None = None, epsilon: float | list[float] = 0.0, log_weight: float = 0.0, **kwargs: Any, ) -> None: """ Args: data_module: give losses access to data for computing data-specific loss params epsilon: loss values below epsilon will be zeroed out log_weight: final weight in front of the loss term in the objective function is computed as ``1.0 / (2.0 * exp(log_weight))``. """ super().__init__() self.data_module = data_module # epsilon can either by a float or a list of floats self.epsilon = torch.tensor(epsilon, dtype=torch.float) self.log_weight = torch.tensor(log_weight, dtype=torch.float) self.reduce_methods_dict = {"mean": torch.mean, "sum": torch.sum}
@property def weight(self) -> Float[torch.Tensor, ""]: """Scalar loss weight computed as ``1 / (2 * exp(log_weight))``. Returns: Positive scalar weight tensor. """ # weight = \sigma where our trainable parameter is \log(\sigma^2). # i.e., we take the parameter as it is in the config and exponentiate it to # enforce positivity weight = 1.0 / (2.0 * torch.exp(self.log_weight)) return weight
[docs] def remove_nans(self, **kwargs: Any) -> Any: """Remove NaN entries from inputs before computing the loss. Subclasses must override this method to implement the appropriate NaN-masking strategy. Raises: NotImplementedError: always, unless overridden by a subclass. """ # find nans in the targets, and do a masked_select operation raise NotImplementedError
[docs] def compute_loss(self, **kwargs: Any) -> torch.Tensor: """Compute the element-wise loss between targets and predictions. Subclasses must override this method. Raises: NotImplementedError: always, unless overridden by a subclass. """ raise NotImplementedError
[docs] def rectify_epsilon(self, loss: torch.Tensor) -> torch.Tensor: """Zero out loss values below the epsilon threshold (epsilon-insensitive loss). Args: loss: element-wise loss tensor. Returns: Loss tensor with values below ``self.epsilon`` set to zero via ReLU. """ # loss values below epsilon as masked to zero loss = F.relu(loss - self.epsilon) return loss
[docs] def reduce_loss(self, loss: torch.Tensor, method: str = "mean") -> Float[torch.Tensor, ""]: """Reduce an element-wise loss tensor to a scalar. Args: loss: element-wise loss tensor. method: reduction method; currently ``"mean"`` or ``"sum"``. Returns: Scalar loss tensor. """ return self.reduce_methods_dict[method](loss)
[docs] def log_loss( self, loss: torch.Tensor, stage: Literal["train", "val", "test"] | None, ) -> list[dict]: """Build a list of logging dicts for the scalar loss and its weight. Args: loss: scalar loss value to log. stage: training stage prefix for the log key, or ``None`` to skip stage prefixing. Returns: List of dicts with ``"name"`` and ``"value"`` keys, one for the loss and one for the weight. """ loss_dict = { "name": f"{stage}_{self.loss_name}_loss", "value": loss, "prog_bar": True, } weight_dict = { "name": f"{self.loss_name}_weight", "value": self.weight, } return [loss_dict, weight_dict]
[docs] def __call__( self, *args: Any, **kwargs: Any, ) -> tuple[Float[torch.Tensor, ""], list[dict]]: """Execute the full loss pipeline and return a scalar loss plus logging dicts. The standard pipeline is: remove_nans → compute_loss → rectify_epsilon → reduce_loss → log_loss. Subclasses must override this method to supply the correct arguments to each step. Raises: NotImplementedError: always, unless overridden by a subclass. """ # give us the flow of operations, and we overwrite the methods, and determine # their arguments which are in buffer # self.remove_nans() # self.compute_loss() # self.rectify_epsilon() # self.reduce_loss() # self.log_loss() # return scalar_loss, logs raise NotImplementedError
[docs] class HeatmapLoss(Loss): """Parent class for different heatmap losses (MSE, Wasserstein, etc)."""
[docs] def __init__( self, data_module: BaseDataModule | UnlabeledDataModule | None = None, log_weight: float = 0.0, **kwargs: Any, ) -> None: """Initialize HeatmapLoss. Args: data_module: data module providing access to datasets; passed to the parent class. log_weight: final weight in front of the loss term in the objective function is computed as ``1.0 / (2.0 * exp(log_weight))``. """ super().__init__(data_module=data_module, log_weight=log_weight)
[docs] def remove_nans( self, targets: Float[torch.Tensor, "batch num_keypoints heatmap_height heatmap_width"], predictions: Float[torch.Tensor, "batch num_keypoints heatmap_height heatmap_width"], ) -> tuple[ Float[torch.Tensor, "num_valid_keypoints heatmap_height heatmap_width"], Float[torch.Tensor, "num_valid_keypoints heatmap_height heatmap_width"], ]: """Remove heatmap entries where all target pixels are zero (NaN/unlabeled keypoints). Args: targets: ground-truth heatmaps. predictions: predicted heatmaps. Returns: Tuple of ``(clean_targets, clean_predictions)`` with all-zero target rows removed. """ squeezed_targets = targets.reshape(targets.shape[0], targets.shape[1], -1) idxs_ignore = torch.all(squeezed_targets == 0.0, dim=-1) return targets[~idxs_ignore], predictions[~idxs_ignore]
[docs] def compute_loss(self, **kwargs: Any) -> torch.Tensor: """Compute element-wise divergence between target and predicted heatmaps. Subclasses must override this method with the specific divergence measure. Raises: NotImplementedError: always, unless overridden by a subclass. """ raise NotImplementedError
[docs] def __call__( self, heatmaps_targ: Float[torch.Tensor, "batch num_keypoints heatmap_height heatmap_width"], heatmaps_pred: Float[torch.Tensor, "batch num_keypoints heatmap_height heatmap_width"], stage: Literal["train", "val", "test"] | None = None, **kwargs: Any, ) -> tuple[Float[torch.Tensor, ""], list[dict]]: """Compute the heatmap loss. Args: heatmaps_targ: ground-truth heatmaps. heatmaps_pred: predicted heatmaps. stage: training stage for logging. **kwargs: ignored extra keyword arguments. Returns: Tuple of scalar loss and list of logging dicts. """ # give us the flow of operations, and we overwrite the methods, and determine # their arguments which are in buffer clean_targets, clean_predictions = self.remove_nans( targets=heatmaps_targ, predictions=heatmaps_pred ) elementwise_loss = self.compute_loss( targets=clean_targets, predictions=clean_predictions ) scalar_loss = self.reduce_loss(elementwise_loss, method="mean") logs = self.log_loss(loss=scalar_loss, stage=stage) return scalar_loss, logs
[docs] class HeatmapMSELoss(HeatmapLoss): """MSE loss between heatmaps.""" loss_name = "heatmap_mse"
[docs] def __init__( self, data_module: BaseDataModule | UnlabeledDataModule | None = None, log_weight: float = 0.0, **kwargs: Any, ) -> None: """Initialize HeatmapMSELoss. Args: data_module: data module providing access to datasets; passed to the parent class. log_weight: final weight in front of the loss term in the objective function is computed as ``1.0 / (2.0 * exp(log_weight))``. """ super().__init__(data_module=data_module, log_weight=log_weight)
[docs] def compute_loss( self, targets: Float[torch.Tensor, "batch_x_num_keypoints heatmap_height heatmap_width"], predictions: Float[torch.Tensor, "batch_x_num_keypoints heatmap_height heatmap_width"], ) -> Float[torch.Tensor, "batch_x_num_keypoints heatmap_height heatmap_width"]: """Compute pixel-wise MSE between target and predicted heatmaps. Multiplies by the number of heatmap pixels (h * w) to keep the loss magnitude consistent across different heatmap resolutions. Args: targets: ground-truth heatmaps. predictions: model-predicted heatmaps. Returns: element-wise MSE scaled by heatmap area. """ h = targets.shape[1] w = targets.shape[2] # multiply by number of pixels in heatmap to standardize loss range loss = F.mse_loss(targets, predictions, reduction="none") * h * w return loss
[docs] class HeatmapKLLoss(HeatmapLoss): """Kullback-Leibler loss between heatmaps.""" loss_name = "heatmap_kl"
[docs] def __init__( self, data_module: BaseDataModule | UnlabeledDataModule | None = None, log_weight: float = 0.0, **kwargs: Any, ) -> None: """Initialize HeatmapKLLoss. Args: data_module: data module providing access to datasets; passed to the parent class. log_weight: final weight in front of the loss term in the objective function is computed as ``1.0 / (2.0 * exp(log_weight))``. """ super().__init__(data_module=data_module, log_weight=log_weight) self.loss = kl_div_loss_2d
[docs] def compute_loss( self, targets: Float[torch.Tensor, "num_valid_keypoints heatmap_height heatmap_width"], predictions: Float[torch.Tensor, "num_valid_keypoints heatmap_height heatmap_width"], ) -> Float[torch.Tensor, "num_valid_keypoints"]: """Compute per-keypoint KL divergence between target and predicted heatmaps. Args: targets: ground-truth heatmaps. predictions: model-predicted heatmaps. Returns: per-keypoint KL divergence values. """ loss = self.loss( predictions.unsqueeze(0) + 1e-10, targets.unsqueeze(0) + 1e-10, reduction="none", ) return loss[0]
[docs] class HeatmapJSLoss(HeatmapLoss): """Jensen-Shannon loss between heatmaps.""" loss_name = "heatmap_js"
[docs] def __init__( self, data_module: BaseDataModule | UnlabeledDataModule | None = None, log_weight: float = 0.0, **kwargs: Any, ) -> None: """Initialize HeatmapJSLoss. Args: data_module: data module providing access to datasets; passed to the parent class. log_weight: final weight in front of the loss term in the objective function is computed as ``1.0 / (2.0 * exp(log_weight))``. """ super().__init__(data_module=data_module, log_weight=log_weight) self.loss = js_div_loss_2d
[docs] def compute_loss( self, targets: Float[torch.Tensor, "num_valid_keypoints heatmap_height heatmap_width"], predictions: Float[torch.Tensor, "num_valid_keypoints heatmap_height heatmap_width"], ) -> Float[torch.Tensor, "num_valid_keypoints"]: """Compute per-keypoint Jensen-Shannon divergence between target and predicted heatmaps. Args: targets: ground-truth heatmaps. predictions: model-predicted heatmaps. Returns: per-keypoint JS divergence values. """ loss = self.loss( predictions.unsqueeze(0) + 1e-10, targets.unsqueeze(0) + 1e-10, reduction="none", ) return loss[0]
[docs] class PCALoss(Loss): """Penalize predictions that fall outside a low-dimensional subspace.""" # define all valid loss names as class constants LOSS_NAME_MULTIVIEW = "pca_multiview" LOSS_NAME_SINGLEVIEW = "pca_singleview"
[docs] def __init__( self, loss_name: Literal["pca_singleview", "pca_multiview"], components_to_keep: int | float = 0.95, empirical_epsilon_percentile: float = 99.0, epsilon: float | None = None, empirical_epsilon_multiplier: float = 1.0, mirrored_column_matches: ListConfig | list | None = None, columns_for_singleview_pca: ListConfig | list | None = None, data_module: BaseDataModule | UnlabeledDataModule | None = None, log_weight: float = 0.0, device: str | torch.device = _DEFAULT_TORCH_DEVICE, centering_method: Literal["mean", "median"] | None = None, **kwargs: Any, ) -> None: """Initialize PCALoss. Fits a :class:`KeypointPCA` object on the training data and uses the resulting low-dimensional subspace to penalize out-of-subspace predictions at training time. Args: loss_name: ``"pca_singleview"`` penalizes single-camera predictions; ``"pca_multiview"`` penalizes predictions that are inconsistent across views. components_to_keep: passed to :class:`KeypointPCA`; see its docstring for details. empirical_epsilon_percentile: percentile of the training-data reprojection error used to set epsilon when ``epsilon`` is ``None``; in ``[0, 100]``. epsilon: if not ``None``, use this fixed epsilon value and ignore ``empirical_epsilon_percentile``. empirical_epsilon_multiplier: scalar multiplier applied to the empirically computed epsilon before use. mirrored_column_matches: required for ``"pca_multiview"``; see :class:`KeypointPCA` for details. columns_for_singleview_pca: subset of keypoint indices to use for singleview PCA; ``None`` uses all keypoints. data_module: data module used by :class:`KeypointPCA` to extract training data. log_weight: final weight in front of the loss term in the objective function is computed as ``1.0 / (2.0 * exp(log_weight))``. device: device on which PCA parameters are stored and loss is computed. centering_method: if not ``None``, subtract the per-frame keypoint centroid before fitting PCA. ``"mean"`` uses the arithmetic mean; ``"median"`` uses the median. """ super().__init__(data_module=data_module, log_weight=log_weight) self.device = device # validate against class constants if loss_name not in (self.LOSS_NAME_MULTIVIEW, self.LOSS_NAME_SINGLEVIEW): raise ValueError(f"Invalid loss_name: {loss_name}") self.loss_name = loss_name if loss_name == "pca_multiview": if mirrored_column_matches is None: raise ValueError("must provide mirrored_column_matches in data config") # the current data_module contains datasets that are loaded using augmentations. the # current solution is to pass the data module to KeypointPCA, which then passes it to # DataExtractor; we will also pass a "no_augmentation" arg to DataExtractor which will # rebuild the data module with only resizing augmentations, then extract the data. # initialize keypoint pca module # this module will fit pca on training data, and will define the error metric # and fuction to be used in model training. assert data_module is not None, 'PCALoss requires a data_module to fit PCA' self.pca = KeypointPCA( loss_type=self.loss_name, data_module=data_module, components_to_keep=components_to_keep, empirical_epsilon_percentile=empirical_epsilon_percentile, mirrored_column_matches=mirrored_column_matches, columns_for_singleview_pca=columns_for_singleview_pca, device=device, centering_method=centering_method, ) # compute all the parameters needed for the loss self.pca() # select epsilon based on constructor inputs if epsilon is not None: self.epsilon = torch.tensor(epsilon, dtype=torch.float, device=self.device) print(f"Using absolute epsilon={epsilon:.2f} for pca loss; empirical epsilon ignored") else: # empirically compute epsilon, already converted to tensor self.epsilon = self.pca.parameters["epsilon"] * empirical_epsilon_multiplier print( f"Using empirical epsilon={float(self.pca.parameters['epsilon']):.3f}" f" * multiplier={float(empirical_epsilon_multiplier):.3f}" f" -> total={float(self.epsilon):.3f} for {self.loss_name} loss", )
[docs] def remove_nans(self, **kwargs: Any) -> Any: """No-op for PCALoss; NaN handling is performed inside :meth:`compute_loss`.""" # find nans in the targets, and do a masked_select operation pass
[docs] def compute_loss( self, predictions: Float[torch.Tensor, "num_samples sample_dim"], ) -> Float[torch.Tensor, "num_samples _"]: """Compute per-sample PCA reprojection error. Args: predictions: predicted keypoint coordinates, shape ``(num_samples, sample_dim)``. Returns: Reprojection error per sample and keypoint. """ assert predictions.device == torch.device(self.device), ( predictions.device, torch.device(self.device), ) # compute either reprojection error or projection onto discarded evecs. # they will vary in the last dim, hence -1. return self.pca.compute_reprojection_error(data_arr=predictions)
[docs] def __call__( self, keypoints_pred: torch.Tensor, stage: Literal["train", "val", "test"] | None = None, **kwargs: Any, ) -> tuple[Float[torch.Tensor, ""], list[dict]]: """Compute the PCA loss for a batch of predicted keypoints. Args: keypoints_pred: predicted keypoint coordinates. stage: training stage for logging. **kwargs: ignored extra keyword arguments. Returns: Tuple of scalar loss and list of logging dicts. """ assert keypoints_pred.device == torch.device(self.device), ( keypoints_pred.device, torch.device(self.device), ) keypoints_pred = self.pca._format_data(data_arr=keypoints_pred) elementwise_loss = self.compute_loss(predictions=keypoints_pred) epsilon_insensitive_loss = self.rectify_epsilon(loss=elementwise_loss) scalar_loss = self.reduce_loss(epsilon_insensitive_loss, method="mean") logs = self.log_loss(loss=scalar_loss, stage=stage) return scalar_loss, logs
[docs] class TemporalLoss(Loss): """Penalize temporal differences for each target. Motion model: x_t = x_(t-1) + e_t, e_t ~ N(0, s) """ loss_name = "temporal"
[docs] def __init__( self, data_module: BaseDataModule | UnlabeledDataModule | None = None, epsilon: float | list[float] = 0.0, prob_threshold: float = 0.0, log_weight: float = 0.0, **kwargs: Any, ) -> None: """Initialize TemporalLoss. Args: data_module: data module providing access to datasets; passed to the parent class. epsilon: loss values below this threshold are zeroed out. May be a scalar or a list with one value per keypoint. prob_threshold: predictions whose confidence is below this value are excluded from the loss computation. log_weight: final weight in front of the loss term in the objective function is computed as ``1.0 / (2.0 * exp(log_weight))``. """ super().__init__(data_module=data_module, epsilon=epsilon, log_weight=log_weight) self.prob_threshold = torch.tensor(prob_threshold, dtype=torch.float)
[docs] def rectify_epsilon( self, loss: Float[torch.Tensor, "batch_minus_one num_keypoints"] ) -> Float[torch.Tensor, "batch_minus_one num_keypoints"]: """Rectify supporting a list of epsilons, one per bodypart. Not implemented in Loss class, because shapes of broadcasting may vary""" # self.epsilon is a tensor initialized in parent class # repeating for broadcasting. # note: this unsqueezing doesn't affect anything if epsilon is a scalar tensor, # but it does if it's a tensor with multiple elements. epsilon = self.epsilon.unsqueeze(0).repeat(loss.shape[0], 1).to(loss.device) return F.relu(loss - epsilon)
[docs] def remove_nans( self, loss: Float[torch.Tensor, "batch_minus_one num_keypoints"], confidences: Float[torch.Tensor, "batch num_keypoints"], ) -> Float[torch.Tensor, "batch_minus_one num_keypoints"]: """Zero out temporal difference losses where either neighboring frame is low-confidence. Args: loss: temporal difference losses of shape ``(batch-1, num_keypoints)``. confidences: per-frame confidence scores of shape ``(batch, num_keypoints)``. Returns: Loss tensor with entries zeroed where confidence falls below ``self.prob_threshold``. """ # find nans in the targets, and do a masked_select operation # get rid of unsupervised targets with extremely uncertain predictions or likely occlusions idxs_ignore = confidences < self.prob_threshold # ignore the loss values in the diff where one of the heatmaps is 'nan' union_idxs_ignore = torch.zeros( (confidences.shape[0] - 1, confidences.shape[1]), dtype=torch.bool, device=loss.device, ) for i in range(confidences.shape[0] - 1): union_idxs_ignore[i] = torch.logical_or(idxs_ignore[i], idxs_ignore[i + 1]) # clone loss and zero out the nan values clean_loss = loss.clone() clean_loss[union_idxs_ignore] = 0.0 return clean_loss
[docs] def compute_loss( self, predictions: Float[torch.Tensor, "batch two_x_num_keypoints"], ) -> Float[torch.Tensor, "batch_minus_one num_keypoints"]: """Compute per-keypoint L2 temporal differences between consecutive frames. Args: predictions: predicted (x, y) keypoints of shape ``(batch, 2*num_keypoints)``. Returns: L2 norm of frame-to-frame differences, shape ``(batch-1, num_keypoints)``. """ # return shape: (batch - 1, num_targets) diffs = torch.diff(predictions, dim=0) # return shape: (batch - 1, num_keypoints, 2) reshape = torch.reshape(diffs, (diffs.shape[0], -1, 2)) # return shape (batch - 1, num_keypoints) loss = torch.linalg.norm(reshape, ord=2, dim=2) return loss
[docs] def __call__( self, keypoints_pred: Float[torch.Tensor, "batch two_x_num_keypoints"], confidences: Float[torch.Tensor, "batch num_keypoints"] | None = None, stage: Literal["train", "val", "test"] | None = None, **kwargs: Any, ) -> tuple[Float[torch.Tensor, ""], list[dict]]: """Compute the temporal loss for a batch of predicted keypoints. Args: keypoints_pred: predicted (x, y) keypoints of shape ``(batch, 2*num_keypoints)``. confidences: per-frame confidence scores; if provided, low-confidence frames are masked out. stage: training stage for logging. **kwargs: ignored extra keyword arguments. Returns: Tuple of scalar loss and list of logging dicts. """ elementwise_loss = self.compute_loss(predictions=keypoints_pred) # do remove nans with loss to remove temporal difference values clean_loss = ( self.remove_nans(loss=elementwise_loss, confidences=confidences) if confidences is not None else elementwise_loss ) epsilon_insensitive_loss = self.rectify_epsilon(loss=clean_loss) scalar_loss = self.reduce_loss(epsilon_insensitive_loss, method="mean") logs = self.log_loss(loss=scalar_loss, stage=stage) return scalar_loss, logs
[docs] class TemporalHeatmapLoss(Loss): """Penalize temporal differences for each heatmap. Motion model: x_t = x_(t-1) + e_t, e_t ~ N(0, s) """ LOSS_NAME_MSE = "temporal_heatmap_mse" LOSS_NAME_KL = "temporal_heatmap_kl"
[docs] def __init__( self, loss_name: Literal["temporal_heatmap_mse", "temporal_heatmap_kl"], data_module: BaseDataModule | UnlabeledDataModule | None = None, epsilon: float | list[float] = 0.0, prob_threshold: float = 0.0, log_weight: float = 0.0, **kwargs: Any, ) -> None: """Initialize TemporalHeatmapLoss. Args: loss_name: ``"temporal_heatmap_mse"`` uses pixel-wise MSE between consecutive heatmaps; ``"temporal_heatmap_kl"`` uses the KL divergence. data_module: data module providing access to datasets; passed to the parent class. epsilon: loss values below this threshold are zeroed out. May be a scalar or a list with one value per keypoint. prob_threshold: predictions whose confidence is below this value are excluded from the loss computation. log_weight: final weight in front of the loss term in the objective function is computed as ``1.0 / (2.0 * exp(log_weight))``. """ super().__init__(data_module=data_module, epsilon=epsilon, log_weight=log_weight) if loss_name not in (self.LOSS_NAME_MSE, self.LOSS_NAME_KL): raise ValueError(f"Invalid loss_name: {loss_name}") self.loss_name = loss_name if self.loss_name == "temporal_heatmap_mse": self.hmloss = None elif self.loss_name == "temporal_heatmap_kl": self.hmloss = kl_div_loss_2d else: raise NotImplementedError self.prob_threshold = torch.tensor(prob_threshold, dtype=torch.float)
[docs] def rectify_epsilon( self, loss: Float[torch.Tensor, "batch_minus_one num_valid_keypoints"] ) -> Float[torch.Tensor, "batch_minus_one num_valid_keypoints"]: """Rectify supporting a list of epsilons, one per bodypart. Not implemented in Loss class, because shapes of broadcasting may vary""" # self.epsilon is a tensor initialized in parent class # repeating for broadcasting. # note: this unsqueezing doesn't affect anything if epsilon is a scalar tensor, # but it does if it's a tensor with multiple elements. epsilon = self.epsilon.unsqueeze(0).repeat(loss.shape[0], 1).to(loss.device) return F.relu(loss - epsilon)
[docs] def remove_nans( self, confidences: Float[torch.Tensor, "batch num_keypoints"], loss: Float[torch.Tensor, "batch_minus_one num_keypoints"], ) -> Float[torch.Tensor, "batch_minus_one num_keypoints"]: """Zero out heatmap temporal difference losses where adjacent frames are low-confidence. Args: confidences: per-frame confidence scores of shape ``(batch, num_keypoints)``. loss: temporal difference losses of shape ``(batch-1, num_keypoints)``. Returns: Loss tensor with entries zeroed where confidence falls below ``self.prob_threshold``. """ # find nans in the targets, and do a masked_select operation # get rid of unsupervised targets with extremely uncertain predictions or likely occlusions idxs_ignore = confidences < self.prob_threshold # ignore the loss values in the diff where one of the heatmaps is 'nan' union_idxs_ignore = torch.zeros( (confidences.shape[0] - 1, confidences.shape[1]), dtype=torch.bool ).to(loss.device) for i in range(confidences.shape[0] - 1): union_idxs_ignore[i] = torch.logical_or(idxs_ignore[i], idxs_ignore[i + 1]) loss[union_idxs_ignore] = 0.0 return loss
[docs] def compute_loss( self, predictions: Float[torch.Tensor, "batch num_valid_keypoints heatmap_height heatmap_width"], ) -> Float[torch.Tensor, "batch_minus_one num_valid_keypoints"]: """Compute per-keypoint temporal heatmap differences between consecutive frames. Args: predictions: predicted heatmaps of shape ``(batch, num_keypoints, heatmap_height, heatmap_width)``. Returns: Per-keypoint temporal divergence of shape ``(batch-1, num_keypoints)``. """ # compute the differences between matching heatmaps for each keypoint diffs = torch.zeros( (predictions.shape[0] - 1, predictions.shape[1]), device=predictions.device ) for i in range(diffs.shape[0]): if self.loss_name == "temporal_heatmap_mse": curr_mse = F.mse_loss( predictions[i], predictions[i + 1], reduction="none" ).reshape(predictions.shape[1], -1) diffs[i] = torch.mean(curr_mse, dim=-1) elif self.loss_name == "temporal_heatmap_kl": assert self.hmloss is not None diffs[i] = self.hmloss( predictions[i].unsqueeze(0) + 1e-10, predictions[i + 1].unsqueeze(0) + 1e-10, reduction="none", ) return diffs
[docs] def __call__( self, heatmaps_pred: Float[torch.Tensor, "batch num_keypoints heatmap_height heatmap_width"], confidences: Float[torch.Tensor, "batch num_keypoints"], stage: Literal["train", "val", "test"] | None = None, **kwargs: Any, ) -> tuple[Float[torch.Tensor, ""], list[dict]]: """Compute the temporal heatmap loss for a batch of predicted heatmaps. Args: heatmaps_pred: predicted heatmaps of shape ``(batch, num_keypoints, heatmap_height, heatmap_width)``. confidences: per-frame confidence scores of shape ``(batch, num_keypoints)``. stage: training stage for logging. **kwargs: ignored extra keyword arguments. Returns: Tuple of scalar loss and list of logging dicts. """ elementwise_loss = self.compute_loss(predictions=heatmaps_pred) # remove nan after loss is computed to get rid of diff vals with a bad heatmap clean_loss = self.remove_nans(confidences=confidences, loss=elementwise_loss) epsilon_insensitive_loss = self.rectify_epsilon(loss=clean_loss) scalar_loss = self.reduce_loss(epsilon_insensitive_loss, method="mean") logs = self.log_loss(loss=scalar_loss, stage=stage) return scalar_loss, logs
[docs] class UnimodalLoss(Loss): """Encourage heatmaps to be unimodal using various measures.""" LOSS_NAME_MSE = "unimodal_mse" LOSS_NAME_KL = "unimodal_kl" LOSS_NAME_JS = "unimodal_js"
[docs] def __init__( self, loss_name: Literal["unimodal_mse", "unimodal_kl", "unimodal_js"], original_image_height: int, original_image_width: int, downsampled_image_height: int, downsampled_image_width: int, data_module: BaseDataModule | UnlabeledDataModule | None = None, prob_threshold: float = 0.0, log_weight: float = 0.0, uniform_heatmaps: bool = False, **kwargs: Any, ) -> None: """Initialize UnimodalLoss. Generates an ideal unimodal heatmap from each predicted keypoint coordinate and penalizes the difference between that ideal heatmap and the network's predicted heatmap. Args: loss_name: divergence measure to use. ``"unimodal_mse"`` uses pixel-wise MSE; ``"unimodal_kl"`` uses KL divergence; ``"unimodal_js"`` uses Jensen-Shannon divergence. original_image_height: height of the full-resolution input image in pixels, used when generating ideal heatmaps. original_image_width: width of the full-resolution input image in pixels. downsampled_image_height: height of the heatmap output (after backbone downsampling). downsampled_image_width: width of the heatmap output. data_module: data module providing access to datasets; passed to the parent class. prob_threshold: predictions whose confidence is below this value are excluded from the loss computation. log_weight: final weight in front of the loss term in the objective function is computed as ``1.0 / (2.0 * exp(log_weight))``. uniform_heatmaps: if ``True``, generate uniform (flat) target heatmaps for NaN ground truth keypoints instead of ignoring them in the loss. """ super().__init__(data_module=data_module, log_weight=log_weight) if loss_name not in (self.LOSS_NAME_MSE, self.LOSS_NAME_KL, self.LOSS_NAME_JS): raise ValueError(f"Invalid loss_name: {loss_name}") self.loss_name = loss_name self.original_image_height = original_image_height self.original_image_width = original_image_width self.downsampled_image_height = downsampled_image_height self.downsampled_image_width = downsampled_image_width self.uniform_heatmaps = uniform_heatmaps self.prob_threshold = torch.tensor(prob_threshold, dtype=torch.float) if self.loss_name == "unimodal_mse": self.loss = None elif self.loss_name == "unimodal_kl": self.loss = kl_div_loss_2d elif self.loss_name == "unimodal_js": self.loss = js_div_loss_2d else: raise NotImplementedError
[docs] def remove_nans( self, targets: Float[torch.Tensor, "batch num_keypoints heatmap_height heatmap_width"], predictions: Float[torch.Tensor, "batch num_keypoints heatmap_height heatmap_width"], confidences: Float[torch.Tensor, "batch num_keypoints"], ) -> tuple[ Float[torch.Tensor, "num_valid_keypoints heatmap_height heatmap_width"], Float[torch.Tensor, "num_valid_keypoints heatmap_height heatmap_width"], ]: """Remove nans from targets and predictions. Args: targets: (batch, num_keypoints, heatmap_height, heatmap_width) predictions: (batch, num_keypoints, heatmap_height, heatmap_width) confidences: (batch, num_keypoints) Returns: clean targets: concatenated across different images and keypoints clean predictions: concatenated across different images and keypoints """ # use confidences to get rid of unsupervised targets with likely occlusions idxs_ignore = confidences < self.prob_threshold return targets[~idxs_ignore], predictions[~idxs_ignore]
[docs] def compute_loss( self, targets: Float[torch.Tensor, "num_valid_keypoints heatmap_height heatmap_width"], predictions: Float[torch.Tensor, "num_valid_keypoints heatmap_height heatmap_width"], ) -> torch.Tensor: """Compute per-element divergence between ideal unimodal targets and predicted heatmaps. Args: targets: ideal unimodal heatmaps derived from predicted keypoint coordinates. predictions: predicted heatmaps from the network. Returns: Element-wise loss tensor. """ if self.loss_name == "unimodal_mse": return F.mse_loss(targets, predictions, reduction="none") elif self.loss_name == "unimodal_kl": assert self.loss is not None return self.loss( predictions.unsqueeze(0) + 1e-10, targets.unsqueeze(0) + 1e-10, reduction="none", ) elif self.loss_name == "unimodal_js": assert self.loss is not None return self.loss( predictions.unsqueeze(0) + 1e-10, targets.unsqueeze(0) + 1e-10, reduction="none", ) else: raise NotImplementedError
[docs] def __call__( self, keypoints_pred_augmented: Float[torch.Tensor, "batch two_x_num_keypoints"], heatmaps_pred: Float[torch.Tensor, "batch num_keypoints heatmap_height heatmap_width"], confidences: Float[torch.Tensor, "batch num_keypoints"], stage: Literal["train", "val", "test"] | None = None, **kwargs: Any, ) -> tuple[Float[torch.Tensor, ""], list[dict]]: """Compute unimodal loss. Args: keypoints_pred_augmented: these are in the augmented image space heatmaps_pred: also in the augmented space, matching the keypoints_pred_augmented """ # turn keypoint predictions into unimodal heatmaps keypoints_pred = keypoints_pred_augmented.reshape(keypoints_pred_augmented.shape[0], -1, 2) heatmaps_ideal = generate_heatmaps( # this process doesn't compute gradients keypoints=keypoints_pred, height=self.original_image_height, width=self.original_image_width, output_shape=(self.downsampled_image_height, self.downsampled_image_width), uniform_heatmaps=self.uniform_heatmaps, ) # remove invisible keypoints according to confidences clean_targets, clean_predictions = self.remove_nans( targets=heatmaps_ideal, predictions=heatmaps_pred, confidences=confidences ) # compute loss just on the valid heatmaps elementwise_loss = self.compute_loss( targets=clean_targets, predictions=clean_predictions ) scalar_loss = self.reduce_loss(elementwise_loss, method="mean") logs = self.log_loss(loss=scalar_loss, stage=stage) return scalar_loss, logs
[docs] class RegressionMSELoss(Loss): """MSE loss between ground truth and predicted coordinates.""" loss_name = "regression"
[docs] def __init__( self, data_module: BaseDataModule | UnlabeledDataModule | None = None, epsilon: float = 0.0, log_weight: float = 0.0, **kwargs: Any, ) -> None: """Initialize RegressionMSELoss. Args: data_module: data module providing access to datasets; passed to the parent class. epsilon: loss values below this threshold are zeroed out. log_weight: final weight in front of the loss term in the objective function is computed as ``1.0 / (2.0 * exp(log_weight))``. """ super().__init__(data_module=data_module, epsilon=epsilon, log_weight=log_weight)
[docs] def remove_nans( self, targets: Float[torch.Tensor, "batch two_x_num_keypoints"], predictions: Float[torch.Tensor, "batch two_x_num_keypoints"], ) -> tuple[ Float[torch.Tensor, "num_valid_keypoints"], Float[torch.Tensor, "num_valid_keypoints"], ]: """Mask out NaN coordinate entries from targets and predictions. Args: targets: ground-truth (x, y) keypoints; NaN entries indicate unlabeled keypoints. predictions: predicted (x, y) keypoints. Returns: Tuple of ``(clean_targets, clean_predictions)`` with NaN positions removed. """ mask = targets == targets # keypoints is not none, bool targets_masked = torch.masked_select(targets, mask) predictions_masked = torch.masked_select(predictions, mask) return targets_masked, predictions_masked
[docs] def compute_loss( self, targets: Float[torch.Tensor, "batch_x_two_x_num_keypoints"], predictions: Float[torch.Tensor, "batch_x_two_x_num_keypoints"], ) -> Float[torch.Tensor, "batch_x_two_x_num_keypoints"]: """Compute element-wise MSE between target and predicted coordinates. Args: targets: ground-truth coordinate values. predictions: predicted coordinate values. Returns: Element-wise squared error tensor. """ loss = F.mse_loss(targets, predictions, reduction="none") return loss
[docs] def __call__( self, keypoints_targ: Float[torch.Tensor, "batch two_x_num_keypoints"], keypoints_pred: Float[torch.Tensor, "batch two_x_num_keypoints"], stage: Literal["train", "val", "test"] | None = None, **kwargs: Any, ) -> tuple[Float[torch.Tensor, ""], list[dict]]: """Compute the regression MSE loss for a batch of predicted keypoints. Args: keypoints_targ: ground-truth (x, y) keypoints; NaN entries are ignored. keypoints_pred: predicted (x, y) keypoints. stage: training stage for logging. **kwargs: ignored extra keyword arguments. Returns: Tuple of scalar loss and list of logging dicts. """ clean_targets, clean_predictions = self.remove_nans( targets=keypoints_targ, predictions=keypoints_pred ) elementwise_loss = self.compute_loss( targets=clean_targets, predictions=clean_predictions ) scalar_loss = self.reduce_loss(elementwise_loss, method="mean") logs = self.log_loss(loss=scalar_loss, stage=stage) return scalar_loss, logs
[docs] class RegressionRMSELoss(RegressionMSELoss): """Root MSE loss between ground truth and predicted coordinates.""" loss_name = "rmse"
[docs] def __init__( self, data_module: BaseDataModule | UnlabeledDataModule | None = None, epsilon: float = 0.0, log_weight: float = 0.0, **kwargs: Any, ) -> None: """Initialize RegressionRMSELoss. Args: data_module: data module providing access to datasets; passed to the parent class. epsilon: loss values below this threshold are zeroed out. log_weight: final weight in front of the loss term in the objective function is computed as ``1.0 / (2.0 * exp(log_weight))``. """ super().__init__(data_module=data_module, epsilon=epsilon, log_weight=log_weight)
[docs] def compute_loss( self, targets: Float[torch.Tensor, "batch_x_two_x_num_keypoints"], predictions: Float[torch.Tensor, "batch_x_two_x_num_keypoints"], ) -> Float[torch.Tensor, "batch_x_num_keypoints"]: """Compute per-keypoint Euclidean distance between predicted and target coordinates. Args: targets: ground-truth (x, y) keypoint coordinates, flattened. predictions: predicted (x, y) keypoint coordinates, flattened. Returns: per-keypoint RMSE (Euclidean pixel distance). """ targs = targets.reshape(-1, 2) preds = predictions.reshape(-1, 2) loss = torch.mean(F.mse_loss(targs, preds, reduction="none"), dim=1) return torch.sqrt(loss)
[docs] class PairwiseProjectionsLoss(Loss): """Penalize projections from each pair of cameras into 3D world space.""" loss_name = "supervised_pairwise_projections"
[docs] def __init__(self, log_weight: float = 0.0, **kwargs: Any) -> None: """Initialize PairwiseProjectionsLoss. Args: log_weight: final weight in front of the loss term in the objective function is computed as ``1.0 / (2.0 * exp(log_weight))``. """ super().__init__(log_weight=log_weight)
[docs] def remove_nans( self, loss: Float[torch.Tensor, "batch cam_pairs num_keypoints"], ) -> Float[torch.Tensor, "valid_losses"]: """Select only valid (non-NaN) loss entries. Args: loss: per-pair per-keypoint loss tensor; NaN indicates a missing keypoint. Returns: Flat tensor of valid loss values, or a zero scalar if none are valid. """ mask = ~torch.isnan(loss) valid_losses = torch.masked_select(loss, mask) if valid_losses.numel() == 0: # No valid losses, return zero that preserves gradients # Use torch.where to avoid nan*0.0 issues dummy_loss = torch.where(mask, loss, torch.zeros_like(loss)) return dummy_loss.sum() # This will be 0.0 and preserve gradients else: return valid_losses
[docs] def compute_loss( self, targets: Float[torch.Tensor, "batch num_keypoints 3"], predictions: Float[torch.Tensor, "batch cam_pairs num_keypoints 3"], ) -> Float[torch.Tensor, "batch cam_pairs num_keypoints"]: """Compute L2 distance between 3D target and per-camera-pair predicted 3D keypoints. Args: targets: ground-truth 3D keypoints of shape ``(batch, num_keypoints, 3)``. predictions: predicted 3D points from pairwise triangulation, shape ``(batch, cam_pairs, num_keypoints, 3)``. Returns: Per-pair per-keypoint L2 distances; NaN where targets or predictions are missing. """ # Check for NaN targets AND predictions nan_targets = torch.isnan(targets).any(dim=-1) # [batch, num_keypoints] nan_predictions = torch.isnan(predictions).any(dim=-1) # [batch, cam_pairs, num_keypoints] # Expand target NaN mask to match prediction dimensions nan_targets_expanded = nan_targets.unsqueeze(1) # [batch, 1, num_keypoints] # Combined NaN mask combined_nan_mask = \ nan_targets_expanded | nan_predictions # [batch, cam_pairs, num_keypoints] # Create clean targets and predictions - replace NaNs with zeros and detach clean_targets = torch.where( nan_targets.unsqueeze(-1), # [batch, num_keypoints, 1] torch.zeros_like(targets).detach(), targets, ) clean_predictions = torch.where( combined_nan_mask.unsqueeze(-1), # [batch, cam_pairs, num_keypoints, 1] torch.zeros_like(predictions).detach(), predictions, ) # Compute loss with clean tensors loss = torch.linalg.norm(clean_targets.unsqueeze(1) - clean_predictions, ord=2, dim=-1) # Set loss to NaN where either targets or predictions were originally NaN loss = torch.where( combined_nan_mask, torch.tensor(float('nan'), device=loss.device, dtype=loss.dtype), loss, ) return loss
[docs] def __call__( self, keypoints_targ_3d: Float[torch.Tensor, "batch num_keypoints 3"], keypoints_pred_3d: Float[torch.Tensor, "batch cam_pairs num_keypoints 3"], stage: Literal["train", "val", "test"] | None = None, **kwargs: Any, ) -> tuple[Float[torch.Tensor, ""], list[dict]]: """Compute the pairwise projections loss. Args: keypoints_targ_3d: ground-truth 3D keypoints of shape ``(batch, num_keypoints, 3)``. keypoints_pred_3d: predicted 3D keypoints per camera pair, shape ``(batch, cam_pairs, num_keypoints, 3)``. stage: training stage for logging. **kwargs: ignored extra keyword arguments. Returns: Tuple of scalar loss and list of logging dicts. Raises: ValueError: if either ``keypoints_targ_3d`` or ``keypoints_pred_3d`` is ``None``. """ # check if 3D keypoints are available if keypoints_targ_3d is None or keypoints_pred_3d is None: raise ValueError( f"3D keypoints not available for {stage} stage. " "Camera params file is required but not found;" "Turn off supervised_pairwise_projections loss to avoid this error." ) elementwise_loss = self.compute_loss( targets=keypoints_targ_3d, predictions=keypoints_pred_3d, ) clean_loss = self.remove_nans(loss=elementwise_loss) scalar_loss = self.reduce_loss(clean_loss, method="mean") logs = self.log_loss(loss=scalar_loss, stage=stage) return scalar_loss, logs
[docs] class ReprojectionHeatmapLoss(Loss): """Penalize error between predicted 2D->3D->2D->heatmap and ground truth heatmap.""" loss_name = "supervised_reprojection_heatmap_mse"
[docs] def __init__( self, original_image_height: int, original_image_width: int, downsampled_image_height: int, downsampled_image_width: int, log_weight: float = 0.0, uniform_heatmaps: bool = False, **kwargs: Any, ) -> None: """Initialize ReprojectionHeatmapLoss. Converts 2D reprojected keypoints (obtained by projecting 3D triangulated predictions back into each camera's image plane) into heatmaps and compares them with the ground truth heatmaps using pixel-wise MSE. Args: original_image_height: height of the full-resolution input image in pixels. original_image_width: width of the full-resolution input image in pixels. downsampled_image_height: height of the heatmap output (after backbone downsampling). downsampled_image_width: width of the heatmap output. log_weight: final weight in front of the loss term in the objective function is computed as ``1.0 / (2.0 * exp(log_weight))``. uniform_heatmaps: if ``True``, generate uniform (flat) target heatmaps for NaN ground truth keypoints instead of ignoring them in the loss. """ super().__init__(log_weight=log_weight) self.original_image_height = original_image_height self.original_image_width = original_image_width self.downsampled_image_height = downsampled_image_height self.downsampled_image_width = downsampled_image_width self.uniform_heatmaps = uniform_heatmaps
[docs] def remove_nans( self, loss: Float[torch.Tensor, "batch num_keypoints heatmap_height heatmap_width"], targets: Float[torch.Tensor, "batch num_keypoints heatmap_height heatmap_width"], ) -> Float[torch.Tensor, "valid_losses"]: """Select only valid (non-zero-target) loss entries. Args: loss: element-wise MSE loss tensor. targets: ground-truth heatmaps; all-zero heatmaps indicate unlabeled keypoints. Returns: Flat tensor of valid loss values, or a zero scalar if none are valid. """ # Create mask for valid keypoints (non-zero targets) squeezed_targets = targets.reshape(targets.shape[0], targets.shape[1], -1) valid_keypoints = ~torch.all(squeezed_targets == 0.0, dim=-1) # [batch, num_keypoints] # Expand mask to match loss dimensions valid_mask = valid_keypoints.unsqueeze(-1).unsqueeze(-1) # [batch, num_keypoints, 1, 1] valid_mask = valid_mask.expand_as(loss) # [batch, num_keypoints, h, w] valid_losses = torch.masked_select(loss, valid_mask) if valid_losses.numel() == 0: # No valid losses, return zero that preserves gradients dummy_loss = torch.where(valid_mask, loss, torch.zeros_like(loss)) return dummy_loss.sum() # This will be 0.0 and preserve gradients else: return valid_losses
[docs] def compute_loss( self, targets: Float[torch.Tensor, "batch_x_num_keypoints heatmap_height heatmap_width"], predictions: Float[torch.Tensor, "batch_x_num_keypoints heatmap_height heatmap_width"], ) -> Float[torch.Tensor, "batch_x_num_keypoints heatmap_height heatmap_width"]: """Compute pixel-wise MSE between reprojected and ground-truth heatmaps. Args: targets: ground-truth heatmaps. predictions: heatmaps generated from reprojected 2D keypoints. Returns: Element-wise MSE scaled by the number of heatmap pixels. """ h = targets.shape[1] w = targets.shape[2] # multiply by number of pixels in heatmap to standardize loss range loss = F.mse_loss(targets, predictions, reduction="none") * h * w return loss
[docs] def __call__( self, heatmaps_targ: Float[torch.Tensor, "batch num_keypoints heatmap_height heatmap_width"], keypoints_pred_2d_reprojected: Float[torch.Tensor, "batch num_keypoints 2"], stage: Literal["train", "val", "test"] | None = None, **kwargs: Any, ) -> tuple[Float[torch.Tensor, ""], list[dict]]: """Compute the reprojection heatmap loss. Args: heatmaps_targ: ground-truth heatmaps. keypoints_pred_2d_reprojected: 2D keypoints obtained by projecting triangulated 3D predictions back into each camera, shape ``(batch, num_keypoints, 2)``. stage: training stage for logging. **kwargs: ignored extra keyword arguments. Returns: Tuple of scalar loss and list of logging dicts. Raises: ValueError: if ``keypoints_pred_2d_reprojected`` is ``None``. """ # check if reprojected keypoints are available if keypoints_pred_2d_reprojected is None: raise ValueError( f"Reprojected keypoints not available for {stage} stage. " "Camera params file is required but not found;" "Turn off supervised_reprojection_heatmap loss to avoid this error." ) # create heatmaps from 2d reprojections heatmaps_pred = generate_heatmaps( keypoints=keypoints_pred_2d_reprojected, height=self.original_image_height, width=self.original_image_width, output_shape=(self.downsampled_image_height, self.downsampled_image_width), uniform_heatmaps=self.uniform_heatmaps, keep_gradients=True, ) elementwise_loss = self.compute_loss(targets=heatmaps_targ, predictions=heatmaps_pred) clean_loss = self.remove_nans(loss=elementwise_loss, targets=heatmaps_targ) scalar_loss = self.reduce_loss(clean_loss, method="mean") logs = self.log_loss(loss=scalar_loss, stage=stage) return scalar_loss, logs