lightning_pose.losses

lightning_pose.losses.factory Module

High-level loss class that orchestrates the individual losses.

Classes

LossFactory

Factory object that contains an object for each specified loss.

lightning_pose.losses.losses Module

Supervised and unsupervised losses implemented in pytorch.

The lightning pose package defines each loss as its own class; an initialized loss object, in addition to computing the loss, stores hyperparameters related to the loss (weight in the final objective funcion, epsilon-insensitivity parameter, etc.)

A separate LossFactory class (defined in lightning_pose.losses.factory) collects all losses for a given model and orchestrates their execution, logging, etc.

The general flow of each loss class is as follows: - input: predicted and ground truth data - step 0: remove ground truth samples containing nans if desired - step 1: compute loss for each batch element/keypoint/etc - step 2: epsilon-insensitivity: set loss to zero for any batch element with loss < epsilon - step 3: reduce loss (usually mean) - step 4: log values to a dict - step 5: return weighted loss

Functions

get_loss_classes()

Get a dict with all the loss classes.

Classes

Loss

Parent class for all losses.

HeatmapLoss

Parent class for different heatmap losses (MSE, Wasserstein, etc).

HeatmapMSELoss

MSE loss between heatmaps.

HeatmapKLLoss

Kullback-Leibler loss between heatmaps.

HeatmapJSLoss

Kullback-Leibler loss between heatmaps.

PCALoss

Penalize predictions that fall outside a low-dimensional subspace.

TemporalLoss

Penalize temporal differences for each target.

TemporalHeatmapLoss

Penalize temporal differences for each heatmap.

UnimodalLoss

Encourage heatmaps to be unimodal using various measures.

RegressionMSELoss

MSE loss between ground truth and predicted coordinates.

RegressionRMSELoss

Root MSE loss between ground truth and predicted coordinates.

PairwiseProjectionsLoss

Penalize projections from each pair of cameras into 3D world space.