lightning_pose.losses
lightning_pose.losses.factory Module
High-level loss class that orchestrates the individual losses.
Classes
Factory object that contains an object for each specified loss. |
lightning_pose.losses.losses Module
Supervised and unsupervised losses implemented in pytorch.
The lightning pose package defines each loss as its own class; an initialized loss object, in addition to computing the loss, stores hyperparameters related to the loss (weight in the final objective funcion, epsilon-insensitivity parameter, etc.)
A separate LossFactory class (defined in lightning_pose.losses.factory) collects all losses for a given model and orchestrates their execution, logging, etc.
The general flow of each loss class is as follows: - input: predicted and ground truth data - step 0: remove ground truth samples containing nans if desired - step 1: compute loss for each batch element/keypoint/etc - step 2: epsilon-insensitivity: set loss to zero for any batch element with loss < epsilon - step 3: reduce loss (usually mean) - step 4: log values to a dict - step 5: return loss
Functions
Get a dict with all the loss classes. |
Classes
Parent class for all losses. |
|
Parent class for different heatmap losses (MSE, Wasserstein, etc). |
|
MSE loss between heatmaps. |
|
Kullback-Leibler loss between heatmaps. |
|
Jensen-Shannon loss between heatmaps. |
|
Penalize predictions that fall outside a low-dimensional subspace. |
|
Penalize temporal differences for each target. |
|
Penalize temporal differences for each heatmap. |
|
Encourage heatmaps to be unimodal using various measures. |
|
MSE loss between ground truth and predicted coordinates. |
|
Root MSE loss between ground truth and predicted coordinates. |
|
Penalize projections from each pair of cameras into 3D world space. |
|
Penalize error between predicted 2D->3D->2D->heatmap and ground truth heatmap. |