lightning_pose.modelsο
lightning_pose.models Packageο
Pose estimation model classes, re-exported at the package level.
Four model types, each available in a supervised and a semi-supervised variant (8 concrete classes total):
regressionβ direct (x, y) coordinate regression.RegressionTracker/SemiSupervisedRegressionTrackerβregression_tracker.pyheatmapβ per-keypoint 2-D Gaussian heatmaps.HeatmapTracker/SemiSupervisedHeatmapTrackerβheatmap_tracker.pyheatmap_mhcrnnβ heatmaps with temporal context via a recurrent head (MHCRNN).HeatmapTrackerMHCRNN/SemiSupervisedHeatmapTrackerMHCRNNβheatmap_tracker_mhcrnn.pyheatmap_multiview_transformerβ multi-camera heatmaps with cross-view attention.HeatmapTrackerMultiviewTransformer/SemiSupervisedHeatmapTrackerMultiviewTransformerβheatmap_tracker_multiview.py
Supervised / semi-supervised split: every supervised class has a semi-supervised
counterpart produced by mixing in
SemiSupervisedTrackerMixin. The mixin adds a
second loss_factory_unsupervised argument and extends training_step to compute
unsupervised losses on unlabeled video frames.
Other files in this package:
base.pyβ abstract bases and shared logic:BaseFeatureExtractor,BaseSupervisedTracker,SemiSupervisedTrackerMixin.factory.pyβget_model()(full construction from config) andget_model_class()(pure(model_type, semi_supervised) β classdispatch);ALLOWED_MODEL_TYPESLiteral defined here.backbones/β backbone wrappers andbuild_backbone(); seebackbones/__init__.pyfor the type hierarchy and how to add a new backbone.heads/β output head classes (HeatmapHead,HeatmapMHCRNNHead,LinearRegressionHead).
Functionsο
|
Determine from the losses config whether the model is semi-supervised. |
|
Build a pose estimation model from a Hydra config. |
|
Return the model class for the given model type and supervision mode. |
Classesο
Base model that produces heatmaps of keypoints from images. |
|
Model produces heatmaps of keypoints from labeled/unlabeled images. |
|
Multi-headed Convolutional RNN network that handles context frames. |
|
Model produces heatmaps of keypoints from labeled/unlabeled images. |
|
Transformer network that handles multi-view datasets. |
|
Semi-supervised HeatmapTrackerMultiviewTransformer that supports unsupervised losses. |
|
Base model that produces (x, y) predictions of keypoints from images. |
|
Model produces vectors of keypoints from labeled/unlabeled images. |
lightning_pose.models.base Moduleο
Base class for backbone that acts as a feature extractor.
lightning_pose.models.factory Moduleο
Factory functions for building pose estimation models from a Hydra config.
Public entry points:
get_model_class()β pure dispatch: returns the model class for a given(model_type, semi_supervised)pair without instantiating anything.get_model()β full construction: resolves optimizer/scheduler defaults, instantiates the appropriate model class, and optionally loads weights from a checkpoint.
All model class imports are deferred inside the function bodies to avoid circular imports (this module is loaded early in the call stack, before the model classes are fully defined).
Supported model types: regression, heatmap, heatmap_mhcrnn,
heatmap_multiview_transformer.
Adding a new model type: add its string to ALLOWED_MODEL_TYPES, add a
branch in get_model_class() (two lines, one per supervision mode), add an
elif block in get_model() for its constructor kwargs, and create the model
file(s) under lightning_pose/models/.