Training
If this is your first time training a model, you’ll need to:
Organize your data (per the document in Organizing your data)
Create a valid config file
After that, you are ready to train a model.
Create a valid config file
Copy the default config (config_default.yaml)
to a local file and modify the data section to point to your own dataset. For example:
data:
image_resize_dims:
height: 256
width: 256
data_dir: /home/user1/data/
video_dir: /home/user1/data/videos
csv_file: labeled_frames.csv
downsample_factor: 2
# total number of keypoints
num_keypoints: 3
keypoint_names:
- paw_left
- paw_right
- nose_tip
Sections other than data have reasonable defaults for getting started,
but can be modified as well. For the full reference of fields, see Model config.yaml.
Train a model
The command litpose train (installed with lightning-pose) is used to train a model from the command line:
# Replace 'config_default.yaml' with the path to your config file.
litpose train config_default.yaml
The model will be saved in ./outputs/{YYYY-MM-DD}/{HH:MM:SS}/, creating the folder if it does not already exist.
To customize the output directory, use the --output_dir OUTPUT_DIR flag of the command.
# Save to 'outputs/lp_test_1'
litpose train config_default.yaml --output_dir outputs/lp_test_1
Note
If the command litpose is not found, ensure that you’ve activated the conda
environment with lightning-pose installed, and that you’re using version >= 1.7.0
(verify this using pip show lightning-pose).
For the full listing of training options, see the CLI reference: Train.
Config overrides
If you want to override some config values before training, you can use the --overrides flag.
This uses hydra under the hood, so refer to the hydra syntax for config overrides.
# Train for only 5 epochs
litpose train config_default.yaml --overrides training.min_epochs=5 training.max_epochs=5
# Train a supervised model
litpose train config_default.yaml --output_dir outputs/supervised --overrides \
model.losses_to_use=null
Post-training flags
After training, lightning pose can automatically predict on some videos and save out videos labeled with its predictions. The config settings that control this behavior are:
eval.predict_vids_after_training: iftrue, automatically run inference after training on all videos located in the directory given byeval.test_videos_directory; results are saved to the model directoryeval.save_vids_after_training: iftrue(as well aseval.predict_vids_after_training) the keypoints predicted during the inference step will be overlaid on the videos and saved with inference outputs to the model directory
Training on sample dataset
To quickly try lightning-pose without your own dataset, the lightning-pose git repository provides a small sample dataset. Clone the repository and run the train command pointed at our sample config:
# (Skip this if you've already cloned, i.e. to install from source.)
git clone https://github.com/paninski-lab/lightning-pose
# Run from a directory containing the lightning-pose repo.
litpose train lightning-pose/scripts/configs/config_mirror-mouse-example.yaml
Tensorboard
Training metrics such as losses are logged in model_dir/tb_logs.
To view the logged losses via tensorboard, run:
tensorboard --logdir outputs/YYYY-MM-DD/
where you use the date in which you ran the model.
Click on the provided link in the terminal, which will look something like
http://localhost:6006/.
Note that if you save the model at a different directory, just use that directory after
--logdir.
Note
If you don’t see all your models in tensorboard, hit the refresh button on the top right corner of the screen, and the other models should appear.
Metrics are plotted as a function of step/batch. Validation metrics are typically recorded less
frequently than train metrics.
The frequency of these checks are controlled by cfg.training.log_every_n_steps (training)
and cfg.training.check_val_every_n_epoch (validation).
Available metrics
The following are the important metrics for all model types (supervised, context, semi-supervised, etc.):
train_supervised_loss: this is the same astrain_heatmap_mse_loss_weighted, which is the mean square error (MSE) between the true and predicted heatmaps on labeled training datatrain_supervised_rmse: the root mean square error (RMSE) between the true and predicted (x, y) coordinates on labeled training data; scale is in pixelsval_supervised_loss: this is the same asval_heatmap_mse_loss_weighted, which is the MSE between the true and predicted heatmaps on labeled validation dataval_supervised_rmse: the RMSE between the true and predicted (x, y) coordinates on labeled validation data; scale is in pixels
The following are important metrics for the semi-supervised models:
train_pca_multiview_loss_weighted: thetrain_pca_multiview_loss(in pixels), which measures multiview consistency, multplied by the loss weight set in the configuration file. This metric is only computed on batches of unlabeled training data.train_pca_singleview_loss_weighted: thetrain_pca_singleview_loss(in pixels), which measures pose plausibility, multplied by the loss weight set in the configuration file. This metric is only computed on batches of unlabeled training data.train_temporal_loss_weighted: thetrain_temporal_loss(in pixels), which measures temporal smoothness, multplied by the loss weight set in the configuration file. This metric is only computed on batches of unlabeled training data.total_unsupervised_importance: a weight on all weighted unsupervised losses that linearly increases from 0 to 1 over 100 epochstotal_loss: weighted supervised loss (train_heatmap_mse_loss_weighted) plustotal_unsupervised_importancetimes the sum of all applicable weighted unsupervised losses
Model directory structure
/path/to/model/YYYY-MM-DD/HH-MM-SS/
├── tb_logs/
├── video_preds/
│ └── labeled_videos/
├── config.yaml
├── predictions.csv
├── predictions_pca_multiview_error.csv
├── predictions_pca_singleview_error.csv
└── predictions_pixel_error.csv
tb_logs/: model weightsvideo_preds/: predictions and metrics from videos. The config fieldeval.test_videos_directorypoints to a directory of videos; ifeval.predict_vids_after_trainingis set totrue, all videos in the indicated direcotry will be run through the model upon training completion and results stored here.video_preds/labeled_videos/: labeled mp4s. The config fieldeval.test_videos_directorypoints to a directory of videos; ifeval.save_vids_after_trainingis set totrue, all videos in the indicated direcotry will be run through the model upon training completion and results stored here.predictions.csv: predictions on labeled data. The right-most column records the train/val/test split that each example belongs to.predictions_pixel_error.csv: Euclidean distance between the predictions inpredictions.csvand the labeled keypoints (in<YOUR_LABELED_FRAMES>.csv) per keypoint and frame.
We also compute all unsupervised losses, where applicable, and store them (per keypoint and frame) in the following csvs:
predictions_pca_multiview_error.csv: pca multiview reprojection error between predictions and labeled keypointspredictions_pca_singleview_error.csv: pca singleview reprojection error between predictions and labeled keypoints