Add a model

To detail the steps that are necessary to fully incorporate a new model into the Lightning Pose package, we will first detail some of the underlying infrastructure, and then use the standard heatmap model as an example.

The BaseFeatureExtractor class

The BaseFeatureExtractor class contains a backbone neural network that, as its name suggests, extracts features from images that can be used for downstream keypoint prediction. There is one key method in this class, get_representations, which takes in a batch of images and outputs a batch of representations. Notably, this method can handle a variety of batch types:

  • individual images for supervised training of a baseline model

  • individual images and their context frames for supervised training of a context model

  • sequence of frames for video inference with a baseline model

  • sequence of frames for video inference with a context model

See BaseFeatureExtractor.get_representations for more documentation on the input/output behavior of this function.

Users can currently choose from different backbone architectures (EfficientNets, ResNets, Transformers) and different initializations for the ResNet50 (classification with ImageNet, pose estimation with AP10k, etc.). If you would like to add a new backbone architecture (say, a new transformer) or a new initialization (say, AP36k) to BaseFeatureExtractor these modifications will be immediately accessible to all downstream models (semi-supervised pose estimators, context models, etc.)

The BaseSupervisedTracker class

The BaseSupervisedTracker class inherits from BaseFeatureExtractor, and contains a mix of already-implemented methods to assist with model training and several abstract methods which should be implemented by children classes (i.e. your new model). For example, the training_step method (used by Lightning in their Trainer class) simply calls the evaluate_labeled method which performs three steps:

  1. process the input batch and return pose features/heatmaps/coordinates (needs to be implemented by downstream models)

  2. send these model outputs to the loss factory to compute and log losses (already implemented)

  3. compute and and log root mean square error for monitoring during training (already implemented)

When implementing a new model, you must implement a method called get_loss_inputs_labeled that will perform step 1. Note that the BaseSupervisedTracker does not implement a “head” that transforms features into pose predictions; that is what individual models will implement.

The SemiSupervisedTrackerMixin

Note

You can skip this section on a first pass

If you wish to implement a model that is compatible with any of the unsupervised losses you will need to use the SemiSupervisedTrackerMixin. This “mixin” is not a complete class on its own but rather implements a training_step that takes unlabeled data into account. This training step consists of two parts:

  1. evaluate the network on a labeled batch

  2. evaluate the network on an unlabeled batch

For step 1, we fall back to the evaluate_labeled method described above. For step 2, there is a new method in this mixin called evaluate_unlabeled that goes through the same process of pushing data through the model and computing and logging the losses.

Implementing a new model

With this necessary background, let’s look at the how to implement a fully supervised and semi-supervised heatmap tracker. The easiest way to parse this information is to look at the HeatmapTracker class and follow along with the text below to see how these ideas are implemented in practice.

A second example can be found in the HeatmapTrackerMHCRNN class, which implements the fully- and semi-supervised versions of the temporal context network.

Finally, a third example can be found in the RegressionTracker, which implements a fully- and semi-supervised trackers that omit heatmaps and directly predict (x, y) coordinates.

Fully supervised model

The HeatmapTracker class inherits from BaseSupervisedTracker, giving it access to the base feature extractor and the training step (as well as validation and test step methods). Our job is to implement a “head” network that takes the features as input and outputs pose predictions (in this case through the use of heatmaps). The __init__ method should call the __init__ method of the parent class (BaseSupervisedTracker), which will construct the backbone feature extractor. You will need to implement one or more methods that construct the head, and initialize the head in this __init__ method. For example, BaseSupervisedTracker contains methods that construct upsampling layers that transform the features into 2D heatmaps, one per keypoint. This class also contains a method that takes the soft argmax of each heatmap to produce an (x, y) coordinate for each keypoint.

Note

if you want to use unsupervised losses in your model the transformation from features to (x, y) coordinate MUST be differentiable!

Another method that you must implement is get_loss_inputs_labeled (which will be called by the training_step method of the parent class). For the BaseSupervisedTracker class this method comprises two parts:

  1. process batch through both feature extractor and head to get heatmaps

  2. process heatmaps to get (x, y) coordinates

Importantly, this method must return a dict with a set of standard key names, which will be used by downstream losses. Your model does not need to return all the keys listed below, but must return the keys used by the losses you choose.

  • heatmaps_targ: target (ground truth) heatmaps for each frame/keypoint in the batch

  • heatmaps_pred: predicted heatmaps

  • keypoints_targ: target (ground truth) coordinates

  • keypoints_targ: target (ground truth) coordinates

  • confidences: uncertainty estimate associated with each keypoint

The final method that you must implement is the predict_step; this tells the model how to transform a batch of frames into (x, y) coordinates (and optionally confidences). This method may look exactly like get_loss_inputs_labeled with a slightly different return format; alternatively, this method can implement non-differentiable operations to choose the final coordinates, such as a hard argmax instead of the soft argmax required for training.

Semi-supervised model

The SemiSupervisedHeatmapTracker class inherits from both

  • HeatmapTracker, which gives it access to the feature extractor, heatmap prediction head, and labeled data loss computation

  • SemiSupervisedTrackerMixin, which gives it access to unlabeled data loss computation.

This class (as well as any other class you want to build that uses unsupervised losses) must implement two methods.

The first is the __init__ method, which should call the __init__ method of the parent class(es), and also builds the unsupervised losses themselves. This is fully taken care of by the LossFactory object, which must be an input to the semi-supervised model’s constructor (more info on this below).

The second method is called get_loss_inputs_unlabeled and, like its sibling function get_loss_inputs_labeled in the fully supervised model, is responsible for two steps:

  1. process batch through both feature extractor and head to get heatmaps

  2. process heatmaps to get (x, y) coordinates

As above, this method will return a dict with a set of standard key names, which will not include target heatmaps or keypoints since those are formed from labeled data and we are dealing exclusively with unlabeled data in this method.

You do not need to implement a predict_step method since the “semi-supervised” aspect of this model only affects training and not inference; therefore during inference the predict_step of the fully supervised model will be used.

Integrating a new model into the Lightning Pose pipeline

Once you’ve implemented your model, the next step is to integrate it into the larger repo so that it can take advantage of the available training and evaluation infrastructure. We will describe this process from the outside in.

Step 1: update the config file

The default configuration file at lightning_pose/scripts/configs/config_default.yaml enumerates all possible hyperparameters needed for building and training a model. If your new model requires additional hyperparameters that you wish to control externally, include these in the config file. Inside the pipeline, when initializing the model, you will have access to every key-value pair in this file.

The field model.model_type is used to specify your model - the current supported values are “regression”, “heatmap”, and “heatmap_mhcrnn”. Add your new model name to this list.

The basic training script can be found at scripts/train_hydra.py. You do not need to update anything in this script to accommodate your new model, but this script uses several helper functions that we will update next.

Step 2: update get_dataset

The first helper function you need to update is get_dataset(), which creates a torch Dataset object associated with your model. For example, the regression-based models do not need a dataset that returns heatmaps, whereas the heatmap-based models do. In this function you will see the if/else statement that creates a dataset based on the model type; include your model in this if/else statement.

Step 3: update get_loss_factories

If your model requires heatmaps for training, in order to ensure the heatmap losses are properly logged you need to add your model to the first if/else statement in the function get_loss_factories() (you will see “heatmap” and “heatmap_mhcrnn” models already represented there). Note that if your model uses heatmaps you will also be able to select from several heatmap losses in the config file using the model.heatmap_loss_type key.

Step 4: update get_model

This next helper function - get_model() - is what translates the key-value pairs from the config file to constructing the actual model. You will see examples of all other models in this function; include your model accordingly.

Step 5: update get_model_class

Finally, there is helper function get_model_class() which is used to seamlessly load model parameters from checkpoint files. Again, there are various if/else statements where your model should be incorporated.

Step 6: optional and miscellaneous additons

If you find yourself needing to write a new DALI dataloader to support your model training, you might also need to update the PredictionHandler class.

Step 7: ADD UNIT TESTS!

Not only is this good coding practice, it makes debugging your model easier! Make a new file in the directory tests/models that follow the same pattern as the other files there. We provide many convenience functions that allow you to set up units tests for fully supervised models, contextmodels, and semi-supervised models (and combinations thereof). Let’s take the fully-supervised heatmap model as an example; once you write the test you can run it from the command line like so:

pytest tests/models/test_heatmap_tracker.py::test_supervised_heatmap

This test will build your model using the helper functions above (like get_model) and train it for several epochs using the toy dataset that comes packaged with this repo.

And that’s it!