lightning_pose.models

lightning_pose.models Package

Pose estimation model classes, re-exported at the package level.

Four model types, each available in a supervised and a semi-supervised variant (8 concrete classes total):

  • regression β€” direct (x, y) coordinate regression. RegressionTracker / SemiSupervisedRegressionTracker β†’ regression_tracker.py

  • heatmap β€” per-keypoint 2-D Gaussian heatmaps. HeatmapTracker / SemiSupervisedHeatmapTracker β†’ heatmap_tracker.py

  • heatmap_mhcrnn β€” heatmaps with temporal context via a recurrent head (MHCRNN). HeatmapTrackerMHCRNN / SemiSupervisedHeatmapTrackerMHCRNN β†’ heatmap_tracker_mhcrnn.py

  • heatmap_multiview_transformer β€” multi-camera heatmaps with cross-view attention. HeatmapTrackerMultiviewTransformer / SemiSupervisedHeatmapTrackerMultiviewTransformer β†’ heatmap_tracker_multiview.py

Supervised / semi-supervised split: every supervised class has a semi-supervised counterpart produced by mixing in SemiSupervisedTrackerMixin. The mixin adds a second loss_factory_unsupervised argument and extends training_step to compute unsupervised losses on unlabeled video frames.

Other files in this package:

  • base.py β€” abstract bases and shared logic: BaseFeatureExtractor, BaseSupervisedTracker, SemiSupervisedTrackerMixin.

  • factory.py β€” get_model() (full construction from config) and get_model_class() (pure (model_type, semi_supervised) β†’ class dispatch); ALLOWED_MODEL_TYPES Literal defined here.

  • backbones/ β€” backbone wrappers and build_backbone(); see backbones/__init__.py for the type hierarchy and how to add a new backbone.

  • heads/ β€” output head classes (HeatmapHead, HeatmapMHCRNNHead, LinearRegressionHead).

Functions

check_if_semi_supervised([losses_to_use])

Determine from the losses config whether the model is semi-supervised.

get_model(cfg,Β data_module,Β loss_factories)

Build a pose estimation model from a Hydra config.

get_model_class(model_type,Β semi_supervised)

Return the model class for the given model type and supervision mode.

Classes

HeatmapTracker

Base model that produces heatmaps of keypoints from images.

SemiSupervisedHeatmapTracker

Model produces heatmaps of keypoints from labeled/unlabeled images.

HeatmapTrackerMHCRNN

Multi-headed Convolutional RNN network that handles context frames.

SemiSupervisedHeatmapTrackerMHCRNN

Model produces heatmaps of keypoints from labeled/unlabeled images.

HeatmapTrackerMultiviewTransformer

Transformer network that handles multi-view datasets.

SemiSupervisedHeatmapTrackerMultiviewTransformer

Semi-supervised HeatmapTrackerMultiviewTransformer that supports unsupervised losses.

RegressionTracker

Base model that produces (x, y) predictions of keypoints from images.

SemiSupervisedRegressionTracker

Model produces vectors of keypoints from labeled/unlabeled images.

lightning_pose.models.base Module

Base class for backbone that acts as a feature extractor.

lightning_pose.models.factory Module

Factory functions for building pose estimation models from a Hydra config.

Public entry points:

  • get_model_class() β€” pure dispatch: returns the model class for a given (model_type, semi_supervised) pair without instantiating anything.

  • get_model() β€” full construction: resolves optimizer/scheduler defaults, instantiates the appropriate model class, and optionally loads weights from a checkpoint.

All model class imports are deferred inside the function bodies to avoid circular imports (this module is loaded early in the call stack, before the model classes are fully defined).

Supported model types: regression, heatmap, heatmap_mhcrnn, heatmap_multiview_transformer.

Adding a new model type: add its string to ALLOWED_MODEL_TYPES, add a branch in get_model_class() (two lines, one per supervision mode), add an elif block in get_model() for its constructor kwargs, and create the model file(s) under lightning_pose/models/.

Subpackages