Trainingο
If this is your first time training a model, youβll need to:
Organize your data (per the document in Organizing your data)
Create a valid config file
After that, you are ready to train a model.
Create a valid config fileο
Copy the default config (config_default.yaml)
to a local file and modify the data section to point to your own dataset. For example:
data:
image_resize_dims:
height: 256
width: 256
data_dir: /home/user1/data/
video_dir: /home/user1/data/videos
csv_file: labeled_frames.csv
downsample_factor: 2
# total number of keypoints
num_keypoints: 3
keypoint_names:
- paw_left
- paw_right
- nose_tip
Sections other than data have reasonable defaults for getting started,
but can be modified as well. For the full reference of fields, see Model config.yaml.
Train a modelο
The command litpose train (installed with lightning-pose) is used to train a model from the command line:
# Replace 'config_default.yaml' with the path to your config file.
litpose train config_default.yaml
The model will be saved in ./outputs/{YYYY-MM-DD}/{HH:MM:SS}/, creating the folder if it does not already exist.
To customize the output directory, use the --output_dir OUTPUT_DIR flag of the command.
# Save to 'outputs/lp_test_1'
litpose train config_default.yaml --output_dir outputs/lp_test_1
Note
If the command litpose is not found, ensure that youβve activated the conda
environment with lightning-pose installed, and that youβre using version >= 1.7.0
(verify this using pip show lightning-pose).
For the full listing of training options, see the CLI reference: Train.
Config overridesο
If you want to override some config values before training, you can use the --overrides flag.
This uses hydra under the hood, so refer to the hydra syntax for config overrides.
# Train for only 5 epochs
litpose train config_default.yaml --overrides training.min_epochs=5 training.max_epochs=5
# Train a supervised model
litpose train config_default.yaml --output_dir outputs/supervised --overrides \
model.losses_to_use=null
Post-training flagsο
After training, lightning pose can automatically predict on some videos and save out videos labeled with its predictions. The config settings that control this behavior are:
eval.predict_vids_after_training: iftrue, automatically run inference after training on all videos located in the directory given byeval.test_videos_directory; results are saved to the model directoryeval.save_vids_after_training: iftrue(as well aseval.predict_vids_after_training) the keypoints predicted during the inference step will be overlaid on the videos and saved with inference outputs to the model directory
Training on sample datasetο
To quickly try lightning-pose without your own dataset, the lightning-pose git repository provides a small sample dataset. Clone the repository and run the train command pointed at our sample config:
# (Skip this if you've already cloned, i.e. to install from source.)
git clone https://github.com/paninski-lab/lightning-pose
# Run from a directory containing the lightning-pose repo.
litpose train lightning-pose/scripts/configs/config_mirror-mouse-example.yaml
Tensorboardο
Training metrics such as losses are logged in model_dir/tb_logs.
To view the logged losses via tensorboard, run:
tensorboard --logdir outputs/YYYY-MM-DD/
where you use the date in which you ran the model.
Click on the provided link in the terminal, which will look something like
http://localhost:6006/.
Note that if you save the model at a different directory, just use that directory after
--logdir.
Note
If you donβt see all your models in tensorboard, hit the refresh button on the top right corner of the screen, and the other models should appear.
Metrics are plotted as a function of step/batch. Validation metrics are typically recorded less
frequently than train metrics.
The frequency of these checks are controlled by cfg.training.log_every_n_steps (training)
and cfg.training.check_val_every_n_epoch (validation).
Available metrics
The following are the important metrics for all model types (supervised, context, semi-supervised, etc.):
train_supervised_loss: this is the same astrain_heatmap_mse_loss_weighted, which is the mean square error (MSE) between the true and predicted heatmaps on labeled training datatrain_supervised_rmse: the root mean square error (RMSE) between the true and predicted (x, y) coordinates on labeled training data; scale is in pixelsval_supervised_loss: this is the same asval_heatmap_mse_loss_weighted, which is the MSE between the true and predicted heatmaps on labeled validation dataval_supervised_rmse: the RMSE between the true and predicted (x, y) coordinates on labeled validation data; scale is in pixels
The following are important metrics for the semi-supervised models:
train_pca_multiview_loss_weighted: thetrain_pca_multiview_loss(in pixels), which measures multiview consistency, multplied by the loss weight set in the configuration file. This metric is only computed on batches of unlabeled training data.train_pca_singleview_loss_weighted: thetrain_pca_singleview_loss(in pixels), which measures pose plausibility, multplied by the loss weight set in the configuration file. This metric is only computed on batches of unlabeled training data.train_temporal_loss_weighted: thetrain_temporal_loss(in pixels), which measures temporal smoothness, multplied by the loss weight set in the configuration file. This metric is only computed on batches of unlabeled training data.total_unsupervised_importance: a weight on all weighted unsupervised losses that linearly increases from 0 to 1 over 100 epochstotal_loss: weighted supervised loss (train_heatmap_mse_loss_weighted) plustotal_unsupervised_importancetimes the sum of all applicable weighted unsupervised losses
Model directory structureο
Lightning Pose saves trained models and their outputs in a structured way. For a detailed reference on the model directory structure and its contents, see Model Directory Structure.
The following is a brief overview of the directory structure:
/path/to/model/YYYY-MM-DD/HH-MM-SS/
βββ tb_logs/
βββ video_preds/
β βββ labeled_videos/
βββ config.yaml
βββ predictions.csv
βββ predictions_pca_multiview_error.csv
βββ predictions_pca_singleview_error.csv
βββ predictions_pixel_error.csv
tb_logs/: contains training logs for Tensorboard.video_preds/: predictions and metrics from videos.predictions.csv: predictions on labeled data.predictions_pixel_error.csv: Euclidean distance between the predictions and the labeled keypoints.