lightning_pose.losses
lightning_pose.losses.factory Module
Factory utilities for building and combining losses from a Hydra config.
Three components work together:
get_loss_classes()— returns the registry mapping loss-name strings to classes.get_loss_factories()— readscfg.lossesandcfg.model.losses_to_use, assembles per-loss parameter dicts, and returns a{'supervised': LossFactory, 'unsupervised': LossFactory}dict ready to be passed to a model constructor.LossFactory— aLightningModulethat holds instantiated loss objects and computes the total weighted loss in its__call__method.
Adding a new loss:
Define the class in
losses/losses.py, inheriting fromLoss. Set aloss_name: strclass attribute (single name) or multipleLOSS_NAME_*: strclass attributes when one class serves several config strings (e.g.PCALoss).Import the class at the top of this file and add one entry per name to the dict returned by
get_loss_classes().If the loss requires parameters from
cfg.losses(weight, epsilon, etc.), add a corresponding block inget_loss_factories()that reads those values and adds them to theparamsdict for that loss name.
lightning_pose.losses.losses Module
Supervised and unsupervised losses implemented in pytorch.
The lightning pose package defines each loss as its own class; an initialized loss object, in addition to computing the loss, stores hyperparameters related to the loss (weight in the final objective funcion, epsilon-insensitivity parameter, etc.)
A separate LossFactory class (defined in lightning_pose.losses.factory) collects all losses for a given model and orchestrates their execution, logging, etc.
The general flow of each loss class is as follows: - input: predicted and ground truth data - step 0: remove ground truth samples containing nans if desired - step 1: compute loss for each batch element/keypoint/etc - step 2: epsilon-insensitivity: set loss to zero for any batch element with loss < epsilon - step 3: reduce loss (usually mean) - step 4: log values to a dict - step 5: return loss
Classes
Parent class for all losses. |
|
Parent class for different heatmap losses (MSE, Wasserstein, etc). |
|
MSE loss between heatmaps. |
|
Kullback-Leibler loss between heatmaps. |
|
Jensen-Shannon loss between heatmaps. |
|
Penalize predictions that fall outside a low-dimensional subspace. |
|
Penalize temporal differences for each target. |
|
Penalize temporal differences for each heatmap. |
|
Encourage heatmaps to be unimodal using various measures. |
|
MSE loss between ground truth and predicted coordinates. |
|
Root MSE loss between ground truth and predicted coordinates. |
|
Penalize projections from each pair of cameras into 3D world space. |
|
Penalize error between predicted 2D->3D->2D->heatmap and ground truth heatmap. |