get_model

lightning_pose.models.get_model(cfg: DictConfig | ListConfig, data_module: BaseDataModule | UnlabeledDataModule | None, loss_factories: dict[str, LossFactory] | dict[str, None]) ALLOWED_MODELS[source]

Build a pose estimation model from a Hydra config.

Resolves optimizer and lr-scheduler defaults, then dispatches on cfg.model.model_type and whether unsupervised losses are present to instantiate the appropriate model class. Optionally loads weights from cfg.model.checkpoint after construction (supports both .ckpt files and directories containing one).

Parameters:
  • cfg

    Hydra config. Relevant fields: - cfg.model.model_type: one of 'regression', 'heatmap',

    'heatmap_mhcrnn', 'heatmap_multiview_transformer'.

    • cfg.model.backbone: backbone identifier (see ALLOWED_BACKBONES).

    • cfg.model.losses_to_use: list of unsupervised loss names; empty/None selects the fully supervised branch.

    • cfg.model.checkpoint: optional path to a .ckpt file or directory from which to load weights after construction.

    • cfg.data.image_resize_dims: ViT backbones require height == width.

  • data_module – data module used to infer num_targets for heatmap models; may be None when building a model without a dataset (e.g. inference only).

  • loss_factories – dict with keys 'supervised' and 'unsupervised', each mapping to a LossFactory instance (or None for stub construction in tests).

Returns:

instantiated model ready for training or inference.

Raises:
  • RuntimeError – if a ViT backbone is selected with non-square image dimensions.

  • NotImplementedError – if cfg.model.model_type is not a recognised value.