Source code for lightning_pose.api.model

"""High-level Model class for loading trained checkpoints and running inference."""

from __future__ import annotations

import copy
from pathlib import Path
from typing import Any, cast

import cv2
import numpy as np
import pandas as pd
import torch
from omegaconf import DictConfig, ListConfig, OmegaConf

from lightning_pose.api.model_config import ModelConfig
from lightning_pose.data import (
    _IMAGENET_MEAN,
    _IMAGENET_STD,
    get_data_module,
    get_dataset,
    get_imgaug_transform,
)
from lightning_pose.data.datamodules import BaseDataModule, UnlabeledDataModule
from lightning_pose.data.datatypes import MultiviewPredictionResult, PredictionResult
from lightning_pose.data.utils import convert_bbox_coords
from lightning_pose.metrics import compute_metrics_single
from lightning_pose.models import ALLOWED_MODEL_TYPES, ALLOWED_MODELS
from lightning_pose.utils import io as io_utils
from lightning_pose.utils.predictions import generate_labeled_video as generate_labeled_video_fn
from lightning_pose.utils.predictions import (
    predict_dataset,
    predict_video,
)

__all__ = ["Model", "get_model_class", "load_model_from_checkpoint"]


def get_model_class(map_type: ALLOWED_MODEL_TYPES, semi_supervised: bool) -> type[ALLOWED_MODELS]:
    """Return the model class for the given model type and supervision mode.

    Args:
        map_type: one of ``"regression"``, ``"heatmap"``, ``"heatmap_mhcrnn"``,
            ``"heatmap_multiview_transformer"``.
        semi_supervised: True to return the semi-supervised variant.

    Returns:
        model class (not an instance).

    Raises:
        NotImplementedError: if ``map_type`` is not recognised.

    """
    if not semi_supervised:
        if map_type == 'regression':
            from lightning_pose.models import RegressionTracker as ModelClass
        elif map_type == 'heatmap':
            from lightning_pose.models import HeatmapTracker as ModelClass
        elif map_type == 'heatmap_mhcrnn':
            from lightning_pose.models import HeatmapTrackerMHCRNN as ModelClass
        elif map_type == 'heatmap_multiview_transformer':
            from lightning_pose.models import HeatmapTrackerMultiviewTransformer as ModelClass
        else:
            raise NotImplementedError(
                f'{map_type} is an invalid model_type for a fully supervised model'
            )
    else:
        if map_type == 'regression':
            from lightning_pose.models import SemiSupervisedRegressionTracker as ModelClass
        elif map_type == 'heatmap':
            from lightning_pose.models import SemiSupervisedHeatmapTracker as ModelClass
        elif map_type == 'heatmap_mhcrnn':
            from lightning_pose.models import SemiSupervisedHeatmapTrackerMHCRNN as ModelClass
        elif map_type == 'heatmap_multiview_transformer':
            from lightning_pose.models import (
                SemiSupervisedHeatmapTrackerMultiviewTransformer as ModelClass,
            )
        else:
            raise NotImplementedError(
                f'{map_type} is an invalid model_type for a semi-supervised model'
            )
    return ModelClass


def load_model_from_checkpoint(
    cfg: DictConfig | ListConfig,
    ckpt_file: str | None,
    eval: bool = False,
    data_module: BaseDataModule | UnlabeledDataModule | None = None,
    skip_data_module: bool = False,
) -> ALLOWED_MODELS:
    """Load a Lightning Pose model from a checkpoint file.

    Args:
        cfg: model config
        ckpt_file: absolute path to model checkpoint
        eval: True for eval mode, False for train mode
        data_module: used to initialise unsupervised losses
        skip_data_module: if ``data_module`` is not None this is ignored.
            If False and ``data_module=None``, a data module is created from the config file and
            unsupervised losses are accessible in the model.
            If True and ``data_module=None``, the unsupervised losses are not accessible in the
            model; recommended for running inference on new videos.

    Returns:
        model as a Lightning Module

    Raises:
        ValueError: if ``ckpt_file`` is None

    """
    if ckpt_file is None:
        raise ValueError('ckpt_file must be provided to load a model from checkpoint')
    from lightning_pose.data import (
        get_data_module,
        get_dataset,
        get_imgaug_transform,
    )
    from lightning_pose.losses import get_loss_factories
    from lightning_pose.models import check_if_semi_supervised
    from lightning_pose.utils.io import return_absolute_data_paths

    delete_extras = False
    if not data_module and not skip_data_module:
        delete_extras = True
        data_dir, video_dir = return_absolute_data_paths(data_cfg=cfg.data)
        imgaug_transform = get_imgaug_transform(cfg=cfg)
        dataset = get_dataset(cfg=cfg, data_dir=data_dir, imgaug_transform=imgaug_transform)
        data_module = get_data_module(cfg=cfg, dataset=dataset, video_dir=video_dir)
    if not data_module:
        loss_factories = {'supervised': None, 'unsupervised': None}
    else:
        loss_factories = get_loss_factories(cfg=cfg, data_module=data_module)

    semi_supervised = check_if_semi_supervised(cfg.model.losses_to_use)
    ModelClass = get_model_class(
        map_type=cfg.model.model_type,
        semi_supervised=semi_supervised,
    )

    try:
        checkpoint = torch.load(ckpt_file)
    except Exception as e:
        print(f'Warning: Failed to load checkpoint with default settings: {e}')
        print('Attempting to load with weights_only=False...')
        checkpoint = torch.load(ckpt_file, weights_only=False)
    state_dict = checkpoint.get('state_dict', checkpoint)

    # fix state dict key mismatch for upsampling layers in old checkpoints
    keys_remapped = False
    for key in list(state_dict.keys()):
        if key.startswith('upsampling_layers.'):
            state_dict['head.' + key] = state_dict.pop(key)
            keys_remapped = True

    if keys_remapped:
        checkpoint['state_dict'] = state_dict
        import tempfile
        with tempfile.NamedTemporaryFile(suffix='.ckpt', delete=False) as tmp_file:
            torch.save(checkpoint, tmp_file.name)
            fixed_ckpt_file = tmp_file.name
    else:
        fixed_ckpt_file = ckpt_file

    if semi_supervised:
        model = ModelClass.load_from_checkpoint(
            fixed_ckpt_file,
            loss_factory=loss_factories['supervised'],
            loss_factory_unsupervised=loss_factories['unsupervised'],
            strict=False,
        )
    else:
        model = ModelClass.load_from_checkpoint(
            fixed_ckpt_file,
            loss_factory=loss_factories['supervised'],
            strict=False,
        )

    if keys_remapped:
        import os
        os.unlink(fixed_ckpt_file)

    if eval:
        model.eval()

    if delete_extras:
        del imgaug_transform
        del dataset
        del data_module
    del loss_factories
    torch.cuda.empty_cache()

    return model


[docs] class Model: """High-level interface for inference with a trained lightning-pose model. Load a saved model with `Model.from_dir`, then call prediction methods directly. Model weights are loaded lazily on the first prediction call. Attributes: model_dir: absolute path to the directory the model is stored in. config: the model configuration as a `ModelConfig` object. model: the underlying PyTorch model; None until the first prediction call. Examples: >>> from lightning_pose.api import Model >>> model = Model.from_dir("outputs/2024-01-01/12-00-00") Single-frame inference (no file I/O): >>> import numpy as np >>> frame = np.zeros((256, 256, 3), dtype=np.uint8) >>> result = model.predict_frame(frame) >>> result["keypoints"].shape # (num_keypoints, 2) >>> result["confidence"].shape # (num_keypoints,) Predict on a video file: >>> pred_result = model.predict_on_video_file("path/to/video.mp4") >>> pred_result.predictions # pd.DataFrame with MultiIndex columns >>> pred_result.metrics # ComputeMetricsSingleResult or None Predict on a labeled CSV (also computes pixel error): >>> pred_result = model.predict_on_label_csv("path/to/CollectedData.csv") """ model_dir: Path """Directory the model is stored in.""" config: ModelConfig """The model configuration stored as a `ModelConfig` object. `ModelConfig` wraps the `omegaconf.DictConfig` and provides util functions over it. """ model: ALLOWED_MODELS | None = None # Just a constant we can use as a default value for kwargs, # to differentiate between user omitting a kwarg, vs explicitly passing None. UNSPECIFIED = "unspecified"
[docs] @staticmethod def from_dir(model_dir: str | Path) -> Model: """Create a `Model` instance for a model stored at `model_dir`. Args: model_dir: path to a model output directory containing ``config.yaml`` and a ``.ckpt`` checkpoint file. Returns: Model ready for inference. Weights are loaded lazily on the first prediction call. Examples: >>> from lightning_pose.api import Model >>> model = Model.from_dir("outputs/2024-01-01/12-00-00") >>> model.config.is_multi_view() False """ return Model.from_dir2(model_dir)
@staticmethod def from_dir2(model_dir: str | Path, hydra_overrides: list[str] | None = None) -> Model: """Internal version of from_dir that supports hydra_overrides. Not sure whether to promote this to public API yet.""" model_dir = Path(model_dir).absolute() if hydra_overrides is not None: import hydra with hydra.initialize_config_dir( version_base="1.1", config_dir=str(model_dir) ): cfg = hydra.compose(config_name="config", overrides=hydra_overrides) config = ModelConfig(cfg) else: config = ModelConfig.from_yaml_file(model_dir / "config.yaml") return Model(model_dir, config) def __init__(self, model_dir: str | Path, config: ModelConfig) -> None: """Initialize a Model from a directory and a pre-loaded config. Prefer `Model.from_dir` for typical usage. Use this constructor when you have already constructed a `ModelConfig` (e.g. after applying Hydra overrides). Args: model_dir: path to the model output directory. config: the model configuration. """ self.model_dir = Path(model_dir).absolute() self.config = config @property def cfg(self) -> DictConfig | ListConfig: """The model configuration as an `omegaconf.DictConfig`.""" return self.config.cfg def _load(self) -> None: """Load model weights from the checkpoint file on first call; no-op thereafter. Raises: FileNotFoundError: if no checkpoint file is found in `model_dir`. """ if self.model is None: ckpt_file = io_utils.ckpt_path_from_base_path( base_path=str(self.model_dir), model_name=self.cfg.model.model_name ) if ckpt_file is None: raise FileNotFoundError( "Checkpoint file not found, have you trained for enough epochs?" ) self.model = load_model_from_checkpoint( cfg=self.cfg, ckpt_file=ckpt_file, eval=True, skip_data_module=True, )
[docs] def image_preds_dir(self) -> Path: """Return the directory where image/CSV predictions are saved.""" return self.model_dir / "image_preds"
[docs] def video_preds_dir(self) -> Path: """Return the directory where video predictions are saved.""" return self.model_dir / "video_preds"
[docs] def labeled_videos_dir(self) -> Path: """Return the directory where prediction-annotated videos are saved.""" return self.model_dir / "video_preds" / "labeled_videos"
[docs] def cropped_data_dir(self) -> Path: """Return the directory where cropzoom-cropped images are saved.""" return self.model_dir / "cropped_images"
[docs] def cropped_videos_dir(self) -> Path: """Return the directory where cropzoom-cropped videos are saved.""" return self.model_dir / "cropped_videos"
[docs] def cropped_csv_file_path(self, csv_file_path: str | Path) -> Path: """Return the path where a cropzoom-adjusted CSV file will be saved. Args: csv_file_path: path to the original labeled CSV file. Returns: path of the form ``{model_dir}/image_preds/{csv_name}/cropped_{csv_name}``. """ csv_file_path = Path(csv_file_path) return ( self.model_dir / "image_preds" / csv_file_path.name / ("cropped_" + csv_file_path.name) )
[docs] def predict_frame( self, frame_rgb: np.ndarray, bbox: tuple[int, int, int, int] | None = None, ) -> dict[str, np.ndarray]: """Single-frame inference. No file I/O, no DALI. Preprocessing uses cv2 (not DALI). Results will differ numerically from ``predict_on_video_file`` due to interpolation and normalization differences. Do not mix results from the two paths in quantitative analysis. For MHCRNN (context) models, pass a ``(T, H, W, 3)`` array where T is the temporal context length (typically 5). Passing a single frame to a context model raises ``ValueError`` — use ``predict_on_video_file`` for proper temporal inference. The first call triggers model loading and CUDA initialization, which may take several seconds. Subsequent calls are fast (~5-50ms depending on backbone). For latency-sensitive loops, call once on a dummy frame before entering the loop. Args: frame_rgb: ``(H, W, 3)`` uint8 RGB array for standard models, or ``(T, H, W, 3)`` uint8 RGB array for context (MHCRNN) models. bbox: Optional ``(x, y, w, h)`` crop region. Note: this is ``(x, y, width, height)``, NOT ``(x1, y1, x2, y2)``. If provided, crops first, then remaps keypoints back to original coordinates. Returns: {"keypoints": (num_kp, 2) float32 array (x, y) in original frame coords, "confidence": (num_kp,) float32 in [0, 1] -- likelihood/confidence per keypoint. For regression models, confidence is always 1.0.} Raises: ValueError: If frame_rgb has wrong shape/dtype, bbox has non-positive dimensions, bbox produces an empty crop, or a context model receives single-frame input. Examples: >>> import numpy as np >>> frame = np.zeros((256, 256, 3), dtype=np.uint8) >>> result = model.predict_frame(frame) >>> result["keypoints"].shape # (num_keypoints, 2) >>> result["confidence"].shape # (num_keypoints,) With a bounding-box crop (x, y, width, height): >>> result = model.predict_frame(frame, bbox=(100, 50, 128, 128)) """ self._load() if self.model is None: raise RuntimeError('model failed to load; self.model is None after _load()') # --- Input validation --- if frame_rgb.dtype != np.uint8: raise ValueError( f"frame_rgb must be uint8, got {frame_rgb.dtype}. " "Convert with frame.astype(np.uint8) if values are in [0, 255]." ) is_context_input = frame_rgb.ndim == 4 if is_context_input: if frame_rgb.shape[3] != 3: raise ValueError( f"frame_rgb must be (T, H, W, 3), got shape {frame_rgb.shape}" ) elif frame_rgb.ndim == 3: if frame_rgb.shape[2] != 3: raise ValueError( f"frame_rgb must be (H, W, 3), got shape {frame_rgb.shape}" ) else: raise ValueError( f"frame_rgb must be (H, W, 3) or (T, H, W, 3), " f"got {frame_rgb.ndim}D array with shape {frame_rgb.shape}" ) if frame_rgb.size == 0: raise ValueError("frame_rgb is empty") is_context_model = self.model.do_context if is_context_model and not is_context_input: raise ValueError( "Context model requires frame_rgb of shape (T, H, W, 3) " "where T is the temporal context length (typically 5). " "Use predict_on_video_file for single-frame input." ) # --- Crop --- if bbox is not None: bx, by, bw, bh = bbox if bx < 0 or by < 0: raise ValueError( f"bbox origin must be non-negative, got x={bx}, y={by}" ) if bw <= 0 or bh <= 0: raise ValueError( f"bbox width and height must be positive, got w={bw}, h={bh}" ) if is_context_input: crop = frame_rgb[:, by:by + bh, bx:bx + bw] else: crop = frame_rgb[by:by + bh, bx:bx + bw] if crop.size == 0: raise ValueError( f"bbox (x={bx}, y={by}, w={bw}, h={bh}) produces an empty " f"crop on frame of shape {frame_rgb.shape}" ) # Use actual crop dims for remap -- numpy clips silently when # bbox extends beyond frame boundaries. if is_context_input: actual_h, actual_w = crop.shape[1], crop.shape[2] else: actual_h, actual_w = crop.shape[0], crop.shape[1] else: crop = frame_rgb # --- Preprocess --- resize_h = self.cfg.data.image_resize_dims.height resize_w = self.cfg.data.image_resize_dims.width mean = np.array(_IMAGENET_MEAN, dtype=np.float32) std = np.array(_IMAGENET_STD, dtype=np.float32) def _preprocess_single(img: np.ndarray) -> np.ndarray: """Resize, normalize, and transpose a single HWC uint8 frame to CHW float32.""" resized = cv2.resize( img, (resize_w, resize_h), interpolation=cv2.INTER_LINEAR, ) t = resized.astype(np.float32) / 255.0 t = (t - mean) / std return np.transpose(t, (2, 0, 1)) # (3, H, W) if is_context_input: frames = [_preprocess_single(crop[i]) for i in range(crop.shape[0])] tensor = np.stack(frames) # (T, 3, H, W) tensor_t = torch.from_numpy(tensor).unsqueeze(0) # (1, T, 3, H, W) else: tensor = _preprocess_single(crop) tensor_t = torch.from_numpy(tensor).unsqueeze(0) # (1, 3, H, W) device = self.model.device tensor_t = tensor_t.to(device) # --- Build batch dict --- # Bbox in LP format: [x, y, height, width] if bbox is not None: bbox_lp = torch.tensor( [[bx, by, actual_h, actual_w]], dtype=torch.float32, device=device, ) else: if is_context_input: fh, fw = frame_rgb.shape[1], frame_rgb.shape[2] else: fh, fw = frame_rgb.shape[0], frame_rgb.shape[1] bbox_lp = torch.tensor( [[0, 0, fh, fw]], dtype=torch.float32, device=device, ) num_kp = self.model.num_keypoints batch_dict = { "images": tensor_t, "keypoints": torch.zeros(1, num_kp * 2, dtype=torch.float32, device=device), "bbox": bbox_lp, "idxs": torch.zeros(1, dtype=torch.long, device=device), "heatmaps": torch.zeros(1, num_kp, 1, 1, dtype=torch.float32, device=device), } # --- Inference via get_loss_inputs_labeled --- self.model.eval() with torch.inference_mode(): result = self.model.get_loss_inputs_labeled(batch_dict) # type: ignore[arg-type] # --- Extract predictions --- kp_pred = result["keypoints_pred"] has_confidence = "confidences" in result if is_context_model: # Context model's get_loss_inputs_labeled concatenates [sf; mf] along batch dim n = kp_pred.shape[0] // 2 kp_sf = kp_pred[:n].reshape(n, -1, 2) kp_mf = kp_pred[n:].reshape(n, -1, 2) conf_sf = result["confidences"][:n] conf_mf = result["confidences"][n:] # Merge: pick higher-confidence prediction per keypoint mf_better = conf_mf > conf_sf kp_sf[mf_better] = kp_mf[mf_better] conf_merged = conf_sf.clone() conf_merged[mf_better] = conf_mf[mf_better] kp = kp_sf[0].cpu().numpy().astype(np.float32) conf = conf_merged[0].cpu().numpy().astype(np.float32) elif has_confidence: # Heatmap model — keypoints already in original frame coords # (get_loss_inputs_labeled calls convert_bbox_coords internally) kp = kp_pred[0].cpu().numpy().reshape(-1, 2).astype(np.float32) conf = result["confidences"][0].cpu().numpy().astype(np.float32) else: # Regression model — get_loss_inputs_labeled does not call # convert_bbox_coords, so we apply the remap ourselves. kp_pred = convert_bbox_coords(batch_dict, kp_pred, in_place=False) # type: ignore[arg-type] kp = kp_pred[0].cpu().numpy().reshape(-1, 2).astype(np.float32) conf = np.ones(num_kp, dtype=np.float32) return {"keypoints": kp, "confidence": conf}
[docs] def predict_on_label_csv( self, csv_file: str | Path, data_dir: str | Path | None = None, compute_metrics: bool = True, add_train_val_test_set: bool = False, ) -> PredictionResult: """Predicts on a labeled dataset and computes error/loss metrics if applicable. Args: csv_file (str | Path): Path to the CSV file of images, keypoint locations. data_dir (str | Path, optional): Root path for relative paths in the CSV file. Defaults to the data_dir originally used when training. compute_metrics (bool, optional): Whether to compute pixel error and loss metrics on predictions. generate_labeled_images (bool, optional): Whether to save labeled images. Defaults to False. output_dir (str | Path, optional): The directory to save outputs to. Defaults to `{model_dir}/image_preds/{csv_file_name}`. If set to None, outputs are not saved. add_train_val_test_set (bool): When predicting on training dataset, set to true to add the `set` column to the prediction output. Returns: PredictionResult: A PredictionResult object containing the predictions and metrics. Examples: >>> result = model.predict_on_label_csv("path/to/CollectedData.csv") >>> result.predictions # pd.DataFrame with MultiIndex columns >>> result.metrics.pixel_error # mean pixel error per keypoint Skip metric computation for faster inference: >>> result = model.predict_on_label_csv( ... "path/to/CollectedData.csv", ... compute_metrics=False, ... ) """ self._load() # Convert this to absolute, because if relative, downstream will # assume its relative to the data_dir. csv_file = Path(csv_file).absolute() if data_dir is None: data_dir = self.config.cfg.data.data_dir output_dir = self.image_preds_dir() / csv_file.name output_dir.mkdir(parents=True, exist_ok=True) # Point predict_dataset to the csv_file and data_dir. # HACK: For true multi-view model, trick predict_dataset and compute_metrics # into thinking this is a single-view model. cfg_overrides: dict[str, Any] = { "data": { "data_dir": str(data_dir), "csv_file": str(csv_file), } } # Avoid annotating set=train/val/test for CSV file other than the training CSV file. if not add_train_val_test_set: cfg_overrides.update({"train_prob": 1, "val_prob": 0, "train_frames": 1}) cfg_pred = OmegaConf.merge(self.cfg, cfg_overrides) # HACK: For true multi-view model, trick predict_dataset and compute_metrics # into thinking this is a single-view model. if self.config.is_multi_view(): del cfg_pred.data.view_names # HACK: If we don't delete mirrored_column_matches, downstream # interprets this as a mirrored multiview model, and compute_metrics fails. del cfg_pred.data.mirrored_column_matches data_module_pred = _build_datamodule_pred(cfg_pred) preds_file_path = output_dir / "predictions.csv" preds_file = str(preds_file_path) df = predict_dataset( model=self, data_module=data_module_pred, preds_file=preds_file, cfg=cfg_pred, ) if compute_metrics: metrics = compute_metrics_single( cfg=cfg_pred, labels_file=str(csv_file), preds_file=preds_file, data_module=data_module_pred, ) else: metrics = None if not isinstance(df, pd.DataFrame): raise RuntimeError('expected a single-view DataFrame from predict_dataset') return PredictionResult(predictions=df, metrics=metrics)
[docs] def predict_on_label_csv_multiview( self, csv_file_per_view: list[str] | list[Path], bbox_file_per_view: list[str] | list[Path] | None = None, camera_params_file: str | Path | None = None, data_dir: str | Path | None = None, compute_metrics: bool = True, add_train_val_test_set: bool = False, ) -> MultiviewPredictionResult: """Version of `predict_on_label_csv` that gives models access to all views of each frame. Arguments: csv_file_per_view (list[str] | list[Path]): A list of csv files each from a different view of the same session. Order must match the `view_names` in the config file. See `predict_on_label_csv` docstring for other arguments.""" if not self.config.is_multi_view(): raise ValueError('predict_on_label_csv_multiview requires a multi-view model') self._load() view_names = self.config.cfg.data.view_names if len(csv_file_per_view) != len(view_names): raise ValueError( f'expected {len(view_names)} csv files (one per view), ' f'got {len(csv_file_per_view)}' ) # Convert this to absolute, because if relative, downstream will # assume its relative to the data_dir. csv_file_per_view = [Path(f).absolute() for f in csv_file_per_view] if data_dir is None: data_dir = self.config.cfg.data.data_dir # Point predict_dataset to the csv_file and data_dir. cfg_overrides: dict[str, Any] = { "data": { "data_dir": str(data_dir), "csv_file": [str(p) for p in csv_file_per_view], } } if camera_params_file: cfg_overrides["data"]["camera_params_file"] = camera_params_file if bbox_file_per_view: cfg_overrides["data"]["bbox_file"] = [str(p) for p in bbox_file_per_view] else: cfg_overrides["data"]["bbox_file"] = None # Avoid annotating set=train/val/test for CSV file other than the training CSV file. if not add_train_val_test_set: cfg_overrides.update({"train_prob": 1, "val_prob": 0, "train_frames": 1}) cfg_pred = OmegaConf.merge(self.cfg, cfg_overrides) data_module_pred = _build_datamodule_pred(cfg_pred) preds_files = [] for i, _view_name in enumerate(view_names): output_dir = self.image_preds_dir() / csv_file_per_view[i].name output_dir.mkdir(parents=True, exist_ok=True) preds_files.append(str(output_dir / "predictions.csv")) # Outputs dict[str, pd.DataFrame] because inputs indicate multiview. view_to_df_dict = predict_dataset( model=self, data_module=data_module_pred, preds_file=preds_files, cfg=cfg_pred, ) if compute_metrics: metrics = {} for view_name, labels_file, _preds_file in zip( view_names, csv_file_per_view, preds_files, strict=True ): metrics[view_name] = compute_metrics_single( cfg=self.cfg, labels_file=str(labels_file), preds_file=_preds_file, data_module=data_module_pred, ) else: metrics = None return MultiviewPredictionResult( predictions=cast(dict[str, pd.DataFrame], view_to_df_dict), metrics=metrics, )
[docs] def predict_on_video_file( self, video_file: str | Path, output_dir: str | Path | None = UNSPECIFIED, compute_metrics: bool = True, generate_labeled_video: bool = False, progress_file: Path | None = None, ) -> PredictionResult: """Predicts on a video file and computes unsupervised loss metrics if applicable. Args: video_file (str | Path): Path to the video file. output_dir (str | Path, optional): The directory to save outputs to. Defaults to `{model_dir}/image_preds/{csv_file_name}`. If set to None, outputs are not saved. compute_metrics (bool, optional): Whether to compute pixel error and loss metrics on predictions. generate_labeled_video (bool, optional): Whether to save a labeled video. Defaults to False. progress_file (Path, optional): Path to a file to save progress information for the App. Defaults to None. Returns: PredictionResult: A PredictionResult object containing the predictions and metrics. Examples: >>> result = model.predict_on_video_file("path/to/video.mp4") >>> result.predictions # pd.DataFrame, one row per frame Save a keypoint-annotated video alongside the predictions CSV: >>> result = model.predict_on_video_file( ... "path/to/video.mp4", ... generate_labeled_video=True, ... ) """ self._load() video_file = Path(video_file) if output_dir == self.__class__.UNSPECIFIED: output_dir = self.video_preds_dir() elif output_dir is None: raise NotImplementedError("Currently we must save predictions") output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) prediction_csv_file = output_dir / f"{video_file.stem}.csv" df = predict_video( video_file=str(video_file), model=self, output_pred_file=str(prediction_csv_file), progress_file=progress_file, ) if generate_labeled_video: labeled_mp4_file = str(self.labeled_videos_dir() / f"{video_file.stem}_labeled.mp4") generate_labeled_video_fn( video_file=str(video_file), preds_df=df, output_mp4_file=labeled_mp4_file, confidence_thresh_for_vid=self.cfg.eval.confidence_thresh_for_vid, colormap=self.cfg.eval.get("colormap", "cool"), ) if compute_metrics: # FIXME: Data module is only used for computing PCA metrics. data_module = _build_datamodule_pred(self.cfg) metrics = compute_metrics_single( cfg=self.cfg, labels_file=None, preds_file=str(prediction_csv_file), data_module=data_module, ) else: metrics = None return PredictionResult(predictions=df, metrics=metrics)
[docs] def predict_on_video_file_multiview( self, video_file_per_view: list[str] | list[Path], output_dir: str | Path | None = UNSPECIFIED, compute_metrics: bool = True, generate_labeled_video: bool = False, progress_file: Path | None = None, ) -> MultiviewPredictionResult: """Version of `predict_on_video_file` that accesses to multiple camera views of each frame. Arguments: video_file_per_view (list[str] | list[Path]): A list of video files each from a different view of the same session. Number of video files must match the `view_names` in the config file. Order of the list does not matter: video files are intelligently matched to views by their filename using `utils.io.collect_video_files_by_view`. output_dir (str | Path, optional): The directory to save outputs to. Defaults to `{model_dir}/image_preds/{csv_file_name}`. If set to None, outputs are not saved. compute_metrics (bool, optional): Whether to compute pixel error and loss metrics on predictions. generate_labeled_video (bool, optional): Whether to save a labeled video. Defaults to False. progress_file (Path, optional): Path to a file to save progress information for the App. Returns: MultiviewPredictionResult: object containing the predictions and metrics for each view. """ if not self.config.is_multi_view(): raise ValueError('predict_on_video_file_multiview requires a multi-view model') self._load() view_names = self.config.cfg.data.view_names if len(video_file_per_view) != len(view_names): raise ValueError( f'expected {len(view_names)} video files (one per view), ' f'got {len(video_file_per_view)}' ) video_file_per_view = [Path(f) for f in video_file_per_view] if output_dir == self.__class__.UNSPECIFIED: output_dir = self.video_preds_dir() elif output_dir is None: raise NotImplementedError("Currently we must save predictions") output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) # Arranges video_file_per_view to be in the same order as cfg.data.view_names. _view_to_video_file: dict[str, Path] = io_utils.collect_video_files_by_view( video_file_per_view, view_names ) video_file_per_view = [ _view_to_video_file[view_name] for view_name in view_names ] prediction_csv_file_list = [ str(output_dir / f"{video_file.stem}.csv") for video_file in video_file_per_view ] df_list = predict_video( video_file=list(map(str, video_file_per_view)), model=self, output_pred_file=prediction_csv_file_list, progress_file=progress_file, ) if generate_labeled_video: for video_file, preds_df in zip(video_file_per_view, df_list, strict=True): labeled_mp4_file = str( self.labeled_videos_dir() / f"{video_file.stem}_labeled.mp4" ) generate_labeled_video_fn( video_file=str(video_file), preds_df=preds_df, output_mp4_file=labeled_mp4_file, confidence_thresh_for_vid=self.cfg.eval.confidence_thresh_for_vid, colormap=self.cfg.eval.get("colormap", "cool"), ) data_module = _build_datamodule_pred(self.cfg) if compute_metrics: metrics = {} for view_name, preds_file in zip(view_names, prediction_csv_file_list, strict=True): metrics[view_name] = compute_metrics_single( cfg=self.cfg, labels_file=None, preds_file=preds_file, data_module=data_module, ) else: metrics = None df_dict = {view_name: df for view_name, df in zip(view_names, df_list, strict=True)} return MultiviewPredictionResult(predictions=df_dict, metrics=metrics)
def _build_datamodule_pred(cfg: DictConfig | ListConfig) -> BaseDataModule | UnlabeledDataModule: """Build a data module configured for prediction (no augmentation). Args: cfg: model config; augmentation is overridden to ``"default"`` (resize only). Returns: data module ready for use with `predict_dataset`. """ cfg_pred = copy.deepcopy(cfg) cfg_pred.training.imgaug = "default" imgaug_transform_pred = get_imgaug_transform(cfg=cfg_pred) dataset_pred = get_dataset( cfg=cfg_pred, data_dir=cfg_pred.data.data_dir, imgaug_transform=imgaug_transform_pred, ) data_module_pred = get_data_module( cfg=cfg_pred, dataset=dataset_pred, video_dir=cfg_pred.data.video_dir ) return data_module_pred