get_model_class

lightning_pose.models.get_model_class(model_type: ALLOWED_MODEL_TYPES, semi_supervised: bool) type[ALLOWED_MODELS][source]

Return the model class for the given model type and supervision mode.

Parameters:
  • model_type – one of 'regression', 'heatmap', 'heatmap_mhcrnn', 'heatmap_multiview_transformer'.

  • semi_supervised – True to return the semi-supervised variant.

Returns:

model class (not an instance).

Raises:

NotImplementedError – if model_type is not recognised.