get_model_class
- lightning_pose.models.get_model_class(model_type: ALLOWED_MODEL_TYPES, semi_supervised: bool) type[ALLOWED_MODELS][source]
Return the model class for the given model type and supervision mode.
- Parameters:
model_type – one of
'regression','heatmap','heatmap_mhcrnn','heatmap_multiview_transformer'.semi_supervised – True to return the semi-supervised variant.
- Returns:
model class (not an instance).
- Raises:
NotImplementedError – if
model_typeis not recognised.