LossFactory

class lightning_pose.losses.factory.LossFactory(losses_params_dict: Dict[str, dict], data_module: BaseDataModule | UnlabeledDataModule)[source]

Bases: LightningModule

Factory object that contains an object for each specified loss.

Methods Summary

__call__([stage, anneal_weight])

Call self as a function.

Methods Documentation

__call__(stage: Literal['train', 'val', 'test'] | None = None, anneal_weight: float | Tensor = 1.0, **kwargs) Tuple[Tensor[Tensor], List[dict]][source]

Call self as a function.