"""High-level loss class that orchestrates the individual losses."""
from typing import Literal
import lightning.pytorch as pl
import torch
from torchtyping import TensorType
from lightning_pose.data.datamodules import BaseDataModule, UnlabeledDataModule
from lightning_pose.losses.losses import get_loss_classes
# to ignore imports for sphix-autoapidoc
__all__ = [
"LossFactory",
]
[docs]
class LossFactory(pl.LightningModule):
"""Factory object that contains an object for each specified loss."""
[docs]
def __init__(
self,
losses_params_dict: dict[str, dict],
data_module: BaseDataModule | UnlabeledDataModule,
) -> None:
super().__init__()
self.losses_params_dict = losses_params_dict
self.data_module = data_module
# initialize loss classes
self._initialize_loss_instances()
def _initialize_loss_instances(self):
self.loss_instance_dict = {}
loss_classes_dict = get_loss_classes()
for loss, params in self.losses_params_dict.items():
self.loss_instance_dict[loss] = loss_classes_dict[loss](
data_module=self.data_module, **params
)
[docs]
def __call__(
self,
stage: Literal["train", "val", "test"] | None = None,
anneal_weight: float | torch.Tensor = 1.0,
**kwargs
) -> tuple[TensorType[()], list[dict]]:
# loop over losses, compute, sum, log
# don't log if stage is None
tot_loss = 0.0
log_list_all = []
for loss_name, loss_instance in self.loss_instance_dict.items():
# kwargs options:
# - heatmaps_targ
# - heatmaps_pred
# - keypoints_targ
# - keypoints_pred
#
# if a Loss class needs to manipulate other objects (e.g. image embedding),
# the model's `training_step` method must supply that tensor to the loss
# factory using the correct keyword argument (defined by the new Loss
# class's `__call__` method)
# "stage" is used for logging purposes
curr_loss, log_list = loss_instance(stage=stage, **kwargs)
current_weighted_loss = loss_instance.weight * curr_loss
if anneal_weight is None or loss_name in ["heatmap_mse", "heatmap_kl", "heatmap_js"]:
anneal_weight_ = 1.0
else:
anneal_weight_ = anneal_weight
tot_loss += anneal_weight_ * current_weighted_loss
# log weighted losses (unweighted losses auto-logged by loss instance)
log_list += [
{
"name": f"{stage}_{loss_name}_loss_weighted",
"value": current_weighted_loss,
}
]
# append all losses
log_list_all += log_list
return tot_loss, log_list_all