Training

If this is your first time training a model, you’ll need to:

  1. Organize your data (per the document in Organizing your data)

  2. Create a valid config file

After that, you are ready to train a model.

Create a valid config file

Copy the default config (config_default.yaml) to a local file and modify the data section to point to your own dataset. For example:

data:
   image_resize_dims:
     height: 256
     width: 256
   data_dir: /home/user1/data/
   video_dir: /home/user1/data/videos
   csv_file: labeled_frames.csv
   downsample_factor: 2
   # total number of keypoints
   num_keypoints: 3
   keypoint_names:
       - paw_left
       - paw_right
       - nose_tip

Sections other than data have reasonable defaults for getting started, but can be modified as well. For the full reference of fields, see Model config.yaml.

Train a model

The command litpose train (installed with lightning-pose) is used to train a model from the command line:

# Replace 'config_default.yaml' with the path to your config file.
litpose train config_default.yaml

The model will be saved in ./outputs/{YYYY-MM-DD}/{HH:MM:SS}/, creating the folder if it does not already exist. To customize the output directory, use the --output_dir OUTPUT_DIR flag of the command.

# Save to 'outputs/lp_test_1'
litpose train config_default.yaml --output_dir outputs/lp_test_1

Note

If the command litpose is not found, ensure that you’ve activated the conda environment with lightning-pose installed, and that you’re using version >= 1.7.0 (verify this using pip show lightning-pose).

For the full listing of training options, see the CLI reference: Train.

Config overrides

If you want to override some config values before training, you can use the --overrides flag. This uses hydra under the hood, so refer to the hydra syntax for config overrides.

# Train for only 5 epochs
litpose train config_default.yaml --overrides training.min_epochs=5 training.max_epochs=5

# Train a supervised model
litpose train config_default.yaml --output_dir outputs/supervised --overrides \
  model.losses_to_use=null

Post-training flags

After training, lightning pose can automatically predict on some videos and save out videos labeled with its predictions. The config settings that control this behavior are:

  • eval.predict_vids_after_training: if true, automatically run inference after training on all videos located in the directory given by eval.test_videos_directory; results are saved to the model directory

  • eval.save_vids_after_training: if true (as well as eval.predict_vids_after_training) the keypoints predicted during the inference step will be overlaid on the videos and saved with inference outputs to the model directory

Training on sample dataset

To quickly try lightning-pose without your own dataset, the lightning-pose git repository provides a small sample dataset. Clone the repository and run the train command pointed at our sample config:

# (Skip this if you've already cloned, i.e. to install from source.)
git clone https://github.com/paninski-lab/lightning-pose

# Run from a directory containing the lightning-pose repo.
litpose train lightning-pose/scripts/configs/config_mirror-mouse-example.yaml

Tensorboard

Training metrics such as losses are logged in model_dir/tb_logs. To view the logged losses via tensorboard, run:

tensorboard --logdir outputs/YYYY-MM-DD/

where you use the date in which you ran the model. Click on the provided link in the terminal, which will look something like http://localhost:6006/. Note that if you save the model at a different directory, just use that directory after --logdir.

Note

If you don’t see all your models in tensorboard, hit the refresh button on the top right corner of the screen, and the other models should appear.

Metrics are plotted as a function of step/batch. Validation metrics are typically recorded less frequently than train metrics. The frequency of these checks are controlled by cfg.training.log_every_n_steps (training) and cfg.training.check_val_every_n_epoch (validation).

Available metrics

The following are the important metrics for all model types (supervised, context, semi-supervised, etc.):

  • train_supervised_loss: this is the same as train_heatmap_mse_loss_weighted, which is the mean square error (MSE) between the true and predicted heatmaps on labeled training data

  • train_supervised_rmse: the root mean square error (RMSE) between the true and predicted (x, y) coordinates on labeled training data; scale is in pixels

  • val_supervised_loss: this is the same as val_heatmap_mse_loss_weighted, which is the MSE between the true and predicted heatmaps on labeled validation data

  • val_supervised_rmse: the RMSE between the true and predicted (x, y) coordinates on labeled validation data; scale is in pixels

The following are important metrics for the semi-supervised models:

  • train_pca_multiview_loss_weighted: the train_pca_multiview_loss (in pixels), which measures multiview consistency, multplied by the loss weight set in the configuration file. This metric is only computed on batches of unlabeled training data.

  • train_pca_singleview_loss_weighted: the train_pca_singleview_loss (in pixels), which measures pose plausibility, multplied by the loss weight set in the configuration file. This metric is only computed on batches of unlabeled training data.

  • train_temporal_loss_weighted: the train_temporal_loss (in pixels), which measures temporal smoothness, multplied by the loss weight set in the configuration file. This metric is only computed on batches of unlabeled training data.

  • total_unsupervised_importance: a weight on all weighted unsupervised losses that linearly increases from 0 to 1 over 100 epochs

  • total_loss: weighted supervised loss (train_heatmap_mse_loss_weighted) plus total_unsupervised_importance times the sum of all applicable weighted unsupervised losses

Model directory structure

Lightning Pose saves trained models and their outputs in a structured way. For a detailed reference on the model directory structure and its contents, see Model Directory Structure.

The following is a brief overview of the directory structure:

/path/to/model/YYYY-MM-DD/HH-MM-SS/
  ├── tb_logs/
  ├── video_preds/
  │   └── labeled_videos/
  ├── config.yaml
  ├── predictions.csv
  ├── predictions_pca_multiview_error.csv
  ├── predictions_pca_singleview_error.csv
  └── predictions_pixel_error.csv
  • tb_logs/: contains training logs for Tensorboard.

  • video_preds/: predictions and metrics from videos.

  • predictions.csv: predictions on labeled data.

  • predictions_pixel_error.csv: Euclidean distance between the predictions and the labeled keypoints.