Source code for lightning_pose.losses.factory

"""High-level loss class that orchestrates the individual losses."""

from typing import Literal

import lightning.pytorch as pl
import torch
from torchtyping import TensorType

from lightning_pose.data.datamodules import BaseDataModule, UnlabeledDataModule
from lightning_pose.losses.losses import get_loss_classes

# to ignore imports for sphix-autoapidoc
__all__ = [
    "LossFactory",
]


[docs] class LossFactory(pl.LightningModule): """Factory object that contains an object for each specified loss."""
[docs] def __init__( self, losses_params_dict: dict[str, dict], data_module: BaseDataModule | UnlabeledDataModule, ) -> None: super().__init__() self.losses_params_dict = losses_params_dict self.data_module = data_module # initialize loss classes self._initialize_loss_instances()
def _initialize_loss_instances(self): self.loss_instance_dict = {} loss_classes_dict = get_loss_classes() for loss, params in self.losses_params_dict.items(): self.loss_instance_dict[loss] = loss_classes_dict[loss]( data_module=self.data_module, **params )
[docs] def __call__( self, stage: Literal["train", "val", "test"] | None = None, anneal_weight: float | torch.Tensor = 1.0, **kwargs ) -> tuple[TensorType[()], list[dict]]: # loop over losses, compute, sum, log # don't log if stage is None tot_loss = 0.0 log_list_all = [] for loss_name, loss_instance in self.loss_instance_dict.items(): # kwargs options: # - heatmaps_targ # - heatmaps_pred # - keypoints_targ # - keypoints_pred # # if a Loss class needs to manipulate other objects (e.g. image embedding), # the model's `training_step` method must supply that tensor to the loss # factory using the correct keyword argument (defined by the new Loss # class's `__call__` method) # "stage" is used for logging purposes curr_loss, log_list = loss_instance(stage=stage, **kwargs) current_weighted_loss = loss_instance.weight * curr_loss if anneal_weight is None or loss_name in ["heatmap_mse", "heatmap_kl", "heatmap_js"]: anneal_weight_ = 1.0 else: anneal_weight_ = anneal_weight tot_loss += anneal_weight_ * current_weighted_loss # log weighted losses (unweighted losses auto-logged by loss instance) log_list += [ { "name": f"{stage}_{loss_name}_loss_weighted", "value": current_weighted_loss, } ] # append all losses log_list_all += log_list return tot_loss, log_list_all