Source code for lightning_pose.train

"""Example model training function."""

import os
import random
import shutil
import sys
import warnings
from pathlib import Path

import lightning.pytorch as pl
import numpy as np
import torch
from omegaconf import DictConfig, ListConfig, OmegaConf, open_dict
from typeguard import typechecked

from lightning_pose.model import Model
from lightning_pose.utils import pretty_print_cfg, pretty_print_str
from lightning_pose.utils.io import return_absolute_data_paths
from lightning_pose.utils.scripts import (
    calculate_train_batches,
    get_callbacks,
    get_data_module,
    get_dataset,
    get_imgaug_transform,
    get_loss_factories,
    get_model,
)

# to ignore imports for sphinx-autoapidoc
__all__ = ["train"]


[docs] @typechecked def train(cfg: DictConfig) -> Model: """ Trains a model using the configuration `cfg`. Saves model to current working directory (callers should `chdir` to the desired `model_dir` prior to calling). """ model = _train(cfg) # Comment out the above, and uncomment the below to skip # training and go straight to post-training analysis: # import os # model = Model.from_dir(os.getcwd()) _evaluate_on_training_dataset(model) _evaluate_on_ood_dataset(model) _predict_test_videos(model) return model
def _absolute_csv_file(csv_file, data_dir): csv_file = Path(csv_file) if not csv_file.is_absolute(): return Path(data_dir) / csv_file return csv_file def _evaluate_on_training_dataset(model: Model): pretty_print_str("Predicting train/val/test images...") if model.config.is_single_view(): csv_file = _absolute_csv_file( model.config.cfg.data.csv_file, model.config.cfg.data.data_dir ) csv_files = [csv_file] output_filename_stems = ["predictions"] else: csv_files = [] output_filename_stems = [] for csv_file, view_name in zip( model.config.cfg.data.csv_file, model.config.cfg.data.view_names ): csv_files.append( _absolute_csv_file(csv_file, model.config.cfg.data.data_dir) ) output_filename_stems.append(f"predictions_{view_name}") for csv_file, output_filename_stem in zip(csv_files, output_filename_stems): model.predict_on_label_csv_internal( csv_file=csv_file, data_dir=model.config.cfg.data.data_dir, # TODO annotate with train/val/test split metadata. compute_metrics=True, generate_labeled_images=False, output_dir=model.model_dir, output_filename_stem=output_filename_stem, add_train_val_test_set=True, ) def _evaluate_on_ood_dataset(model: Model): if model.config.is_single_view(): csv_file = _absolute_csv_file( model.config.cfg.data.csv_file, model.config.cfg.data.data_dir ) ood_csv_file = csv_file.with_stem(csv_file.stem + "_new") ood_csv_files = [ood_csv_file] output_filename_stems = ["predictions_new"] else: ood_csv_files = [] output_filename_stems = [] for csv_file, view_name in zip( model.config.cfg.data.csv_file, model.config.cfg.data.view_names ): csv_file = _absolute_csv_file(csv_file, model.config.cfg.data.data_dir) ood_csv_file = csv_file.with_stem(csv_file.stem + "_new") ood_csv_files.append(ood_csv_file) output_filename_stems.append(f"predictions_new_{view_name}") if ood_csv_files[0].is_file(): pretty_print_str("Predicting OOD images...") for ood_csv_file, output_filename_stem in zip( ood_csv_files, output_filename_stems ): model.predict_on_label_csv_internal( csv_file=ood_csv_file, data_dir=model.config.cfg.data.data_dir, compute_metrics=True, generate_labeled_images=False, output_dir=model.model_dir, output_filename_stem=output_filename_stem, ) def _predict_test_videos(model: Model): if model.config.cfg.eval.predict_vids_after_training: pretty_print_str(f"Predicting videos in cfg.eval.test_videos_directory...") for video_file in model.config.test_video_files(): pretty_print_str(f"Predicting video: {video_file}...") model.predict_on_video_file( Path(video_file), generate_labeled_video=model.config.cfg.eval.save_vids_after_training, ) def _train(cfg: DictConfig) -> Model: # reset all seeds seed = 0 os.environ["PYTHONHASHSEED"] = str(seed) torch.manual_seed(seed) np.random.seed(seed) random.seed(seed) torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True # record lightning-pose version from lightning_pose import __version__ as lightning_pose_version with open_dict(cfg): cfg.model.lightning_pose_version = lightning_pose_version print("Our Hydra config file:") pretty_print_cfg(cfg) # path handling for toy data data_dir, video_dir = return_absolute_data_paths(data_cfg=cfg.data) # ---------------------------------------------------------------------------------- # Set up data/model objects # ---------------------------------------------------------------------------------- # imgaug transform imgaug_transform = get_imgaug_transform(cfg=cfg) # dataset dataset = get_dataset(cfg=cfg, data_dir=data_dir, imgaug_transform=imgaug_transform) # datamodule; breaks up dataset into train/val/test data_module = get_data_module(cfg=cfg, dataset=dataset, video_dir=video_dir) # build loss factory which orchestrates different losses loss_factories = get_loss_factories(cfg=cfg, data_module=data_module) # model model = get_model(cfg=cfg, data_module=data_module, loss_factories=loss_factories) # ---------------------------------------------------------------------------------- # Save configuration in output directory # ---------------------------------------------------------------------------------- # Done before training; files will exist even if script dies prematurely. hydra_output_directory = os.getcwd() print(f"Hydra output directory: {hydra_output_directory}") # save config file dest_config_file = Path(hydra_output_directory) / "config.yaml" OmegaConf.save(config=cfg, f=dest_config_file, resolve=False) # save labeled data file(s) if isinstance(cfg.data.csv_file, str): # single view csv_files = [cfg.data.csv_file] else: # multi view assert isinstance(cfg.data.csv_file, ListConfig) csv_files = cfg.data.csv_file for csv_file in csv_files: src_csv_file = Path(csv_file) if not src_csv_file.is_absolute(): src_csv_file = Path(data_dir) / src_csv_file dest_csv_file = Path(hydra_output_directory) / src_csv_file.name shutil.copyfile(src_csv_file, dest_csv_file) # ---------------------------------------------------------------------------------- # Set up and run training # ---------------------------------------------------------------------------------- # logger logger = pl.loggers.TensorBoardLogger("tb_logs", name=cfg.model.model_name) # Log hydra config to tensorboard as helpful metadata. for key, value in cfg.items(): logger.experiment.add_text( "hydra_config_%s" % key, "```\n%s```" % OmegaConf.to_yaml(value) ) # early stopping, learning rate monitoring, model checkpointing, backbone unfreezing callbacks = get_callbacks( cfg, early_stopping=cfg.training.get("early_stopping", False), lr_monitor=True, ckpt_every_n_epochs=cfg.training.get("ckpt_every_n_epochs", None), ) # calculate number of batches for both labeled and unlabeled data per epoch limit_train_batches = calculate_train_batches(cfg, dataset) # set up trainer # Old configs may have num_gpus: 0. We will remove support in a future release. if cfg.training.num_gpus == 0: warnings.warn( "Config contains unsupported value num_gpus: 0. " "Update num_gpus to 1 in your config." ) cfg.training.num_gpus = max(cfg.training.num_gpus, 1) trainer = pl.Trainer( accelerator="gpu", devices=cfg.training.num_gpus, max_epochs=cfg.training.max_epochs, min_epochs=cfg.training.min_epochs, check_val_every_n_epoch=min( cfg.training.check_val_every_n_epoch, cfg.training.max_epochs, # for debugging or otherwise training for a short time ), log_every_n_steps=cfg.training.log_every_n_steps, callbacks=callbacks, logger=logger, limit_train_batches=limit_train_batches, accumulate_grad_batches=cfg.training.get("accumulate_grad_batches", 1), profiler=cfg.training.get("profiler", None), sync_batchnorm=True, ) # train model! trainer.fit(model=model, datamodule=data_module) # When devices > 0, lightning creates a process per device. # Kill processes other than the main process, otherwise they all go forward. if not trainer.is_global_zero: sys.exit(0) return Model.from_dir(hydra_output_directory)