"""Supervised and unsupervised losses implemented in pytorch.
The lightning pose package defines each loss as its own class; an initialized loss
object, in addition to computing the loss, stores hyperparameters related to the loss
(weight in the final objective funcion, epsilon-insensitivity parameter, etc.)
A separate LossFactory class (defined in lightning_pose.losses.factory) collects all
losses for a given model and orchestrates their execution, logging, etc.
The general flow of each loss class is as follows:
- input: predicted and ground truth data
- step 0: remove ground truth samples containing nans if desired
- step 1: compute loss for each batch element/keypoint/etc
- step 2: epsilon-insensitivity: set loss to zero for any batch element with loss < epsilon
- step 3: reduce loss (usually mean)
- step 4: log values to a dict
- step 5: return loss
"""
import os
from typing import Any, Literal
import torch
from jaxtyping import Float
from kornia.losses import js_div_loss_2d, kl_div_loss_2d
from omegaconf import ListConfig
from torch.nn import functional as F
from lightning_pose.data.datamodules import BaseDataModule, UnlabeledDataModule
from lightning_pose.data.utils import generate_heatmaps
from lightning_pose.utils.pca import KeypointPCA
# to ignore imports for sphix-autoapidoc
__all__ = [
"Loss",
"HeatmapLoss",
"HeatmapMSELoss",
"HeatmapKLLoss",
"HeatmapJSLoss",
"PCALoss",
"TemporalLoss",
"TemporalHeatmapLoss",
"UnimodalLoss",
"RegressionMSELoss",
"RegressionRMSELoss",
"PairwiseProjectionsLoss",
"ReprojectionHeatmapLoss",
]
_DEFAULT_TORCH_DEVICE = "cpu"
if torch.cuda.is_available():
# When running with multiple GPUs, the LOCAL_RANK variable correctly
# contains the DDP Local Rank, which is also the cuda device index.
_DEFAULT_TORCH_DEVICE = f"cuda:{int(os.environ.get('LOCAL_RANK', '0'))}"
[docs]
class Loss:
"""Parent class for all losses."""
loss_name: str
[docs]
def __init__(
self,
data_module: BaseDataModule | UnlabeledDataModule | None = None,
epsilon: float | list[float] = 0.0,
log_weight: float = 0.0,
**kwargs: Any,
) -> None:
"""
Args:
data_module: give losses access to data for computing data-specific loss params
epsilon: loss values below epsilon will be zeroed out
log_weight: final weight in front of the loss term in the objective function is
computed as ``1.0 / (2.0 * exp(log_weight))``.
"""
super().__init__()
self.data_module = data_module
# epsilon can either by a float or a list of floats
self.epsilon = torch.tensor(epsilon, dtype=torch.float)
self.log_weight = torch.tensor(log_weight, dtype=torch.float)
self.reduce_methods_dict = {"mean": torch.mean, "sum": torch.sum}
@property
def weight(self) -> Float[torch.Tensor, ""]:
"""Scalar loss weight computed as ``1 / (2 * exp(log_weight))``.
Returns:
Positive scalar weight tensor.
"""
# weight = \sigma where our trainable parameter is \log(\sigma^2).
# i.e., we take the parameter as it is in the config and exponentiate it to
# enforce positivity
weight = 1.0 / (2.0 * torch.exp(self.log_weight))
return weight
[docs]
def remove_nans(self, **kwargs: Any) -> Any:
"""Remove NaN entries from inputs before computing the loss.
Subclasses must override this method to implement the appropriate NaN-masking strategy.
Raises:
NotImplementedError: always, unless overridden by a subclass.
"""
# find nans in the targets, and do a masked_select operation
raise NotImplementedError
[docs]
def compute_loss(self, **kwargs: Any) -> torch.Tensor:
"""Compute the element-wise loss between targets and predictions.
Subclasses must override this method.
Raises:
NotImplementedError: always, unless overridden by a subclass.
"""
raise NotImplementedError
[docs]
def rectify_epsilon(self, loss: torch.Tensor) -> torch.Tensor:
"""Zero out loss values below the epsilon threshold (epsilon-insensitive loss).
Args:
loss: element-wise loss tensor.
Returns:
Loss tensor with values below ``self.epsilon`` set to zero via ReLU.
"""
# loss values below epsilon as masked to zero
loss = F.relu(loss - self.epsilon)
return loss
[docs]
def reduce_loss(self, loss: torch.Tensor, method: str = "mean") -> Float[torch.Tensor, ""]:
"""Reduce an element-wise loss tensor to a scalar.
Args:
loss: element-wise loss tensor.
method: reduction method; currently ``"mean"`` or ``"sum"``.
Returns:
Scalar loss tensor.
"""
return self.reduce_methods_dict[method](loss)
[docs]
def log_loss(
self,
loss: torch.Tensor,
stage: Literal["train", "val", "test"] | None,
) -> list[dict]:
"""Build a list of logging dicts for the scalar loss and its weight.
Args:
loss: scalar loss value to log.
stage: training stage prefix for the log key, or ``None`` to skip stage prefixing.
Returns:
List of dicts with ``"name"`` and ``"value"`` keys, one for the loss and one for the
weight.
"""
loss_dict = {
"name": f"{stage}_{self.loss_name}_loss",
"value": loss,
"prog_bar": True,
}
weight_dict = {
"name": f"{self.loss_name}_weight",
"value": self.weight,
}
return [loss_dict, weight_dict]
[docs]
def __call__(
self,
*args: Any,
**kwargs: Any,
) -> tuple[Float[torch.Tensor, ""], list[dict]]:
"""Execute the full loss pipeline and return a scalar loss plus logging dicts.
The standard pipeline is:
remove_nans → compute_loss → rectify_epsilon → reduce_loss → log_loss.
Subclasses must override this method to supply the correct arguments to each step.
Raises:
NotImplementedError: always, unless overridden by a subclass.
"""
# give us the flow of operations, and we overwrite the methods, and determine
# their arguments which are in buffer
# self.remove_nans()
# self.compute_loss()
# self.rectify_epsilon()
# self.reduce_loss()
# self.log_loss()
# return scalar_loss, logs
raise NotImplementedError
[docs]
class HeatmapLoss(Loss):
"""Parent class for different heatmap losses (MSE, Wasserstein, etc)."""
[docs]
def __init__(
self,
data_module: BaseDataModule | UnlabeledDataModule | None = None,
log_weight: float = 0.0,
**kwargs: Any,
) -> None:
"""Initialize HeatmapLoss.
Args:
data_module: data module providing access to datasets; passed to the parent class.
log_weight: final weight in front of the loss term in the objective function is
computed as ``1.0 / (2.0 * exp(log_weight))``.
"""
super().__init__(data_module=data_module, log_weight=log_weight)
[docs]
def remove_nans(
self,
targets: Float[torch.Tensor, "batch num_keypoints heatmap_height heatmap_width"],
predictions: Float[torch.Tensor, "batch num_keypoints heatmap_height heatmap_width"],
) -> tuple[
Float[torch.Tensor, "num_valid_keypoints heatmap_height heatmap_width"],
Float[torch.Tensor, "num_valid_keypoints heatmap_height heatmap_width"],
]:
"""Remove heatmap entries where all target pixels are zero (NaN/unlabeled keypoints).
Args:
targets: ground-truth heatmaps.
predictions: predicted heatmaps.
Returns:
Tuple of ``(clean_targets, clean_predictions)`` with all-zero target rows removed.
"""
squeezed_targets = targets.reshape(targets.shape[0], targets.shape[1], -1)
idxs_ignore = torch.all(squeezed_targets == 0.0, dim=-1)
return targets[~idxs_ignore], predictions[~idxs_ignore]
[docs]
def compute_loss(self, **kwargs: Any) -> torch.Tensor:
"""Compute element-wise divergence between target and predicted heatmaps.
Subclasses must override this method with the specific divergence measure.
Raises:
NotImplementedError: always, unless overridden by a subclass.
"""
raise NotImplementedError
[docs]
def __call__(
self,
heatmaps_targ: Float[torch.Tensor, "batch num_keypoints heatmap_height heatmap_width"],
heatmaps_pred: Float[torch.Tensor, "batch num_keypoints heatmap_height heatmap_width"],
stage: Literal["train", "val", "test"] | None = None,
**kwargs: Any,
) -> tuple[Float[torch.Tensor, ""], list[dict]]:
"""Compute the heatmap loss.
Args:
heatmaps_targ: ground-truth heatmaps.
heatmaps_pred: predicted heatmaps.
stage: training stage for logging.
**kwargs: ignored extra keyword arguments.
Returns:
Tuple of scalar loss and list of logging dicts.
"""
# give us the flow of operations, and we overwrite the methods, and determine
# their arguments which are in buffer
clean_targets, clean_predictions = self.remove_nans(
targets=heatmaps_targ, predictions=heatmaps_pred
)
elementwise_loss = self.compute_loss(
targets=clean_targets, predictions=clean_predictions
)
scalar_loss = self.reduce_loss(elementwise_loss, method="mean")
logs = self.log_loss(loss=scalar_loss, stage=stage)
return scalar_loss, logs
[docs]
class HeatmapMSELoss(HeatmapLoss):
"""MSE loss between heatmaps."""
loss_name = "heatmap_mse"
[docs]
def __init__(
self,
data_module: BaseDataModule | UnlabeledDataModule | None = None,
log_weight: float = 0.0,
**kwargs: Any,
) -> None:
"""Initialize HeatmapMSELoss.
Args:
data_module: data module providing access to datasets; passed to the parent class.
log_weight: final weight in front of the loss term in the objective function is
computed as ``1.0 / (2.0 * exp(log_weight))``.
"""
super().__init__(data_module=data_module, log_weight=log_weight)
[docs]
def compute_loss(
self,
targets: Float[torch.Tensor, "batch_x_num_keypoints heatmap_height heatmap_width"],
predictions: Float[torch.Tensor, "batch_x_num_keypoints heatmap_height heatmap_width"],
) -> Float[torch.Tensor, "batch_x_num_keypoints heatmap_height heatmap_width"]:
"""Compute pixel-wise MSE between target and predicted heatmaps.
Multiplies by the number of heatmap pixels (h * w) to keep the loss magnitude
consistent across different heatmap resolutions.
Args:
targets: ground-truth heatmaps.
predictions: model-predicted heatmaps.
Returns:
element-wise MSE scaled by heatmap area.
"""
h = targets.shape[1]
w = targets.shape[2]
# multiply by number of pixels in heatmap to standardize loss range
loss = F.mse_loss(targets, predictions, reduction="none") * h * w
return loss
[docs]
class HeatmapKLLoss(HeatmapLoss):
"""Kullback-Leibler loss between heatmaps."""
loss_name = "heatmap_kl"
[docs]
def __init__(
self,
data_module: BaseDataModule | UnlabeledDataModule | None = None,
log_weight: float = 0.0,
**kwargs: Any,
) -> None:
"""Initialize HeatmapKLLoss.
Args:
data_module: data module providing access to datasets; passed to the parent class.
log_weight: final weight in front of the loss term in the objective function is
computed as ``1.0 / (2.0 * exp(log_weight))``.
"""
super().__init__(data_module=data_module, log_weight=log_weight)
self.loss = kl_div_loss_2d
[docs]
def compute_loss(
self,
targets: Float[torch.Tensor, "num_valid_keypoints heatmap_height heatmap_width"],
predictions: Float[torch.Tensor, "num_valid_keypoints heatmap_height heatmap_width"],
) -> Float[torch.Tensor, "num_valid_keypoints"]:
"""Compute per-keypoint KL divergence between target and predicted heatmaps.
Args:
targets: ground-truth heatmaps.
predictions: model-predicted heatmaps.
Returns:
per-keypoint KL divergence values.
"""
loss = self.loss(
predictions.unsqueeze(0) + 1e-10,
targets.unsqueeze(0) + 1e-10,
reduction="none",
)
return loss[0]
[docs]
class HeatmapJSLoss(HeatmapLoss):
"""Jensen-Shannon loss between heatmaps."""
loss_name = "heatmap_js"
[docs]
def __init__(
self,
data_module: BaseDataModule | UnlabeledDataModule | None = None,
log_weight: float = 0.0,
**kwargs: Any,
) -> None:
"""Initialize HeatmapJSLoss.
Args:
data_module: data module providing access to datasets; passed to the parent class.
log_weight: final weight in front of the loss term in the objective function is
computed as ``1.0 / (2.0 * exp(log_weight))``.
"""
super().__init__(data_module=data_module, log_weight=log_weight)
self.loss = js_div_loss_2d
[docs]
def compute_loss(
self,
targets: Float[torch.Tensor, "num_valid_keypoints heatmap_height heatmap_width"],
predictions: Float[torch.Tensor, "num_valid_keypoints heatmap_height heatmap_width"],
) -> Float[torch.Tensor, "num_valid_keypoints"]:
"""Compute per-keypoint Jensen-Shannon divergence between target and predicted heatmaps.
Args:
targets: ground-truth heatmaps.
predictions: model-predicted heatmaps.
Returns:
per-keypoint JS divergence values.
"""
loss = self.loss(
predictions.unsqueeze(0) + 1e-10,
targets.unsqueeze(0) + 1e-10,
reduction="none",
)
return loss[0]
[docs]
class PCALoss(Loss):
"""Penalize predictions that fall outside a low-dimensional subspace."""
# define all valid loss names as class constants
LOSS_NAME_MULTIVIEW = "pca_multiview"
LOSS_NAME_SINGLEVIEW = "pca_singleview"
[docs]
def __init__(
self,
loss_name: Literal["pca_singleview", "pca_multiview"],
components_to_keep: int | float = 0.95,
empirical_epsilon_percentile: float = 99.0,
epsilon: float | None = None,
empirical_epsilon_multiplier: float = 1.0,
mirrored_column_matches: ListConfig | list | None = None,
columns_for_singleview_pca: ListConfig | list | None = None,
data_module: BaseDataModule | UnlabeledDataModule | None = None,
log_weight: float = 0.0,
device: str | torch.device = _DEFAULT_TORCH_DEVICE,
centering_method: Literal["mean", "median"] | None = None,
**kwargs: Any,
) -> None:
"""Initialize PCALoss.
Fits a :class:`KeypointPCA` object on the training data and uses the resulting
low-dimensional subspace to penalize out-of-subspace predictions at training time.
Args:
loss_name: ``"pca_singleview"`` penalizes single-camera predictions;
``"pca_multiview"`` penalizes predictions that are inconsistent across views.
components_to_keep: passed to :class:`KeypointPCA`; see its docstring for details.
empirical_epsilon_percentile: percentile of the training-data reprojection error
used to set epsilon when ``epsilon`` is ``None``; in ``[0, 100]``.
epsilon: if not ``None``, use this fixed epsilon value and ignore
``empirical_epsilon_percentile``.
empirical_epsilon_multiplier: scalar multiplier applied to the empirically computed
epsilon before use.
mirrored_column_matches: required for ``"pca_multiview"``; see :class:`KeypointPCA`
for details.
columns_for_singleview_pca: subset of keypoint indices to use for singleview PCA;
``None`` uses all keypoints.
data_module: data module used by :class:`KeypointPCA` to extract training data.
log_weight: final weight in front of the loss term in the objective function is
computed as ``1.0 / (2.0 * exp(log_weight))``.
device: device on which PCA parameters are stored and loss is computed.
centering_method: if not ``None``, subtract the per-frame keypoint centroid before
fitting PCA. ``"mean"`` uses the arithmetic mean; ``"median"`` uses the median.
"""
super().__init__(data_module=data_module, log_weight=log_weight)
self.device = device
# validate against class constants
if loss_name not in (self.LOSS_NAME_MULTIVIEW, self.LOSS_NAME_SINGLEVIEW):
raise ValueError(f"Invalid loss_name: {loss_name}")
self.loss_name = loss_name
if loss_name == "pca_multiview":
if mirrored_column_matches is None:
raise ValueError("must provide mirrored_column_matches in data config")
# the current data_module contains datasets that are loaded using augmentations. the
# current solution is to pass the data module to KeypointPCA, which then passes it to
# DataExtractor; we will also pass a "no_augmentation" arg to DataExtractor which will
# rebuild the data module with only resizing augmentations, then extract the data.
# initialize keypoint pca module
# this module will fit pca on training data, and will define the error metric
# and fuction to be used in model training.
assert data_module is not None, 'PCALoss requires a data_module to fit PCA'
self.pca = KeypointPCA(
loss_type=self.loss_name,
data_module=data_module,
components_to_keep=components_to_keep,
empirical_epsilon_percentile=empirical_epsilon_percentile,
mirrored_column_matches=mirrored_column_matches,
columns_for_singleview_pca=columns_for_singleview_pca,
device=device,
centering_method=centering_method,
)
# compute all the parameters needed for the loss
self.pca()
# select epsilon based on constructor inputs
if epsilon is not None:
self.epsilon = torch.tensor(epsilon, dtype=torch.float, device=self.device)
print(f"Using absolute epsilon={epsilon:.2f} for pca loss; empirical epsilon ignored")
else:
# empirically compute epsilon, already converted to tensor
self.epsilon = self.pca.parameters["epsilon"] * empirical_epsilon_multiplier
print(
f"Using empirical epsilon={float(self.pca.parameters['epsilon']):.3f}"
f" * multiplier={float(empirical_epsilon_multiplier):.3f}"
f" -> total={float(self.epsilon):.3f} for {self.loss_name} loss",
)
[docs]
def remove_nans(self, **kwargs: Any) -> Any:
"""No-op for PCALoss; NaN handling is performed inside :meth:`compute_loss`."""
# find nans in the targets, and do a masked_select operation
pass
[docs]
def compute_loss(
self,
predictions: Float[torch.Tensor, "num_samples sample_dim"],
) -> Float[torch.Tensor, "num_samples _"]:
"""Compute per-sample PCA reprojection error.
Args:
predictions: predicted keypoint coordinates, shape ``(num_samples, sample_dim)``.
Returns:
Reprojection error per sample and keypoint.
"""
assert predictions.device == torch.device(self.device), (
predictions.device,
torch.device(self.device),
)
# compute either reprojection error or projection onto discarded evecs.
# they will vary in the last dim, hence -1.
return self.pca.compute_reprojection_error(data_arr=predictions)
[docs]
def __call__(
self,
keypoints_pred: torch.Tensor,
stage: Literal["train", "val", "test"] | None = None,
**kwargs: Any,
) -> tuple[Float[torch.Tensor, ""], list[dict]]:
"""Compute the PCA loss for a batch of predicted keypoints.
Args:
keypoints_pred: predicted keypoint coordinates.
stage: training stage for logging.
**kwargs: ignored extra keyword arguments.
Returns:
Tuple of scalar loss and list of logging dicts.
"""
assert keypoints_pred.device == torch.device(self.device), (
keypoints_pred.device,
torch.device(self.device),
)
keypoints_pred = self.pca._format_data(data_arr=keypoints_pred)
elementwise_loss = self.compute_loss(predictions=keypoints_pred)
epsilon_insensitive_loss = self.rectify_epsilon(loss=elementwise_loss)
scalar_loss = self.reduce_loss(epsilon_insensitive_loss, method="mean")
logs = self.log_loss(loss=scalar_loss, stage=stage)
return scalar_loss, logs
[docs]
class TemporalLoss(Loss):
"""Penalize temporal differences for each target.
Motion model: x_t = x_(t-1) + e_t, e_t ~ N(0, s)
"""
loss_name = "temporal"
[docs]
def __init__(
self,
data_module: BaseDataModule | UnlabeledDataModule | None = None,
epsilon: float | list[float] = 0.0,
prob_threshold: float = 0.0,
log_weight: float = 0.0,
**kwargs: Any,
) -> None:
"""Initialize TemporalLoss.
Args:
data_module: data module providing access to datasets; passed to the parent class.
epsilon: loss values below this threshold are zeroed out. May be a scalar or a list
with one value per keypoint.
prob_threshold: predictions whose confidence is below this value are excluded from
the loss computation.
log_weight: final weight in front of the loss term in the objective function is
computed as ``1.0 / (2.0 * exp(log_weight))``.
"""
super().__init__(data_module=data_module, epsilon=epsilon, log_weight=log_weight)
self.prob_threshold = torch.tensor(prob_threshold, dtype=torch.float)
[docs]
def rectify_epsilon(
self, loss: Float[torch.Tensor, "batch_minus_one num_keypoints"]
) -> Float[torch.Tensor, "batch_minus_one num_keypoints"]:
"""Rectify supporting a list of epsilons, one per bodypart.
Not implemented in Loss class, because shapes of broadcasting may vary"""
# self.epsilon is a tensor initialized in parent class
# repeating for broadcasting.
# note: this unsqueezing doesn't affect anything if epsilon is a scalar tensor,
# but it does if it's a tensor with multiple elements.
epsilon = self.epsilon.unsqueeze(0).repeat(loss.shape[0], 1).to(loss.device)
return F.relu(loss - epsilon)
[docs]
def remove_nans(
self,
loss: Float[torch.Tensor, "batch_minus_one num_keypoints"],
confidences: Float[torch.Tensor, "batch num_keypoints"],
) -> Float[torch.Tensor, "batch_minus_one num_keypoints"]:
"""Zero out temporal difference losses where either neighboring frame is low-confidence.
Args:
loss: temporal difference losses of shape ``(batch-1, num_keypoints)``.
confidences: per-frame confidence scores of shape ``(batch, num_keypoints)``.
Returns:
Loss tensor with entries zeroed where confidence falls below ``self.prob_threshold``.
"""
# find nans in the targets, and do a masked_select operation
# get rid of unsupervised targets with extremely uncertain predictions or likely occlusions
idxs_ignore = confidences < self.prob_threshold
# ignore the loss values in the diff where one of the heatmaps is 'nan'
union_idxs_ignore = torch.zeros(
(confidences.shape[0] - 1, confidences.shape[1]),
dtype=torch.bool,
device=loss.device,
)
for i in range(confidences.shape[0] - 1):
union_idxs_ignore[i] = torch.logical_or(idxs_ignore[i], idxs_ignore[i + 1])
# clone loss and zero out the nan values
clean_loss = loss.clone()
clean_loss[union_idxs_ignore] = 0.0
return clean_loss
[docs]
def compute_loss(
self,
predictions: Float[torch.Tensor, "batch two_x_num_keypoints"],
) -> Float[torch.Tensor, "batch_minus_one num_keypoints"]:
"""Compute per-keypoint L2 temporal differences between consecutive frames.
Args:
predictions: predicted (x, y) keypoints of shape ``(batch, 2*num_keypoints)``.
Returns:
L2 norm of frame-to-frame differences, shape ``(batch-1, num_keypoints)``.
"""
# return shape: (batch - 1, num_targets)
diffs = torch.diff(predictions, dim=0)
# return shape: (batch - 1, num_keypoints, 2)
reshape = torch.reshape(diffs, (diffs.shape[0], -1, 2))
# return shape (batch - 1, num_keypoints)
loss = torch.linalg.norm(reshape, ord=2, dim=2)
return loss
[docs]
def __call__(
self,
keypoints_pred: Float[torch.Tensor, "batch two_x_num_keypoints"],
confidences: Float[torch.Tensor, "batch num_keypoints"] | None = None,
stage: Literal["train", "val", "test"] | None = None,
**kwargs: Any,
) -> tuple[Float[torch.Tensor, ""], list[dict]]:
"""Compute the temporal loss for a batch of predicted keypoints.
Args:
keypoints_pred: predicted (x, y) keypoints of shape ``(batch, 2*num_keypoints)``.
confidences: per-frame confidence scores; if provided, low-confidence frames are
masked out.
stage: training stage for logging.
**kwargs: ignored extra keyword arguments.
Returns:
Tuple of scalar loss and list of logging dicts.
"""
elementwise_loss = self.compute_loss(predictions=keypoints_pred)
# do remove nans with loss to remove temporal difference values
clean_loss = (
self.remove_nans(loss=elementwise_loss, confidences=confidences)
if confidences is not None
else elementwise_loss
)
epsilon_insensitive_loss = self.rectify_epsilon(loss=clean_loss)
scalar_loss = self.reduce_loss(epsilon_insensitive_loss, method="mean")
logs = self.log_loss(loss=scalar_loss, stage=stage)
return scalar_loss, logs
[docs]
class TemporalHeatmapLoss(Loss):
"""Penalize temporal differences for each heatmap.
Motion model: x_t = x_(t-1) + e_t, e_t ~ N(0, s)
"""
LOSS_NAME_MSE = "temporal_heatmap_mse"
LOSS_NAME_KL = "temporal_heatmap_kl"
[docs]
def __init__(
self,
loss_name: Literal["temporal_heatmap_mse", "temporal_heatmap_kl"],
data_module: BaseDataModule | UnlabeledDataModule | None = None,
epsilon: float | list[float] = 0.0,
prob_threshold: float = 0.0,
log_weight: float = 0.0,
**kwargs: Any,
) -> None:
"""Initialize TemporalHeatmapLoss.
Args:
loss_name: ``"temporal_heatmap_mse"`` uses pixel-wise MSE between consecutive
heatmaps; ``"temporal_heatmap_kl"`` uses the KL divergence.
data_module: data module providing access to datasets; passed to the parent class.
epsilon: loss values below this threshold are zeroed out. May be a scalar or a list
with one value per keypoint.
prob_threshold: predictions whose confidence is below this value are excluded from
the loss computation.
log_weight: final weight in front of the loss term in the objective function is
computed as ``1.0 / (2.0 * exp(log_weight))``.
"""
super().__init__(data_module=data_module, epsilon=epsilon, log_weight=log_weight)
if loss_name not in (self.LOSS_NAME_MSE, self.LOSS_NAME_KL):
raise ValueError(f"Invalid loss_name: {loss_name}")
self.loss_name = loss_name
if self.loss_name == "temporal_heatmap_mse":
self.hmloss = None
elif self.loss_name == "temporal_heatmap_kl":
self.hmloss = kl_div_loss_2d
else:
raise NotImplementedError
self.prob_threshold = torch.tensor(prob_threshold, dtype=torch.float)
[docs]
def rectify_epsilon(
self, loss: Float[torch.Tensor, "batch_minus_one num_valid_keypoints"]
) -> Float[torch.Tensor, "batch_minus_one num_valid_keypoints"]:
"""Rectify supporting a list of epsilons, one per bodypart.
Not implemented in Loss class, because shapes of broadcasting may vary"""
# self.epsilon is a tensor initialized in parent class
# repeating for broadcasting.
# note: this unsqueezing doesn't affect anything if epsilon is a scalar tensor,
# but it does if it's a tensor with multiple elements.
epsilon = self.epsilon.unsqueeze(0).repeat(loss.shape[0], 1).to(loss.device)
return F.relu(loss - epsilon)
[docs]
def remove_nans(
self,
confidences: Float[torch.Tensor, "batch num_keypoints"],
loss: Float[torch.Tensor, "batch_minus_one num_keypoints"],
) -> Float[torch.Tensor, "batch_minus_one num_keypoints"]:
"""Zero out heatmap temporal difference losses where adjacent frames are low-confidence.
Args:
confidences: per-frame confidence scores of shape ``(batch, num_keypoints)``.
loss: temporal difference losses of shape ``(batch-1, num_keypoints)``.
Returns:
Loss tensor with entries zeroed where confidence falls below ``self.prob_threshold``.
"""
# find nans in the targets, and do a masked_select operation
# get rid of unsupervised targets with extremely uncertain predictions or likely occlusions
idxs_ignore = confidences < self.prob_threshold
# ignore the loss values in the diff where one of the heatmaps is 'nan'
union_idxs_ignore = torch.zeros(
(confidences.shape[0] - 1, confidences.shape[1]), dtype=torch.bool
).to(loss.device)
for i in range(confidences.shape[0] - 1):
union_idxs_ignore[i] = torch.logical_or(idxs_ignore[i], idxs_ignore[i + 1])
loss[union_idxs_ignore] = 0.0
return loss
[docs]
def compute_loss(
self,
predictions: Float[torch.Tensor, "batch num_valid_keypoints heatmap_height heatmap_width"],
) -> Float[torch.Tensor, "batch_minus_one num_valid_keypoints"]:
"""Compute per-keypoint temporal heatmap differences between consecutive frames.
Args:
predictions: predicted heatmaps of shape
``(batch, num_keypoints, heatmap_height, heatmap_width)``.
Returns:
Per-keypoint temporal divergence of shape ``(batch-1, num_keypoints)``.
"""
# compute the differences between matching heatmaps for each keypoint
diffs = torch.zeros(
(predictions.shape[0] - 1, predictions.shape[1]), device=predictions.device
)
for i in range(diffs.shape[0]):
if self.loss_name == "temporal_heatmap_mse":
curr_mse = F.mse_loss(
predictions[i], predictions[i + 1], reduction="none"
).reshape(predictions.shape[1], -1)
diffs[i] = torch.mean(curr_mse, dim=-1)
elif self.loss_name == "temporal_heatmap_kl":
assert self.hmloss is not None
diffs[i] = self.hmloss(
predictions[i].unsqueeze(0) + 1e-10,
predictions[i + 1].unsqueeze(0) + 1e-10,
reduction="none",
)
return diffs
[docs]
def __call__(
self,
heatmaps_pred: Float[torch.Tensor, "batch num_keypoints heatmap_height heatmap_width"],
confidences: Float[torch.Tensor, "batch num_keypoints"],
stage: Literal["train", "val", "test"] | None = None,
**kwargs: Any,
) -> tuple[Float[torch.Tensor, ""], list[dict]]:
"""Compute the temporal heatmap loss for a batch of predicted heatmaps.
Args:
heatmaps_pred: predicted heatmaps of shape
``(batch, num_keypoints, heatmap_height, heatmap_width)``.
confidences: per-frame confidence scores of shape ``(batch, num_keypoints)``.
stage: training stage for logging.
**kwargs: ignored extra keyword arguments.
Returns:
Tuple of scalar loss and list of logging dicts.
"""
elementwise_loss = self.compute_loss(predictions=heatmaps_pred)
# remove nan after loss is computed to get rid of diff vals with a bad heatmap
clean_loss = self.remove_nans(confidences=confidences, loss=elementwise_loss)
epsilon_insensitive_loss = self.rectify_epsilon(loss=clean_loss)
scalar_loss = self.reduce_loss(epsilon_insensitive_loss, method="mean")
logs = self.log_loss(loss=scalar_loss, stage=stage)
return scalar_loss, logs
[docs]
class UnimodalLoss(Loss):
"""Encourage heatmaps to be unimodal using various measures."""
LOSS_NAME_MSE = "unimodal_mse"
LOSS_NAME_KL = "unimodal_kl"
LOSS_NAME_JS = "unimodal_js"
[docs]
def __init__(
self,
loss_name: Literal["unimodal_mse", "unimodal_kl", "unimodal_js"],
original_image_height: int,
original_image_width: int,
downsampled_image_height: int,
downsampled_image_width: int,
data_module: BaseDataModule | UnlabeledDataModule | None = None,
prob_threshold: float = 0.0,
log_weight: float = 0.0,
uniform_heatmaps: bool = False,
**kwargs: Any,
) -> None:
"""Initialize UnimodalLoss.
Generates an ideal unimodal heatmap from each predicted keypoint coordinate and
penalizes the difference between that ideal heatmap and the network's predicted heatmap.
Args:
loss_name: divergence measure to use. ``"unimodal_mse"`` uses pixel-wise MSE;
``"unimodal_kl"`` uses KL divergence; ``"unimodal_js"`` uses Jensen-Shannon
divergence.
original_image_height: height of the full-resolution input image in pixels, used
when generating ideal heatmaps.
original_image_width: width of the full-resolution input image in pixels.
downsampled_image_height: height of the heatmap output (after backbone downsampling).
downsampled_image_width: width of the heatmap output.
data_module: data module providing access to datasets; passed to the parent class.
prob_threshold: predictions whose confidence is below this value are excluded from
the loss computation.
log_weight: final weight in front of the loss term in the objective function is
computed as ``1.0 / (2.0 * exp(log_weight))``.
uniform_heatmaps: if ``True``, generate uniform (flat) target heatmaps for NaN
ground truth keypoints instead of ignoring them in the loss.
"""
super().__init__(data_module=data_module, log_weight=log_weight)
if loss_name not in (self.LOSS_NAME_MSE, self.LOSS_NAME_KL, self.LOSS_NAME_JS):
raise ValueError(f"Invalid loss_name: {loss_name}")
self.loss_name = loss_name
self.original_image_height = original_image_height
self.original_image_width = original_image_width
self.downsampled_image_height = downsampled_image_height
self.downsampled_image_width = downsampled_image_width
self.uniform_heatmaps = uniform_heatmaps
self.prob_threshold = torch.tensor(prob_threshold, dtype=torch.float)
if self.loss_name == "unimodal_mse":
self.loss = None
elif self.loss_name == "unimodal_kl":
self.loss = kl_div_loss_2d
elif self.loss_name == "unimodal_js":
self.loss = js_div_loss_2d
else:
raise NotImplementedError
[docs]
def remove_nans(
self,
targets: Float[torch.Tensor, "batch num_keypoints heatmap_height heatmap_width"],
predictions: Float[torch.Tensor, "batch num_keypoints heatmap_height heatmap_width"],
confidences: Float[torch.Tensor, "batch num_keypoints"],
) -> tuple[
Float[torch.Tensor, "num_valid_keypoints heatmap_height heatmap_width"],
Float[torch.Tensor, "num_valid_keypoints heatmap_height heatmap_width"],
]:
"""Remove nans from targets and predictions.
Args:
targets: (batch, num_keypoints, heatmap_height, heatmap_width)
predictions: (batch, num_keypoints, heatmap_height, heatmap_width)
confidences: (batch, num_keypoints)
Returns:
clean targets: concatenated across different images and keypoints
clean predictions: concatenated across different images and keypoints
"""
# use confidences to get rid of unsupervised targets with likely occlusions
idxs_ignore = confidences < self.prob_threshold
return targets[~idxs_ignore], predictions[~idxs_ignore]
[docs]
def compute_loss(
self,
targets: Float[torch.Tensor, "num_valid_keypoints heatmap_height heatmap_width"],
predictions: Float[torch.Tensor, "num_valid_keypoints heatmap_height heatmap_width"],
) -> torch.Tensor:
"""Compute per-element divergence between ideal unimodal targets and predicted heatmaps.
Args:
targets: ideal unimodal heatmaps derived from predicted keypoint coordinates.
predictions: predicted heatmaps from the network.
Returns:
Element-wise loss tensor.
"""
if self.loss_name == "unimodal_mse":
return F.mse_loss(targets, predictions, reduction="none")
elif self.loss_name == "unimodal_kl":
assert self.loss is not None
return self.loss(
predictions.unsqueeze(0) + 1e-10,
targets.unsqueeze(0) + 1e-10,
reduction="none",
)
elif self.loss_name == "unimodal_js":
assert self.loss is not None
return self.loss(
predictions.unsqueeze(0) + 1e-10,
targets.unsqueeze(0) + 1e-10,
reduction="none",
)
else:
raise NotImplementedError
[docs]
def __call__(
self,
keypoints_pred_augmented: Float[torch.Tensor, "batch two_x_num_keypoints"],
heatmaps_pred: Float[torch.Tensor, "batch num_keypoints heatmap_height heatmap_width"],
confidences: Float[torch.Tensor, "batch num_keypoints"],
stage: Literal["train", "val", "test"] | None = None,
**kwargs: Any,
) -> tuple[Float[torch.Tensor, ""], list[dict]]:
"""Compute unimodal loss.
Args:
keypoints_pred_augmented: these are in the augmented image space
heatmaps_pred: also in the augmented space, matching the keypoints_pred_augmented
"""
# turn keypoint predictions into unimodal heatmaps
keypoints_pred = keypoints_pred_augmented.reshape(keypoints_pred_augmented.shape[0], -1, 2)
heatmaps_ideal = generate_heatmaps( # this process doesn't compute gradients
keypoints=keypoints_pred,
height=self.original_image_height,
width=self.original_image_width,
output_shape=(self.downsampled_image_height, self.downsampled_image_width),
uniform_heatmaps=self.uniform_heatmaps,
)
# remove invisible keypoints according to confidences
clean_targets, clean_predictions = self.remove_nans(
targets=heatmaps_ideal, predictions=heatmaps_pred, confidences=confidences
)
# compute loss just on the valid heatmaps
elementwise_loss = self.compute_loss(
targets=clean_targets, predictions=clean_predictions
)
scalar_loss = self.reduce_loss(elementwise_loss, method="mean")
logs = self.log_loss(loss=scalar_loss, stage=stage)
return scalar_loss, logs
[docs]
class RegressionMSELoss(Loss):
"""MSE loss between ground truth and predicted coordinates."""
loss_name = "regression"
[docs]
def __init__(
self,
data_module: BaseDataModule | UnlabeledDataModule | None = None,
epsilon: float = 0.0,
log_weight: float = 0.0,
**kwargs: Any,
) -> None:
"""Initialize RegressionMSELoss.
Args:
data_module: data module providing access to datasets; passed to the parent class.
epsilon: loss values below this threshold are zeroed out.
log_weight: final weight in front of the loss term in the objective function is
computed as ``1.0 / (2.0 * exp(log_weight))``.
"""
super().__init__(data_module=data_module, epsilon=epsilon, log_weight=log_weight)
[docs]
def remove_nans(
self,
targets: Float[torch.Tensor, "batch two_x_num_keypoints"],
predictions: Float[torch.Tensor, "batch two_x_num_keypoints"],
) -> tuple[
Float[torch.Tensor, "num_valid_keypoints"],
Float[torch.Tensor, "num_valid_keypoints"],
]:
"""Mask out NaN coordinate entries from targets and predictions.
Args:
targets: ground-truth (x, y) keypoints; NaN entries indicate unlabeled keypoints.
predictions: predicted (x, y) keypoints.
Returns:
Tuple of ``(clean_targets, clean_predictions)`` with NaN positions removed.
"""
mask = targets == targets # keypoints is not none, bool
targets_masked = torch.masked_select(targets, mask)
predictions_masked = torch.masked_select(predictions, mask)
return targets_masked, predictions_masked
[docs]
def compute_loss(
self,
targets: Float[torch.Tensor, "batch_x_two_x_num_keypoints"],
predictions: Float[torch.Tensor, "batch_x_two_x_num_keypoints"],
) -> Float[torch.Tensor, "batch_x_two_x_num_keypoints"]:
"""Compute element-wise MSE between target and predicted coordinates.
Args:
targets: ground-truth coordinate values.
predictions: predicted coordinate values.
Returns:
Element-wise squared error tensor.
"""
loss = F.mse_loss(targets, predictions, reduction="none")
return loss
[docs]
def __call__(
self,
keypoints_targ: Float[torch.Tensor, "batch two_x_num_keypoints"],
keypoints_pred: Float[torch.Tensor, "batch two_x_num_keypoints"],
stage: Literal["train", "val", "test"] | None = None,
**kwargs: Any,
) -> tuple[Float[torch.Tensor, ""], list[dict]]:
"""Compute the regression MSE loss for a batch of predicted keypoints.
Args:
keypoints_targ: ground-truth (x, y) keypoints; NaN entries are ignored.
keypoints_pred: predicted (x, y) keypoints.
stage: training stage for logging.
**kwargs: ignored extra keyword arguments.
Returns:
Tuple of scalar loss and list of logging dicts.
"""
clean_targets, clean_predictions = self.remove_nans(
targets=keypoints_targ, predictions=keypoints_pred
)
elementwise_loss = self.compute_loss(
targets=clean_targets, predictions=clean_predictions
)
scalar_loss = self.reduce_loss(elementwise_loss, method="mean")
logs = self.log_loss(loss=scalar_loss, stage=stage)
return scalar_loss, logs
[docs]
class RegressionRMSELoss(RegressionMSELoss):
"""Root MSE loss between ground truth and predicted coordinates."""
loss_name = "rmse"
[docs]
def __init__(
self,
data_module: BaseDataModule | UnlabeledDataModule | None = None,
epsilon: float = 0.0,
log_weight: float = 0.0,
**kwargs: Any,
) -> None:
"""Initialize RegressionRMSELoss.
Args:
data_module: data module providing access to datasets; passed to the parent class.
epsilon: loss values below this threshold are zeroed out.
log_weight: final weight in front of the loss term in the objective function is
computed as ``1.0 / (2.0 * exp(log_weight))``.
"""
super().__init__(data_module=data_module, epsilon=epsilon, log_weight=log_weight)
[docs]
def compute_loss(
self,
targets: Float[torch.Tensor, "batch_x_two_x_num_keypoints"],
predictions: Float[torch.Tensor, "batch_x_two_x_num_keypoints"],
) -> Float[torch.Tensor, "batch_x_num_keypoints"]:
"""Compute per-keypoint Euclidean distance between predicted and target coordinates.
Args:
targets: ground-truth (x, y) keypoint coordinates, flattened.
predictions: predicted (x, y) keypoint coordinates, flattened.
Returns:
per-keypoint RMSE (Euclidean pixel distance).
"""
targs = targets.reshape(-1, 2)
preds = predictions.reshape(-1, 2)
loss = torch.mean(F.mse_loss(targs, preds, reduction="none"), dim=1)
return torch.sqrt(loss)
[docs]
class PairwiseProjectionsLoss(Loss):
"""Penalize projections from each pair of cameras into 3D world space."""
loss_name = "supervised_pairwise_projections"
[docs]
def __init__(self, log_weight: float = 0.0, **kwargs: Any) -> None:
"""Initialize PairwiseProjectionsLoss.
Args:
log_weight: final weight in front of the loss term in the objective function is
computed as ``1.0 / (2.0 * exp(log_weight))``.
"""
super().__init__(log_weight=log_weight)
[docs]
def remove_nans(
self,
loss: Float[torch.Tensor, "batch cam_pairs num_keypoints"],
) -> Float[torch.Tensor, "valid_losses"]:
"""Select only valid (non-NaN) loss entries.
Args:
loss: per-pair per-keypoint loss tensor; NaN indicates a missing keypoint.
Returns:
Flat tensor of valid loss values, or a zero scalar if none are valid.
"""
mask = ~torch.isnan(loss)
valid_losses = torch.masked_select(loss, mask)
if valid_losses.numel() == 0:
# No valid losses, return zero that preserves gradients
# Use torch.where to avoid nan*0.0 issues
dummy_loss = torch.where(mask, loss, torch.zeros_like(loss))
return dummy_loss.sum() # This will be 0.0 and preserve gradients
else:
return valid_losses
[docs]
def compute_loss(
self,
targets: Float[torch.Tensor, "batch num_keypoints 3"],
predictions: Float[torch.Tensor, "batch cam_pairs num_keypoints 3"],
) -> Float[torch.Tensor, "batch cam_pairs num_keypoints"]:
"""Compute L2 distance between 3D target and per-camera-pair predicted 3D keypoints.
Args:
targets: ground-truth 3D keypoints of shape ``(batch, num_keypoints, 3)``.
predictions: predicted 3D points from pairwise triangulation, shape
``(batch, cam_pairs, num_keypoints, 3)``.
Returns:
Per-pair per-keypoint L2 distances; NaN where targets or predictions are missing.
"""
# Check for NaN targets AND predictions
nan_targets = torch.isnan(targets).any(dim=-1) # [batch, num_keypoints]
nan_predictions = torch.isnan(predictions).any(dim=-1) # [batch, cam_pairs, num_keypoints]
# Expand target NaN mask to match prediction dimensions
nan_targets_expanded = nan_targets.unsqueeze(1) # [batch, 1, num_keypoints]
# Combined NaN mask
combined_nan_mask = \
nan_targets_expanded | nan_predictions # [batch, cam_pairs, num_keypoints]
# Create clean targets and predictions - replace NaNs with zeros and detach
clean_targets = torch.where(
nan_targets.unsqueeze(-1), # [batch, num_keypoints, 1]
torch.zeros_like(targets).detach(),
targets,
)
clean_predictions = torch.where(
combined_nan_mask.unsqueeze(-1), # [batch, cam_pairs, num_keypoints, 1]
torch.zeros_like(predictions).detach(),
predictions,
)
# Compute loss with clean tensors
loss = torch.linalg.norm(clean_targets.unsqueeze(1) - clean_predictions, ord=2, dim=-1)
# Set loss to NaN where either targets or predictions were originally NaN
loss = torch.where(
combined_nan_mask,
torch.tensor(float('nan'), device=loss.device, dtype=loss.dtype),
loss,
)
return loss
[docs]
def __call__(
self,
keypoints_targ_3d: Float[torch.Tensor, "batch num_keypoints 3"],
keypoints_pred_3d: Float[torch.Tensor, "batch cam_pairs num_keypoints 3"],
stage: Literal["train", "val", "test"] | None = None,
**kwargs: Any,
) -> tuple[Float[torch.Tensor, ""], list[dict]]:
"""Compute the pairwise projections loss.
Args:
keypoints_targ_3d: ground-truth 3D keypoints of shape ``(batch, num_keypoints, 3)``.
keypoints_pred_3d: predicted 3D keypoints per camera pair, shape
``(batch, cam_pairs, num_keypoints, 3)``.
stage: training stage for logging.
**kwargs: ignored extra keyword arguments.
Returns:
Tuple of scalar loss and list of logging dicts.
Raises:
ValueError: if either ``keypoints_targ_3d`` or ``keypoints_pred_3d`` is ``None``.
"""
# check if 3D keypoints are available
if keypoints_targ_3d is None or keypoints_pred_3d is None:
raise ValueError(
f"3D keypoints not available for {stage} stage. "
"Camera params file is required but not found;"
"Turn off supervised_pairwise_projections loss to avoid this error."
)
elementwise_loss = self.compute_loss(
targets=keypoints_targ_3d,
predictions=keypoints_pred_3d,
)
clean_loss = self.remove_nans(loss=elementwise_loss)
scalar_loss = self.reduce_loss(clean_loss, method="mean")
logs = self.log_loss(loss=scalar_loss, stage=stage)
return scalar_loss, logs
[docs]
class ReprojectionHeatmapLoss(Loss):
"""Penalize error between predicted 2D->3D->2D->heatmap and ground truth heatmap."""
loss_name = "supervised_reprojection_heatmap_mse"
[docs]
def __init__(
self,
original_image_height: int,
original_image_width: int,
downsampled_image_height: int,
downsampled_image_width: int,
log_weight: float = 0.0,
uniform_heatmaps: bool = False,
**kwargs: Any,
) -> None:
"""Initialize ReprojectionHeatmapLoss.
Converts 2D reprojected keypoints (obtained by projecting 3D triangulated predictions
back into each camera's image plane) into heatmaps and compares them with the ground
truth heatmaps using pixel-wise MSE.
Args:
original_image_height: height of the full-resolution input image in pixels.
original_image_width: width of the full-resolution input image in pixels.
downsampled_image_height: height of the heatmap output (after backbone downsampling).
downsampled_image_width: width of the heatmap output.
log_weight: final weight in front of the loss term in the objective function is
computed as ``1.0 / (2.0 * exp(log_weight))``.
uniform_heatmaps: if ``True``, generate uniform (flat) target heatmaps for NaN
ground truth keypoints instead of ignoring them in the loss.
"""
super().__init__(log_weight=log_weight)
self.original_image_height = original_image_height
self.original_image_width = original_image_width
self.downsampled_image_height = downsampled_image_height
self.downsampled_image_width = downsampled_image_width
self.uniform_heatmaps = uniform_heatmaps
[docs]
def remove_nans(
self,
loss: Float[torch.Tensor, "batch num_keypoints heatmap_height heatmap_width"],
targets: Float[torch.Tensor, "batch num_keypoints heatmap_height heatmap_width"],
) -> Float[torch.Tensor, "valid_losses"]:
"""Select only valid (non-zero-target) loss entries.
Args:
loss: element-wise MSE loss tensor.
targets: ground-truth heatmaps; all-zero heatmaps indicate unlabeled keypoints.
Returns:
Flat tensor of valid loss values, or a zero scalar if none are valid.
"""
# Create mask for valid keypoints (non-zero targets)
squeezed_targets = targets.reshape(targets.shape[0], targets.shape[1], -1)
valid_keypoints = ~torch.all(squeezed_targets == 0.0, dim=-1) # [batch, num_keypoints]
# Expand mask to match loss dimensions
valid_mask = valid_keypoints.unsqueeze(-1).unsqueeze(-1) # [batch, num_keypoints, 1, 1]
valid_mask = valid_mask.expand_as(loss) # [batch, num_keypoints, h, w]
valid_losses = torch.masked_select(loss, valid_mask)
if valid_losses.numel() == 0:
# No valid losses, return zero that preserves gradients
dummy_loss = torch.where(valid_mask, loss, torch.zeros_like(loss))
return dummy_loss.sum() # This will be 0.0 and preserve gradients
else:
return valid_losses
[docs]
def compute_loss(
self,
targets: Float[torch.Tensor, "batch_x_num_keypoints heatmap_height heatmap_width"],
predictions: Float[torch.Tensor, "batch_x_num_keypoints heatmap_height heatmap_width"],
) -> Float[torch.Tensor, "batch_x_num_keypoints heatmap_height heatmap_width"]:
"""Compute pixel-wise MSE between reprojected and ground-truth heatmaps.
Args:
targets: ground-truth heatmaps.
predictions: heatmaps generated from reprojected 2D keypoints.
Returns:
Element-wise MSE scaled by the number of heatmap pixels.
"""
h = targets.shape[1]
w = targets.shape[2]
# multiply by number of pixels in heatmap to standardize loss range
loss = F.mse_loss(targets, predictions, reduction="none") * h * w
return loss
[docs]
def __call__(
self,
heatmaps_targ: Float[torch.Tensor, "batch num_keypoints heatmap_height heatmap_width"],
keypoints_pred_2d_reprojected: Float[torch.Tensor, "batch num_keypoints 2"],
stage: Literal["train", "val", "test"] | None = None,
**kwargs: Any,
) -> tuple[Float[torch.Tensor, ""], list[dict]]:
"""Compute the reprojection heatmap loss.
Args:
heatmaps_targ: ground-truth heatmaps.
keypoints_pred_2d_reprojected: 2D keypoints obtained by projecting triangulated 3D
predictions back into each camera, shape ``(batch, num_keypoints, 2)``.
stage: training stage for logging.
**kwargs: ignored extra keyword arguments.
Returns:
Tuple of scalar loss and list of logging dicts.
Raises:
ValueError: if ``keypoints_pred_2d_reprojected`` is ``None``.
"""
# check if reprojected keypoints are available
if keypoints_pred_2d_reprojected is None:
raise ValueError(
f"Reprojected keypoints not available for {stage} stage. "
"Camera params file is required but not found;"
"Turn off supervised_reprojection_heatmap loss to avoid this error."
)
# create heatmaps from 2d reprojections
heatmaps_pred = generate_heatmaps(
keypoints=keypoints_pred_2d_reprojected,
height=self.original_image_height,
width=self.original_image_width,
output_shape=(self.downsampled_image_height, self.downsampled_image_width),
uniform_heatmaps=self.uniform_heatmaps,
keep_gradients=True,
)
elementwise_loss = self.compute_loss(targets=heatmaps_targ, predictions=heatmaps_pred)
clean_loss = self.remove_nans(loss=elementwise_loss, targets=heatmaps_targ)
scalar_loss = self.reduce_loss(clean_loss, method="mean")
logs = self.log_loss(loss=scalar_loss, stage=stage)
return scalar_loss, logs