LossFactory

class lightning_pose.losses.factory.LossFactory[source]

Bases: LightningModule

Factory object that contains an object for each specified loss.

Methods Summary

__call__([stage,Β anneal_weight])

Call self as a function.

Methods Documentation

__call__(stage: Literal['train', 'val', 'test'] | None = None, anneal_weight: float | Tensor = 1.0, **kwargs) Tensor, {'__torchtyping__': True, 'details': ((),), 'cls_name': 'TensorType'}], list[dict]][source]

Call self as a function.

__init__(losses_params_dict: dict[str, dict], data_module: BaseDataModule | UnlabeledDataModule) None[source]
__new__(**kwargs)