LossFactoryο
- class lightning_pose.losses.factory.LossFactory[source]ο
Bases:
LightningModuleFactory object that contains an object for each specified loss.
Methods Summary
__call__([stage,Β anneal_weight])Call self as a function.
Methods Documentation
- __call__(stage: Literal['train', 'val', 'test'] | None = None, anneal_weight: float | Tensor = 1.0, **kwargs) tuple[~torch.Annotated[~torch.Tensor, {'__torchtyping__': True, 'details': ((),), 'cls_name': 'TensorType'}], list[dict]][source]ο
Call self as a function.
- __init__(losses_params_dict: dict[str, dict], data_module: BaseDataModule | UnlabeledDataModule) None[source]ο
- __new__(**kwargs)ο