lightning_pose.losses

lightning_pose.losses.factory Module

High-level loss class that orchestrates the individual losses.

Classes

LossFactory(losses_params_dict, data_module)

Factory object that contains an object for each specified loss.

lightning_pose.losses.helpers Module

Helper functions for losses.

Functions

convert_dict_values_to_tensors(param_dict, ...)

Classes

EmpiricalEpsilon(percentile)

Find percentile value of a given loss tensor.

lightning_pose.losses.losses Module

Supervised and unsupervised losses implemented in pytorch.

The lightning pose package defines each loss as its own class; an initialized loss object, in addition to computing the loss, stores hyperparameters related to the loss (weight in the final objective funcion, epsilon-insensitivity parameter, etc.)

A separate LossFactory class (defined in lightning_pose.losses.factory) collects all losses for a given model and orchestrates their execution, logging, etc.

The general flow of each loss class is as follows: - input: predicted and ground truth data - step 0: remove ground truth samples containing nans if desired - step 1: compute loss for each batch element/keypoint/etc - step 2: epsilon-insensitivity: set loss to zero for any batch element with loss < epsilon - step 3: reduce loss (usually mean) - step 4: log values to a dict - step 5: return weighted loss

Functions

get_loss_classes()

Get a dict with all the loss classes.

Classes

Loss([data_module, epsilon, log_weight])

Parent class for all losses.

HeatmapLoss([data_module, log_weight])

Parent class for different heatmap losses (MSE, Wasserstein, etc).

HeatmapMSELoss([data_module, log_weight])

MSE loss between heatmaps.

HeatmapKLLoss([data_module, log_weight])

Kullback-Leibler loss between heatmaps.

HeatmapJSLoss([data_module, log_weight])

Kullback-Leibler loss between heatmaps.

PCALoss(loss_name[, components_to_keep, ...])

Penalize predictions that fall outside a low-dimensional subspace.

TemporalLoss([data_module, epsilon, ...])

Penalize temporal differences for each target.

TemporalHeatmapLoss(loss_name[, ...])

Penalize temporal differences for each heatmap.

UnimodalLoss(loss_name, ...[, data_module, ...])

Encourage heatmaps to be unimodal using various measures.

RegressionMSELoss([data_module, epsilon, ...])

MSE loss between ground truth and predicted coordinates.

RegressionRMSELoss([data_module, epsilon, ...])

Root MSE loss between ground truth and predicted coordinates.