get_model
- lightning_pose.models.get_model(cfg: DictConfig | ListConfig, data_module: BaseDataModule | UnlabeledDataModule | None, loss_factories: dict[str, LossFactory] | dict[str, None]) ALLOWED_MODELS[source]
Build a pose estimation model from a Hydra config.
Resolves optimizer and lr-scheduler defaults, then dispatches on
cfg.model.model_typeand whether unsupervised losses are present to instantiate the appropriate model class. Optionally loads weights fromcfg.model.checkpointafter construction (supports both.ckptfiles and directories containing one).- Parameters:
cfg –
Hydra config. Relevant fields: -
cfg.model.model_type: one of'regression','heatmap','heatmap_mhcrnn','heatmap_multiview_transformer'.cfg.model.backbone: backbone identifier (seeALLOWED_BACKBONES).cfg.model.losses_to_use: list of unsupervised loss names; empty/None selects the fully supervised branch.cfg.model.checkpoint: optional path to a.ckptfile or directory from which to load weights after construction.cfg.data.image_resize_dims: ViT backbones require height == width.
data_module – data module used to infer
num_targetsfor heatmap models; may beNonewhen building a model without a dataset (e.g. inference only).loss_factories – dict with keys
'supervised'and'unsupervised', each mapping to aLossFactoryinstance (orNonefor stub construction in tests).
- Returns:
instantiated model ready for training or inference.
- Raises:
RuntimeError – if a ViT backbone is selected with non-square image dimensions.
NotImplementedError – if
cfg.model.model_typeis not a recognised value.