Source code for lightning_pose.api.model

from __future__ import annotations

import copy
from pathlib import Path

import pandas as pd
from omegaconf import DictConfig, OmegaConf

from lightning_pose.api.model_config import ModelConfig
from lightning_pose.data.datatypes import MultiviewPredictionResult, PredictionResult
from lightning_pose.models import ALLOWED_MODELS
from lightning_pose.utils import io as io_utils
from lightning_pose.utils.predictions import generate_labeled_video as generate_labeled_video_fn
from lightning_pose.utils.predictions import (
    load_model_from_checkpoint,
    predict_dataset,
    predict_video,
)
from lightning_pose.utils.scripts import (
    compute_metrics_single,
    get_data_module,
    get_dataset,
    get_imgaug_transform,
)

__all__ = ["Model"]


[docs] class Model: model_dir: Path """Directory the model is stored in.""" config: ModelConfig """The model configuration stored as a `ModelConfig` object. `ModelConfig` wraps the `omegaconf.DictConfig` and provides util functions over it. """ model: ALLOWED_MODELS | None = None # Just a constant we can use as a default value for kwargs, # to differentiate between user omitting a kwarg, vs explicitly passing None. UNSPECIFIED = "unspecified"
[docs] @staticmethod def from_dir(model_dir: str | Path): """Create a `Model` instance for a model stored at `model_dir`.""" return Model.from_dir2(model_dir)
@staticmethod def from_dir2(model_dir: str | Path, hydra_overrides: list[str] = None): """Internal version of from_dir that supports hydra_overrides. Not sure whether to promote this to public API yet.""" model_dir = Path(model_dir).absolute() if hydra_overrides is not None: import hydra with hydra.initialize_config_dir( version_base="1.1", config_dir=str(model_dir) ): cfg = hydra.compose(config_name="config", overrides=hydra_overrides) config = ModelConfig(cfg) else: config = ModelConfig.from_yaml_file(model_dir / "config.yaml") return Model(model_dir, config) def __init__(self, model_dir: str | Path, config: ModelConfig): self.model_dir = Path(model_dir).absolute() self.config = config @property def cfg(self) -> DictConfig: """The model configuration as an `omegaconf.DictConfig`.""" return self.config.cfg def _load(self): if self.model is None: ckpt_file = io_utils.ckpt_path_from_base_path( base_path=str(self.model_dir), model_name=self.cfg.model.model_name ) if ckpt_file is None: raise FileNotFoundError( "Checkpoint file not found, have you trained for enough epochs?" ) self.model = load_model_from_checkpoint( cfg=self.cfg, ckpt_file=ckpt_file, eval=True, skip_data_module=True, ) def image_preds_dir(self) -> Path: return self.model_dir / "image_preds" def video_preds_dir(self) -> Path: return self.model_dir / "video_preds" def labeled_videos_dir(self) -> Path: return self.model_dir / "video_preds" / "labeled_videos" def cropped_data_dir(self): return self.model_dir / "cropped_images" def cropped_videos_dir(self): return self.model_dir / "cropped_videos" def cropped_csv_file_path(self, csv_file_path: str | Path): csv_file_path = Path(csv_file_path) return ( self.model_dir / "image_preds" / csv_file_path.name / ("cropped_" + csv_file_path.name) )
[docs] def predict_on_label_csv( self, csv_file: str | Path, data_dir: str | Path | None = None, compute_metrics: bool = True, add_train_val_test_set: bool = False, ) -> PredictionResult: """Predicts on a labeled dataset and computes error/loss metrics if applicable. Args: csv_file (str | Path): Path to the CSV file of images, keypoint locations. data_dir (str | Path, optional): Root path for relative paths in the CSV file. Defaults to the data_dir originally used when training. compute_metrics (bool, optional): Whether to compute pixel error and loss metrics on predictions. generate_labeled_images (bool, optional): Whether to save labeled images. Defaults to False. output_dir (str | Path, optional): The directory to save outputs to. Defaults to `{model_dir}/image_preds/{csv_file_name}`. If set to None, outputs are not saved. add_train_val_test_set (bool): When predicting on training dataset, set to true to add the `set` column to the prediction output. Returns: PredictionResult: A PredictionResult object containing the predictions and metrics. """ self._load() # Convert this to absolute, because if relative, downstream will # assume its relative to the data_dir. csv_file = Path(csv_file).absolute() if data_dir is None: data_dir = self.config.cfg.data.data_dir output_dir = self.image_preds_dir() / csv_file.name output_dir.mkdir(parents=True, exist_ok=True) # Point predict_dataset to the csv_file and data_dir. # HACK: For true multi-view model, trick predict_dataset and compute_metrics # into thinking this is a single-view model. cfg_overrides = { "data": { "data_dir": str(data_dir), "csv_file": str(csv_file), } } # Avoid annotating set=train/val/test for CSV file other than the training CSV file. if not add_train_val_test_set: cfg_overrides.update({"train_prob": 1, "val_prob": 0, "train_frames": 1}) cfg_pred = OmegaConf.merge(self.cfg, cfg_overrides) # HACK: For true multi-view model, trick predict_dataset and compute_metrics # into thinking this is a single-view model. if self.config.is_multi_view(): del cfg_pred.data.view_names # HACK: If we don't delete mirrored_column_matches, downstream # interprets this as a mirrored multiview model, and compute_metrics fails. del cfg_pred.data.mirrored_column_matches data_module_pred = _build_datamodule_pred(cfg_pred) preds_file_path = output_dir / "predictions.csv" preds_file = str(preds_file_path) df = predict_dataset( cfg_pred, data_module_pred, model=self.model, preds_file=preds_file ) if compute_metrics: metrics = compute_metrics_single( cfg=cfg_pred, labels_file=str(csv_file), preds_file=preds_file, data_module=data_module_pred, ) else: metrics = None return PredictionResult(predictions=df, metrics=metrics)
[docs] def predict_on_label_csv_multiview( self, csv_file_per_view: list[str] | list[Path], bbox_file_per_view: list[str] | list[Path] | None = None, camera_params_file: str | Path | None = None, data_dir: str | Path | None = None, compute_metrics: bool = True, add_train_val_test_set: bool = False, ) -> MultiviewPredictionResult: """Version of `predict_on_label_csv` that gives models access to all views of each frame. Arguments: csv_file_per_view (list[str] | list[Path]): A list of csv files each from a different view of the same session. Order must match the `view_names` in the config file. See `predict_on_label_csv` docstring for other arguments.""" assert self.config.is_multi_view() self._load() view_names = self.config.cfg.data.view_names assert len(csv_file_per_view) == len( view_names ), f"{len(csv_file_per_view)} != {len(view_names)}" # Convert this to absolute, because if relative, downstream will # assume its relative to the data_dir. csv_file_per_view: list[Path] = [Path(f).absolute() for f in csv_file_per_view] if data_dir is None: data_dir = self.config.cfg.data.data_dir # Point predict_dataset to the csv_file and data_dir. cfg_overrides = { "data": { "data_dir": str(data_dir), "csv_file": [str(p) for p in csv_file_per_view], } } if camera_params_file: cfg_overrides["data"]["camera_params_file"] = camera_params_file if bbox_file_per_view: cfg_overrides["data"]["bbox_file"] = [str(p) for p in bbox_file_per_view] else: cfg_overrides["data"]["bbox_file"] = None # Avoid annotating set=train/val/test for CSV file other than the training CSV file. if not add_train_val_test_set: cfg_overrides.update({"train_prob": 1, "val_prob": 0, "train_frames": 1}) cfg_pred = OmegaConf.merge(self.cfg, cfg_overrides) data_module_pred = _build_datamodule_pred(cfg_pred) preds_files = [] for i, view_name in enumerate(view_names): output_dir = self.image_preds_dir() / csv_file_per_view[i].name output_dir.mkdir(parents=True, exist_ok=True) preds_files.append(str(output_dir / "predictions.csv")) # Outputs dict[str, pd.DataFrame] because inputs indicate multiview. view_to_df_dict = predict_dataset( cfg_pred, data_module_pred, model=self.model, preds_file=preds_files ) if compute_metrics: metrics = {} for view_name, labels_file, _preds_file in zip( view_names, csv_file_per_view, preds_files ): metrics[view_name] = compute_metrics_single( cfg=self.cfg, labels_file=str(labels_file), preds_file=_preds_file, data_module=data_module_pred, ) else: metrics = None return MultiviewPredictionResult(predictions=view_to_df_dict, metrics=metrics)
[docs] def predict_on_video_file( self, video_file: str | Path, output_dir: str | Path | None = UNSPECIFIED, compute_metrics: bool = True, generate_labeled_video: bool = False, ) -> PredictionResult: """Predicts on a video file and computes unsupervised loss metrics if applicable. Args: video_file (str | Path): Path to the video file. output_dir (str | Path, optional): The directory to save outputs to. Defaults to `{model_dir}/image_preds/{csv_file_name}`. If set to None, outputs are not saved. compute_metrics (bool, optional): Whether to compute pixel error and loss metrics on predictions. generate_labeled_video (bool, optional): Whether to save a labeled video. Defaults to False. Returns: PredictionResult: A PredictionResult object containing the predictions and metrics. """ self._load() video_file = Path(video_file) if output_dir == self.__class__.UNSPECIFIED: output_dir = self.video_preds_dir() elif output_dir is None: raise NotImplementedError("Currently we must save predictions") output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) prediction_csv_file = output_dir / f"{video_file.stem}.csv" df: pd.DataFrame = predict_video( video_file=str(video_file), model=self, output_pred_file=str(prediction_csv_file), ) if generate_labeled_video: labeled_mp4_file = str( self.labeled_videos_dir() / f"{video_file.stem}_labeled.mp4" ) generate_labeled_video_fn( video_file=str(video_file), preds_df=df, output_mp4_file=labeled_mp4_file, confidence_thresh_for_vid=self.cfg.eval.confidence_thresh_for_vid, colormap=self.cfg.eval.get("colormap", "cool"), ) if compute_metrics: # FIXME: Data module is only used for computing PCA metrics. data_module = _build_datamodule_pred(self.cfg) metrics = compute_metrics_single( cfg=self.cfg, labels_file=None, preds_file=str(prediction_csv_file), data_module=data_module, ) else: metrics = None return PredictionResult(predictions=df, metrics=metrics)
[docs] def predict_on_video_file_multiview( self, video_file_per_view: list[str] | list[Path], output_dir: str | Path | None = UNSPECIFIED, compute_metrics: bool = True, generate_labeled_video: bool = False, ) -> MultiviewPredictionResult: """Version of `predict_on_video_file` that accesses to multiple camera views of each frame. Arguments: video_file_per_view (list[str] | list[Path]): A list of video files each from a different view of the same session. Number of video files must match the `view_names` in the config file. Order of the list does not matter: video files are intelligently matched to views by their filename using `utils.io.collect_video_files_by_view`. output_dir (str | Path, optional): The directory to save outputs to. Defaults to `{model_dir}/image_preds/{csv_file_name}`. If set to None, outputs are not saved. compute_metrics (bool, optional): Whether to compute pixel error and loss metrics on predictions. generate_labeled_video (bool, optional): Whether to save a labeled video. Defaults to False. Returns: MultiviewPredictionResult: object containing the predictions and metrics for each view. """ assert self.config.is_multi_view() self._load() view_names = self.config.cfg.data.view_names assert len(video_file_per_view) == len( view_names ), f"{len(video_file_per_view)} != {len(view_names)}" video_file_per_view: list[Path] = [Path(f) for f in video_file_per_view] if output_dir == self.__class__.UNSPECIFIED: output_dir = self.video_preds_dir() elif output_dir is None: raise NotImplementedError("Currently we must save predictions") output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) # Arranges video_file_per_view to be in the same order as cfg.data.view_names. _view_to_video_file: dict[str, Path] = io_utils.collect_video_files_by_view( video_file_per_view, view_names ) video_file_per_view: list[Path] = [ _view_to_video_file[view_name] for view_name in view_names ] prediction_csv_file_list = [ str(output_dir / f"{video_file.stem}.csv") for video_file in video_file_per_view ] df_list: list[pd.DataFrame] = predict_video( video_file=list(map(str, video_file_per_view)), model=self, output_pred_file=prediction_csv_file_list, ) if generate_labeled_video: for video_file, preds_df in zip(video_file_per_view, df_list): labeled_mp4_file = str( self.labeled_videos_dir() / f"{video_file.stem}_labeled.mp4" ) generate_labeled_video_fn( video_file=str(video_file), preds_df=preds_df, output_mp4_file=labeled_mp4_file, confidence_thresh_for_vid=self.cfg.eval.confidence_thresh_for_vid, colormap=self.cfg.eval.get("colormap", "cool"), ) data_module = _build_datamodule_pred(self.cfg) if compute_metrics: metrics = {} for view_name, preds_file in zip(view_names, prediction_csv_file_list): metrics[view_name] = compute_metrics_single( cfg=self.cfg, labels_file=None, preds_file=preds_file, data_module=data_module, ) else: metrics = None df_dict = {view_name: df for view_name, df in zip(view_names, df_list)} return MultiviewPredictionResult(predictions=df_dict, metrics=metrics)
def _build_datamodule_pred(cfg: DictConfig): cfg_pred = copy.deepcopy(cfg) cfg_pred.training.imgaug = "default" imgaug_transform_pred = get_imgaug_transform(cfg=cfg_pred) dataset_pred = get_dataset( cfg=cfg_pred, data_dir=cfg_pred.data.data_dir, imgaug_transform=imgaug_transform_pred, ) data_module_pred = get_data_module( cfg=cfg_pred, dataset=dataset_pred, video_dir=cfg_pred.data.video_dir ) return data_module_pred