LossFactory

class lightning_pose.losses.factory.LossFactory[source]

Bases: LightningModule

Factory object that contains an object for each specified loss.

Methods Summary

__call__([stage,Β anneal_weight])

Compute the total weighted loss and collect logging entries for all registered losses.

Methods Documentation

__call__(stage: Literal['train', 'val', 'test'] | None = None, anneal_weight: float | Tensor | None = 1.0, **kwargs: Any) tuple[Float[Tensor, ''], list[dict]][source]

Compute the total weighted loss and collect logging entries for all registered losses.

Parameters:
  • stage – training stage used for loss logging ('train', 'val', 'test'); pass None to suppress logging.

  • anneal_weight – scalar multiplier applied to all non-heatmap losses; typically the output of an AnnealWeight callback.

  • **kwargs – tensors forwarded to each individual loss (e.g., heatmaps_targ, keypoints_pred).

Returns:

  • scalar total loss tensor.

  • list of logging dicts with 'name' and 'value' keys.

Return type:

Tuple of

__init__(losses_params_dict: dict[str, dict], data_module: BaseDataModule | UnlabeledDataModule | None) None[source]

Initialize LossFactory and create all specified loss instances.

Parameters:
  • losses_params_dict – mapping from loss name to a dict of keyword arguments that will be passed to the corresponding loss class constructor.

  • data_module – data module passed to each loss; required for data-dependent losses such as PCA.

__new__(**kwargs)