Source code for lightning_pose.train

"""Example model training function."""

import contextlib
import json
import math
import os
import random
import re
import shutil
import sys
from collections.abc import Generator
from datetime import datetime
from pathlib import Path

import lightning.pytorch as pl
import numpy as np
import torch
from lightning.pytorch.loggers import TensorBoardLogger
from omegaconf import DictConfig, ListConfig, OmegaConf, open_dict

import lightning_pose
from lightning_pose.api import Model, ModelConfig
from lightning_pose.callbacks import get_callbacks
from lightning_pose.data import (
    get_data_module,
    get_dataset,
    get_imgaug_transform,
)
from lightning_pose.data.datamodules import BaseDataModule, UnlabeledDataModule
from lightning_pose.losses import get_loss_factories
from lightning_pose.models import get_model
from lightning_pose.utils import pretty_print_cfg, pretty_print_str
from lightning_pose.utils.io import (
    return_absolute_data_paths,
)

# to ignore imports for sphinx-autoapidoc
__all__ = ["train"]


# TODO: Replace with contextlib.chdir in python 3.11.
@contextlib.contextmanager
def chdir(dir: str | Path) -> Generator[None, None, None]:
    """Context manager that temporarily changes the working directory.

    Args:
        dir: directory to change into for the duration of the context.

    Yields:
        None; the current working directory is restored on exit.
    """
    pwd = os.getcwd()
    os.chdir(dir)
    try:
        yield
    finally:
        os.chdir(pwd)


def calculate_steps_per_epoch(data_module: BaseDataModule) -> int:
    """Compute the number of optimizer steps per training epoch.

    For semi-supervised (unlabeled) data modules a minimum of 10 steps per epoch is enforced
    so that the model sees sufficient unlabeled data even when labeled data is scarce.

    Args:
        data_module: data module whose train dataset size and batch size are used.

    Returns:
        Integer number of steps per epoch.
    """
    assert data_module.train_dataset is not None
    train_dataset_length = len(data_module.train_dataset)
    steps_per_epoch = math.ceil(train_dataset_length / data_module.train_batch_size)

    # To understand why we do this, see 'max_size_cycle' in UnlabeledDataModule.
    if isinstance(data_module, UnlabeledDataModule):
        steps_per_epoch = max(10, steps_per_epoch)
    return steps_per_epoch


[docs] def train( cfg: DictConfig | ListConfig, model_dir: str | Path | None = None, skip_evaluation: bool = False, ) -> Model: """ Trains a model using the configuration `cfg`. Saves model to `model_dir` (defaults to cwd if unspecified). """ # Default to cwd for backwards compatibility. Future: make model_dir required. model_dir = Path(model_dir or os.getcwd()) model_dir.mkdir(parents=True, exist_ok=True) status_file_path = model_dir / "train_status.json" with chdir(model_dir): model = _train(cfg, status_file=status_file_path) # Comment out the above, and uncomment the below to skip # training and go straight to post-training analysis: # model = Model.from_dir(os.getcwd()) if not skip_evaluation: _evaluate_on_training_dataset(model) _evaluate_on_training_dataset(model, ood_mode=True) _predict_test_videos(model) # Update status file to COMPLETED try: with open(status_file_path) as f: status_file_contents = json.load(f) except FileNotFoundError: status_file_contents = {} status_file_contents["status"] = "COMPLETED" with open(str(status_file_path.with_suffix(".json.tmp")), "w") as f: json.dump(status_file_contents, f) os.replace(str(status_file_path.with_suffix(".json.tmp")), status_file_path) return model
def _absolute_csv_file(csv_file: str | Path, data_dir: str | Path) -> Path: """Return an absolute path to a CSV file, joining with data_dir if necessary. Args: csv_file: path to the CSV file; may be relative or absolute. data_dir: base directory used to resolve relative paths. Returns: Absolute ``pathlib.Path`` to the CSV file. """ csv_file = Path(csv_file) if not csv_file.is_absolute(): return Path(data_dir) / csv_file return csv_file def _evaluate_on_training_dataset(model: Model, ood_mode: bool = False) -> None: """Arguments: ood_mode: look for "_new"-suffixed versions of the training csv file""" if model.config.is_single_view(): csv_file = _absolute_csv_file( model.config.cfg.data.csv_file, model.config.cfg.data.data_dir ) if ood_mode: csv_file = csv_file.with_stem(csv_file.stem + "_new") csv_files = [csv_file] else: csv_files = [] for csv_file in model.config.cfg.data.csv_file: csv_file = _absolute_csv_file(csv_file, model.config.cfg.data.data_dir) if ood_mode: csv_file = csv_file.with_stem(csv_file.stem + "_new") csv_files.append(csv_file) if model.config.cfg.data.get("camera_params_file"): camera_params_file = _absolute_csv_file( model.config.cfg.data.camera_params_file, model.config.cfg.data.data_dir, ) if ood_mode: camera_params_file = camera_params_file.with_stem(camera_params_file.stem + "_new") else: camera_params_file = None # NOTE: setting bbox_files = None here is a hacky way to get the model predictions # to be in the cropped image space; otherwise the bbox info would lead to # predictions in the original image space. This can be achieved post-hoc by using # the CLI remap command. bbox_files = None # This is how the code would look without the hack # if model.config.cfg.data.get("bbox_file"): # bbox_files = [] # for bbox_file in model.config.cfg.data.bbox_file: # bbox_file = _absolute_csv_file(bbox_file, model.config.cfg.data.data_dir) # if ood_mode: # bbox_file = bbox_file.with_stem(bbox_file.stem + "_new") # bbox_files.append(bbox_file) # else: # bbox_files = None # ood mode: skip prediction when _new files don't exist. if ood_mode and not csv_files[0].exists(): return # Print a custom message when in OOD mode. if ood_mode: pretty_print_str("Predicting OOD images...") else: pretty_print_str("Predicting train/val/test images...") # Run prediction and metric computation. if model.config.is_multi_view(): model.predict_on_label_csv_multiview( csv_file_per_view=csv_files, bbox_file_per_view=bbox_files, camera_params_file=camera_params_file, data_dir=model.config.cfg.data.data_dir, compute_metrics=True, add_train_val_test_set=(not ood_mode), ) else: csv_file = csv_files[0] model.predict_on_label_csv( csv_file=csv_file, data_dir=model.config.cfg.data.data_dir, compute_metrics=True, add_train_val_test_set=(not ood_mode), ) # Copy prediction files to legacy location in model dir. for i, csv_file in enumerate(csv_files): if len(csv_files) > 1: view_name = model.config.cfg.data.view_names[i] # Copy output files to model_dir for backward-compatibility. # New users should look up these files in image_preds. for p_file in (model.image_preds_dir() / csv_file.name).glob("predictions*.csv"): m = re.match(r"predictions(.*)\.csv", p_file.name) metric_suffix = m[1] if m else "" out_file = "predictions" if len(csv_files) > 1: out_file += "_" + view_name if metric_suffix: out_file += metric_suffix if ood_mode: out_file += "_new" out_file += ".csv" out_file = model.model_dir / out_file shutil.copy(p_file, out_file) def _predict_test_videos(model: Model) -> None: """Run video prediction on test videos specified in the config, if enabled. Args: model: trained model used for prediction. """ if model.config.cfg.eval.predict_vids_after_training: pretty_print_str("Predicting videos in cfg.eval.test_videos_directory...") # dealing with multiview if model.config.is_multi_view(): for video_file_per_view in model.config.test_video_files_multiview(): model.predict_on_video_file_multiview( video_file_per_view=video_file_per_view, compute_metrics=True, generate_labeled_video=model.config.cfg.eval.save_vids_after_training, ) else: for video_file in model.config.test_video_files_singleview(): pretty_print_str(f"Predicting video: {video_file}...") model.predict_on_video_file( Path(video_file), generate_labeled_video=model.config.cfg.eval.save_vids_after_training, ) def _train(cfg: DictConfig | ListConfig, status_file: Path | None = None) -> Model: """Build data/model objects, train, and return the trained model. Args: cfg: hydra config containing all training parameters. status_file: optional path to a JSON file where training progress will be written. Returns: The trained ``Model`` instance. """ # reset all seeds seed = 0 os.environ["PYTHONHASHSEED"] = str(seed) torch.manual_seed(seed) np.random.seed(seed) random.seed(seed) torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True # record lightning-pose version with open_dict(cfg): cfg.creation_datetime = datetime.now().isoformat() cfg.model.lightning_pose_version = lightning_pose.version print("Config file:") pretty_print_cfg(cfg) ModelConfig(cfg).validate() # path handling for toy data data_dir, video_dir = return_absolute_data_paths(data_cfg=cfg.data) # ---------------------------------------------------------------------------------- # Set up data/model objects # ---------------------------------------------------------------------------------- # imgaug transform imgaug_transform = get_imgaug_transform(cfg=cfg) # dataset dataset = get_dataset(cfg=cfg, data_dir=data_dir, imgaug_transform=imgaug_transform) # datamodule; breaks up dataset into train/val/test data_module = get_data_module(cfg=cfg, dataset=dataset, video_dir=video_dir) # build loss factory which orchestrates different losses loss_factories = get_loss_factories(cfg=cfg, data_module=data_module) steps_per_epoch = calculate_steps_per_epoch(data_module) # convert milestone_steps to milestones if applicable (before `get_model`). if ( "multisteplr" in cfg.training.lr_scheduler_params and "milestone_steps" in cfg.training.lr_scheduler_params.multisteplr ): milestone_steps = cfg.training.lr_scheduler_params.multisteplr.milestone_steps milestones = [math.ceil(s / steps_per_epoch) for s in milestone_steps] cfg.training.lr_scheduler_params.multisteplr.milestones = milestones # convert patch masking epochs if applicable (before `get_callbacks`) if "patch_mask" in cfg.training and "init_epoch" in cfg.training.patch_mask: init_step = math.ceil(cfg.training.patch_mask.init_epoch * steps_per_epoch) final_step = math.ceil(cfg.training.patch_mask.final_epoch * steps_per_epoch) with open_dict(cfg): cfg.training.patch_mask.init_step = init_step cfg.training.patch_mask.final_step = final_step # model model = get_model(cfg=cfg, data_module=data_module, loss_factories=loss_factories) # ---------------------------------------------------------------------------------- # Save configuration in output directory # ---------------------------------------------------------------------------------- # Done before training; files will exist even if script dies prematurely. hydra_output_directory = os.getcwd() print(f"Hydra output directory: {hydra_output_directory}") # save config file dest_config_file = Path(hydra_output_directory) / "config.yaml" OmegaConf.save(config=cfg, f=dest_config_file, resolve=False) # save labeled data file(s) if isinstance(cfg.data.csv_file, str): # single view csv_files = [cfg.data.csv_file] else: # multi view assert isinstance(cfg.data.csv_file, ListConfig) csv_files = cfg.data.csv_file for csv_file in csv_files: src_csv_file = Path(csv_file) if not src_csv_file.is_absolute(): src_csv_file = Path(data_dir) / src_csv_file dest_csv_file = Path(hydra_output_directory) / src_csv_file.name shutil.copyfile(src_csv_file, dest_csv_file) # ---------------------------------------------------------------------------------- # Set up and run training # ---------------------------------------------------------------------------------- # logger logger = TensorBoardLogger("tb_logs", name=cfg.model.model_name) # Log hydra config to tensorboard as helpful metadata. for key, value in cfg.items(): logger.experiment.add_text( f"hydra_config_{key}", f"```\n{value if isinstance(value, str) else OmegaConf.to_yaml(value)}```", ) # early stopping, learning rate monitoring, model checkpointing, backbone unfreezing callbacks = get_callbacks( cfg, early_stopping=cfg.training.get("early_stopping", False), lr_monitor=True, ckpt_every_n_epochs=cfg.training.get("ckpt_every_n_epochs", None), status_file=status_file, ) # set up trainer cfg.training.num_gpus = max(cfg.training.num_gpus, 1) # initialize to Trainer defaults. Note max_steps defaults to -1. min_steps, max_steps, min_epochs, max_epochs = (None, -1, None, None) if "min_steps" in cfg.training: min_steps = cfg.training.min_steps max_steps = cfg.training.max_steps else: min_epochs = cfg.training.min_epochs max_epochs = cfg.training.max_epochs # Unlike min_epoch/min_step, both of these are valid to specify. check_val_every_n_epoch = cfg.training.get("check_val_every_n_epoch", 1) val_check_interval = cfg.training.get("val_check_interval") trainer = pl.Trainer( accelerator="gpu", devices=cfg.training.num_gpus, max_epochs=max_epochs, min_epochs=min_epochs, max_steps=max_steps, min_steps=min_steps, check_val_every_n_epoch=check_val_every_n_epoch, val_check_interval=val_check_interval, log_every_n_steps=cfg.training.log_every_n_steps, callbacks=callbacks, logger=logger, # To understand why we set this, see 'max_size_cycle' in UnlabeledDataModule. limit_train_batches=cfg.training.get("limit_train_batches") or steps_per_epoch, accumulate_grad_batches=cfg.training.get("accumulate_grad_batches", 1), profiler=cfg.training.get("profiler", None), sync_batchnorm=True, ) # train model! trainer.fit(model=model, datamodule=data_module) # When devices > 0, lightning creates a process per device. # Kill processes other than the main process, otherwise they all go forward. if not trainer.is_global_zero: sys.exit(0) return Model.from_dir(hydra_output_directory)