lightning_pose.losses

lightning_pose.losses.factory Module

Factory utilities for building and combining losses from a Hydra config.

Three components work together:

  • get_loss_classes() — returns the registry mapping loss-name strings to classes.

  • get_loss_factories() — reads cfg.losses and cfg.model.losses_to_use, assembles per-loss parameter dicts, and returns a {'supervised': LossFactory, 'unsupervised': LossFactory} dict ready to be passed to a model constructor.

  • LossFactory — a LightningModule that holds instantiated loss objects and computes the total weighted loss in its __call__ method.

Adding a new loss:

  1. Define the class in losses/losses.py, inheriting from Loss. Set a loss_name: str class attribute (single name) or multiple LOSS_NAME_*: str class attributes when one class serves several config strings (e.g. PCALoss).

  2. Import the class at the top of this file and add one entry per name to the dict returned by get_loss_classes().

  3. If the loss requires parameters from cfg.losses (weight, epsilon, etc.), add a corresponding block in get_loss_factories() that reads those values and adds them to the params dict for that loss name.

lightning_pose.losses.losses Module

Supervised and unsupervised losses implemented in pytorch.

The lightning pose package defines each loss as its own class; an initialized loss object, in addition to computing the loss, stores hyperparameters related to the loss (weight in the final objective funcion, epsilon-insensitivity parameter, etc.)

A separate LossFactory class (defined in lightning_pose.losses.factory) collects all losses for a given model and orchestrates their execution, logging, etc.

The general flow of each loss class is as follows: - input: predicted and ground truth data - step 0: remove ground truth samples containing nans if desired - step 1: compute loss for each batch element/keypoint/etc - step 2: epsilon-insensitivity: set loss to zero for any batch element with loss < epsilon - step 3: reduce loss (usually mean) - step 4: log values to a dict - step 5: return loss

Classes

Loss

Parent class for all losses.

HeatmapLoss

Parent class for different heatmap losses (MSE, Wasserstein, etc).

HeatmapMSELoss

MSE loss between heatmaps.

HeatmapKLLoss

Kullback-Leibler loss between heatmaps.

HeatmapJSLoss

Jensen-Shannon loss between heatmaps.

PCALoss

Penalize predictions that fall outside a low-dimensional subspace.

TemporalLoss

Penalize temporal differences for each target.

TemporalHeatmapLoss

Penalize temporal differences for each heatmap.

UnimodalLoss

Encourage heatmaps to be unimodal using various measures.

RegressionMSELoss

MSE loss between ground truth and predicted coordinates.

RegressionRMSELoss

Root MSE loss between ground truth and predicted coordinates.

PairwiseProjectionsLoss

Penalize projections from each pair of cameras into 3D world space.

ReprojectionHeatmapLoss

Penalize error between predicted 2D->3D->2D->heatmap and ground truth heatmap.