LossFactoryο
- class lightning_pose.losses.factory.LossFactory[source]ο
Bases:
LightningModuleFactory object that contains an object for each specified loss.
Methods Summary
__call__([stage,Β anneal_weight])Compute the total weighted loss and collect logging entries for all registered losses.
Methods Documentation
- __call__(stage: Literal['train', 'val', 'test'] | None = None, anneal_weight: float | Tensor | None = 1.0, **kwargs: Any) tuple[Float[Tensor, ''], list[dict]][source]ο
Compute the total weighted loss and collect logging entries for all registered losses.
- Parameters:
stage β training stage used for loss logging (
'train','val','test'); passNoneto suppress logging.anneal_weight β scalar multiplier applied to all non-heatmap losses; typically the output of an
AnnealWeightcallback.**kwargs β tensors forwarded to each individual loss (e.g.,
heatmaps_targ,keypoints_pred).
- Returns:
scalar total loss tensor.
list of logging dicts with
'name'and'value'keys.
- Return type:
Tuple of
- __init__(losses_params_dict: dict[str, dict], data_module: BaseDataModule | UnlabeledDataModule | None) None[source]ο
Initialize LossFactory and create all specified loss instances.
- Parameters:
losses_params_dict β mapping from loss name to a dict of keyword arguments that will be passed to the corresponding loss class constructor.
data_module β data module passed to each loss; required for data-dependent losses such as PCA.
- __new__(**kwargs)ο