Source code for lightning_pose.losses.factory

"""High-level loss class that orchestrates the individual losses."""

from typing import Any, Literal

import lightning.pytorch as pl
import numpy as np
import torch
from jaxtyping import Float
from omegaconf import DictConfig, ListConfig, OmegaConf

from lightning_pose.data.datamodules import BaseDataModule, UnlabeledDataModule
from lightning_pose.losses.losses import (
    HeatmapJSLoss,
    HeatmapKLLoss,
    HeatmapMSELoss,
    Loss,
    PairwiseProjectionsLoss,
    PCALoss,
    RegressionMSELoss,
    ReprojectionHeatmapLoss,
    TemporalHeatmapLoss,
    TemporalLoss,
    UnimodalLoss,
)

# to ignore imports for sphix-autoapidoc
__all__ = [
    'LossFactory',
    'get_loss_classes',
    'get_loss_factories',
]


[docs] def get_loss_classes() -> dict[str, type[Loss]]: """Return a mapping from loss name string to loss class. Returns: dict mapping each registered loss name to its class. """ return { RegressionMSELoss.loss_name: RegressionMSELoss, HeatmapMSELoss.loss_name: HeatmapMSELoss, HeatmapKLLoss.loss_name: HeatmapKLLoss, HeatmapJSLoss.loss_name: HeatmapJSLoss, PCALoss.LOSS_NAME_MULTIVIEW: PCALoss, PCALoss.LOSS_NAME_SINGLEVIEW: PCALoss, TemporalLoss.loss_name: TemporalLoss, TemporalHeatmapLoss.LOSS_NAME_MSE: TemporalHeatmapLoss, TemporalHeatmapLoss.LOSS_NAME_KL: TemporalHeatmapLoss, UnimodalLoss.LOSS_NAME_MSE: UnimodalLoss, UnimodalLoss.LOSS_NAME_KL: UnimodalLoss, UnimodalLoss.LOSS_NAME_JS: UnimodalLoss, PairwiseProjectionsLoss.loss_name: PairwiseProjectionsLoss, ReprojectionHeatmapLoss.loss_name: ReprojectionHeatmapLoss, }
[docs] def get_loss_factories( cfg: DictConfig | ListConfig, data_module: BaseDataModule | UnlabeledDataModule, ) -> dict: """Create supervised and unsupervised loss factories from a hydra config. Args: cfg: hydra config carrying model, data, and loss parameters. data_module: data module passed to data-dependent losses such as PCA. Returns: dict with keys ``'supervised'`` and ``'unsupervised'``, each mapping to a :class:`LossFactory` instance. """ cfg_loss_dict = OmegaConf.to_object(cfg.losses) assert cfg_loss_dict is not None loss_params_dict: dict[str, dict] = {'supervised': {}, 'unsupervised': {}} # collect supervised losses; log_weight=0.0 → effective weight = 1/2 if cfg.model.model_type.find('heatmap') > -1: loss_name = 'heatmap_' + cfg.model.heatmap_loss_type loss_params_dict['supervised'][loss_name] = {'log_weight': 0.0} if cfg.model.model_type.find('multiview') > -1 and cfg.data.get('camera_params_file'): log_weight_sp = cfg.losses.get( 'supervised_pairwise_projections', {} ).get('log_weight') if log_weight_sp is not None: print('Adding supervised pairwise projection loss') loss_params_dict['supervised']['supervised_pairwise_projections'] = { 'log_weight': log_weight_sp } log_weight_hr = cfg.losses.get( 'supervised_reprojection_heatmap_mse', {} ).get('log_weight') if log_weight_hr is not None: print('Adding supervised reprojection heatmap loss') height_og = cfg.data.image_resize_dims.height width_og = cfg.data.image_resize_dims.width height_ds = int(height_og // (2 ** cfg.data.get('downsample_factor', 2))) width_ds = int(width_og // (2 ** cfg.data.get('downsample_factor', 2))) loss_params_dict['supervised']['supervised_reprojection_heatmap_mse'] = { 'log_weight': log_weight_hr, 'original_image_height': height_og, 'original_image_width': width_og, 'downsampled_image_height': height_ds, 'downsampled_image_width': width_ds, } else: loss_params_dict['supervised'][cfg.model.model_type] = {'log_weight': 0.0} # collect unsupervised losses and their params if cfg.model.losses_to_use is not None: for loss_name in cfg.model.losses_to_use: loss_params_dict['unsupervised'][loss_name] = cfg_loss_dict[loss_name] loss_params_dict['unsupervised'][loss_name]['loss_name'] = loss_name if loss_name[:8] == 'unimodal' or loss_name[:16] == 'temporal_heatmap': if cfg.model.model_type == 'regression': raise NotImplementedError( 'unimodal loss can only be used with classes inheriting from ' 'HeatmapTracker. \nYou specified a RegressionTracker model.' ) raise Exception( 'need to update unimodal and/or temporal heatmap loss to not use ' 'cfg.data.image_resize_dims, which has been deprecated.' ) height_og = cfg.data.image_resize_dims.height width_og = cfg.data.image_resize_dims.width loss_params_dict['unsupervised'][loss_name]['original_image_height'] = height_og loss_params_dict['unsupervised'][loss_name]['original_image_width'] = width_og height_ds = int(height_og // (2 ** cfg.data.get('downsample_factor', 2))) width_ds = int(width_og // (2 ** cfg.data.get('downsample_factor', 2))) loss_params_dict['unsupervised'][loss_name]['downsampled_image_height'] = height_ds loss_params_dict['unsupervised'][loss_name]['downsampled_image_width'] = width_ds if loss_name[:8] == 'unimodal': loss_params_dict['unsupervised'][loss_name]['uniform_heatmaps'] = ( cfg.training.get('uniform_heatmaps_for_nan_keypoints', False) ) elif loss_name == 'pca_multiview': if cfg.data.get('view_names', None) and len(cfg.data.view_names) > 1: num_keypoints = cfg.data.num_keypoints num_views = len(cfg.data.view_names) if isinstance(cfg.data.mirrored_column_matches[0], int): loss_params_dict['unsupervised'][loss_name][ 'mirrored_column_matches' ] = [ ( v * num_keypoints + np.array(cfg.data.mirrored_column_matches, dtype=int) ).tolist() for v in range(num_views) ] else: loss_params_dict['unsupervised'][loss_name][ 'mirrored_column_matches' ] = cfg.data.mirrored_column_matches else: loss_params_dict['unsupervised'][loss_name][ 'mirrored_column_matches' ] = cfg.data.mirrored_column_matches elif loss_name == 'pca_singleview': if cfg.data.get('view_names', None) and len(cfg.data.view_names) > 1: raise NotImplementedError( 'The Pose PCA loss is currently not implemented for multiview data.' ) else: loss_params_dict['unsupervised'][loss_name][ 'columns_for_singleview_pca' ] = cfg.data.get('columns_for_singleview_pca', None) loss_factory_sup = LossFactory( losses_params_dict=loss_params_dict['supervised'], data_module=data_module, ) loss_factory_unsup = LossFactory( losses_params_dict=loss_params_dict['unsupervised'], data_module=data_module, ) return {'supervised': loss_factory_sup, 'unsupervised': loss_factory_unsup}
[docs] class LossFactory(pl.LightningModule): """Factory object that contains an object for each specified loss."""
[docs] def __init__( self, losses_params_dict: dict[str, dict], data_module: BaseDataModule | UnlabeledDataModule | None, ) -> None: """Initialize LossFactory and create all specified loss instances. Args: losses_params_dict: mapping from loss name to a dict of keyword arguments that will be passed to the corresponding loss class constructor. data_module: data module passed to each loss; required for data-dependent losses such as PCA. """ super().__init__() self.losses_params_dict = losses_params_dict self.data_module = data_module # initialize loss classes self._initialize_loss_instances()
def _initialize_loss_instances(self) -> None: """Instantiate each loss class from ``self.losses_params_dict`` and store them.""" self.loss_instance_dict = {} loss_classes_dict = get_loss_classes() for loss, params in self.losses_params_dict.items(): self.loss_instance_dict[loss] = loss_classes_dict[loss]( data_module=self.data_module, **params )
[docs] def __call__( self, stage: Literal['train', 'val', 'test'] | None = None, anneal_weight: float | torch.Tensor | None = 1.0, **kwargs: Any, ) -> tuple[Float[torch.Tensor, ''], list[dict]]: """Compute the total weighted loss and collect logging entries for all registered losses. Args: stage: training stage used for loss logging (``'train'``, ``'val'``, ``'test'``); pass ``None`` to suppress logging. anneal_weight: scalar multiplier applied to all non-heatmap losses; typically the output of an ``AnnealWeight`` callback. **kwargs: tensors forwarded to each individual loss (e.g., ``heatmaps_targ``, ``keypoints_pred``). Returns: Tuple of: - scalar total loss tensor. - list of logging dicts with ``'name'`` and ``'value'`` keys. """ tot_loss: Float[torch.Tensor, ''] = torch.tensor(0.0) log_list_all = [] for loss_name, loss_instance in self.loss_instance_dict.items(): # kwargs options: # - heatmaps_targ # - heatmaps_pred # - keypoints_targ # - keypoints_pred # # if a Loss class needs to manipulate other objects (e.g. image embedding), # the model's `training_step` method must supply that tensor to the loss # factory using the correct keyword argument (defined by the new Loss # class's `__call__` method) curr_loss, log_list = loss_instance(stage=stage, **kwargs) current_weighted_loss = loss_instance.weight * curr_loss if anneal_weight is None or loss_name in ['heatmap_mse', 'heatmap_kl', 'heatmap_js']: anneal_weight_ = 1.0 else: anneal_weight_ = anneal_weight scaled = anneal_weight_ * current_weighted_loss # move accumulator to loss device on first iteration (losses run on GPU at train time) tot_loss = tot_loss.to(scaled.device) + scaled # log weighted losses (unweighted losses auto-logged by loss instance) log_list += [ { 'name': f'{stage}_{loss_name}_loss_weighted', 'value': current_weighted_loss, } ] log_list_all += log_list return tot_loss, log_list_all