from __future__ import annotations
import copy
from pathlib import Path
import cv2
import numpy as np
import pandas as pd
import torch
from omegaconf import DictConfig, OmegaConf
from lightning_pose.api.model_config import ModelConfig
from lightning_pose.data import _IMAGENET_MEAN, _IMAGENET_STD
from lightning_pose.data.datatypes import MultiviewPredictionResult, PredictionResult
from lightning_pose.data.utils import convert_bbox_coords
from lightning_pose.models import ALLOWED_MODELS
from lightning_pose.utils import io as io_utils
from lightning_pose.utils.predictions import generate_labeled_video as generate_labeled_video_fn
from lightning_pose.utils.predictions import (
load_model_from_checkpoint,
predict_dataset,
predict_video,
)
from lightning_pose.utils.scripts import (
compute_metrics_single,
get_data_module,
get_dataset,
get_imgaug_transform,
)
__all__ = ["Model"]
[docs]
class Model:
model_dir: Path
"""Directory the model is stored in."""
config: ModelConfig
"""The model configuration stored as a `ModelConfig` object.
`ModelConfig` wraps the `omegaconf.DictConfig` and provides util functions
over it.
"""
model: ALLOWED_MODELS | None = None
# Just a constant we can use as a default value for kwargs,
# to differentiate between user omitting a kwarg, vs explicitly passing None.
UNSPECIFIED = "unspecified"
[docs]
@staticmethod
def from_dir(model_dir: str | Path):
"""Create a `Model` instance for a model stored at `model_dir`."""
return Model.from_dir2(model_dir)
@staticmethod
def from_dir2(model_dir: str | Path, hydra_overrides: list[str] = None):
"""Internal version of from_dir that supports hydra_overrides. Not sure whether to
promote this to public API yet."""
model_dir = Path(model_dir).absolute()
if hydra_overrides is not None:
import hydra
with hydra.initialize_config_dir(
version_base="1.1", config_dir=str(model_dir)
):
cfg = hydra.compose(config_name="config", overrides=hydra_overrides)
config = ModelConfig(cfg)
else:
config = ModelConfig.from_yaml_file(model_dir / "config.yaml")
return Model(model_dir, config)
def __init__(self, model_dir: str | Path, config: ModelConfig):
self.model_dir = Path(model_dir).absolute()
self.config = config
@property
def cfg(self) -> DictConfig:
"""The model configuration as an `omegaconf.DictConfig`."""
return self.config.cfg
def _load(self):
if self.model is None:
ckpt_file = io_utils.ckpt_path_from_base_path(
base_path=str(self.model_dir), model_name=self.cfg.model.model_name
)
if ckpt_file is None:
raise FileNotFoundError(
"Checkpoint file not found, have you trained for enough epochs?"
)
self.model = load_model_from_checkpoint(
cfg=self.cfg,
ckpt_file=ckpt_file,
eval=True,
skip_data_module=True,
)
def image_preds_dir(self) -> Path:
return self.model_dir / "image_preds"
def video_preds_dir(self) -> Path:
return self.model_dir / "video_preds"
def labeled_videos_dir(self) -> Path:
return self.model_dir / "video_preds" / "labeled_videos"
def cropped_data_dir(self):
return self.model_dir / "cropped_images"
def cropped_videos_dir(self):
return self.model_dir / "cropped_videos"
def cropped_csv_file_path(self, csv_file_path: str | Path):
csv_file_path = Path(csv_file_path)
return (
self.model_dir
/ "image_preds"
/ csv_file_path.name
/ ("cropped_" + csv_file_path.name)
)
[docs]
def predict_frame(
self,
frame_rgb: np.ndarray,
bbox: tuple[int, int, int, int] | None = None,
) -> dict[str, np.ndarray]:
"""Single-frame inference. No file I/O, no DALI.
Preprocessing uses cv2 (not DALI). Results will differ numerically
from ``predict_on_video_file`` due to interpolation and normalization
differences. Do not mix results from the two paths in quantitative
analysis.
For MHCRNN (context) models, pass a ``(T, H, W, 3)`` array where T
is the temporal context length (typically 5). Passing a single frame
to a context model raises ``ValueError`` — use
``predict_on_video_file`` for proper temporal inference.
The first call triggers model loading and CUDA initialization, which
may take several seconds. Subsequent calls are fast (~5-50ms depending
on backbone). For latency-sensitive loops, call once on a dummy frame
before entering the loop.
Args:
frame_rgb: ``(H, W, 3)`` uint8 RGB array for standard models, or
``(T, H, W, 3)`` uint8 RGB array for context (MHCRNN) models.
bbox: Optional ``(x, y, w, h)`` crop region. Note: this is
``(x, y, width, height)``, NOT ``(x1, y1, x2, y2)``.
If provided, crops first, then remaps keypoints back to
original coordinates.
Returns:
{"keypoints": (num_kp, 2) float32 array (x, y) in original frame coords,
"confidence": (num_kp,) float32 in [0, 1] -- likelihood/confidence
per keypoint. For regression models, confidence is always 1.0.}
Raises:
ValueError: If frame_rgb has wrong shape/dtype, bbox has non-positive
dimensions, bbox produces an empty crop, or a context model
receives single-frame input.
"""
self._load()
# --- Input validation ---
if frame_rgb.dtype != np.uint8:
raise ValueError(
f"frame_rgb must be uint8, got {frame_rgb.dtype}. "
"Convert with frame.astype(np.uint8) if values are in [0, 255]."
)
is_context_input = frame_rgb.ndim == 4
if is_context_input:
if frame_rgb.shape[3] != 3:
raise ValueError(
f"frame_rgb must be (T, H, W, 3), got shape {frame_rgb.shape}"
)
elif frame_rgb.ndim == 3:
if frame_rgb.shape[2] != 3:
raise ValueError(
f"frame_rgb must be (H, W, 3), got shape {frame_rgb.shape}"
)
else:
raise ValueError(
f"frame_rgb must be (H, W, 3) or (T, H, W, 3), "
f"got {frame_rgb.ndim}D array with shape {frame_rgb.shape}"
)
if frame_rgb.size == 0:
raise ValueError("frame_rgb is empty")
is_context_model = self.model.do_context
if is_context_model and not is_context_input:
raise ValueError(
"Context model requires frame_rgb of shape (T, H, W, 3) "
"where T is the temporal context length (typically 5). "
"Use predict_on_video_file for single-frame input."
)
# --- Crop ---
if bbox is not None:
bx, by, bw, bh = bbox
if bx < 0 or by < 0:
raise ValueError(
f"bbox origin must be non-negative, got x={bx}, y={by}"
)
if bw <= 0 or bh <= 0:
raise ValueError(
f"bbox width and height must be positive, got w={bw}, h={bh}"
)
if is_context_input:
crop = frame_rgb[:, by:by + bh, bx:bx + bw]
else:
crop = frame_rgb[by:by + bh, bx:bx + bw]
if crop.size == 0:
raise ValueError(
f"bbox (x={bx}, y={by}, w={bw}, h={bh}) produces an empty "
f"crop on frame of shape {frame_rgb.shape}"
)
# Use actual crop dims for remap -- numpy clips silently when
# bbox extends beyond frame boundaries.
if is_context_input:
actual_h, actual_w = crop.shape[1], crop.shape[2]
else:
actual_h, actual_w = crop.shape[0], crop.shape[1]
else:
crop = frame_rgb
# --- Preprocess ---
resize_h = self.cfg.data.image_resize_dims.height
resize_w = self.cfg.data.image_resize_dims.width
mean = np.array(_IMAGENET_MEAN, dtype=np.float32)
std = np.array(_IMAGENET_STD, dtype=np.float32)
def _preprocess_single(img):
resized = cv2.resize(
img, (resize_w, resize_h), interpolation=cv2.INTER_LINEAR,
)
t = resized.astype(np.float32) / 255.0
t = (t - mean) / std
return np.transpose(t, (2, 0, 1)) # (3, H, W)
if is_context_input:
frames = [_preprocess_single(crop[i]) for i in range(crop.shape[0])]
tensor = np.stack(frames) # (T, 3, H, W)
tensor_t = torch.from_numpy(tensor).unsqueeze(0) # (1, T, 3, H, W)
else:
tensor = _preprocess_single(crop)
tensor_t = torch.from_numpy(tensor).unsqueeze(0) # (1, 3, H, W)
device = self.model.device
tensor_t = tensor_t.to(device)
# --- Build batch dict ---
# Bbox in LP format: [x, y, height, width]
if bbox is not None:
bbox_lp = torch.tensor(
[[bx, by, actual_h, actual_w]], dtype=torch.float32, device=device,
)
else:
if is_context_input:
fh, fw = frame_rgb.shape[1], frame_rgb.shape[2]
else:
fh, fw = frame_rgb.shape[0], frame_rgb.shape[1]
bbox_lp = torch.tensor(
[[0, 0, fh, fw]], dtype=torch.float32, device=device,
)
num_kp = self.model.num_keypoints
batch_dict = {
"images": tensor_t,
"keypoints": torch.zeros(1, num_kp * 2, dtype=torch.float32, device=device),
"bbox": bbox_lp,
"idxs": torch.zeros(1, dtype=torch.long, device=device),
"heatmaps": torch.zeros(1, num_kp, 1, 1, dtype=torch.float32, device=device),
}
# --- Inference via get_loss_inputs_labeled ---
self.model.eval()
with torch.inference_mode():
result = self.model.get_loss_inputs_labeled(batch_dict)
# --- Extract predictions ---
kp_pred = result["keypoints_pred"]
has_confidence = "confidences" in result
if is_context_model:
# Context model's get_loss_inputs_labeled concatenates [sf; mf] along batch dim
n = kp_pred.shape[0] // 2
kp_sf = kp_pred[:n].reshape(n, -1, 2)
kp_mf = kp_pred[n:].reshape(n, -1, 2)
conf_sf = result["confidences"][:n]
conf_mf = result["confidences"][n:]
# Merge: pick higher-confidence prediction per keypoint
mf_better = conf_mf > conf_sf
kp_sf[mf_better] = kp_mf[mf_better]
conf_merged = conf_sf.clone()
conf_merged[mf_better] = conf_mf[mf_better]
kp = kp_sf[0].cpu().numpy().astype(np.float32)
conf = conf_merged[0].cpu().numpy().astype(np.float32)
elif has_confidence:
# Heatmap model — keypoints already in original frame coords
# (get_loss_inputs_labeled calls convert_bbox_coords internally)
kp = kp_pred[0].cpu().numpy().reshape(-1, 2).astype(np.float32)
conf = result["confidences"][0].cpu().numpy().astype(np.float32)
else:
# Regression model — get_loss_inputs_labeled does not call
# convert_bbox_coords, so we apply the remap ourselves.
kp_pred = convert_bbox_coords(batch_dict, kp_pred, in_place=False)
kp = kp_pred[0].cpu().numpy().reshape(-1, 2).astype(np.float32)
conf = np.ones(num_kp, dtype=np.float32)
return {"keypoints": kp, "confidence": conf}
[docs]
def predict_on_label_csv(
self,
csv_file: str | Path,
data_dir: str | Path | None = None,
compute_metrics: bool = True,
add_train_val_test_set: bool = False,
) -> PredictionResult:
"""Predicts on a labeled dataset and computes error/loss metrics if applicable.
Args:
csv_file (str | Path): Path to the CSV file of images, keypoint locations.
data_dir (str | Path, optional): Root path for relative paths in the CSV file.
Defaults to the data_dir originally used when training.
compute_metrics (bool, optional): Whether to compute pixel error and loss metrics on
predictions.
generate_labeled_images (bool, optional): Whether to save labeled images.
Defaults to False.
output_dir (str | Path, optional): The directory to save outputs to.
Defaults to `{model_dir}/image_preds/{csv_file_name}`.
If set to None, outputs are not saved.
add_train_val_test_set (bool): When predicting on training dataset, set to true to add
the `set` column to the prediction output.
Returns:
PredictionResult: A PredictionResult object containing the predictions and metrics.
"""
self._load()
# Convert this to absolute, because if relative, downstream will
# assume its relative to the data_dir.
csv_file = Path(csv_file).absolute()
if data_dir is None:
data_dir = self.config.cfg.data.data_dir
output_dir = self.image_preds_dir() / csv_file.name
output_dir.mkdir(parents=True, exist_ok=True)
# Point predict_dataset to the csv_file and data_dir.
# HACK: For true multi-view model, trick predict_dataset and compute_metrics
# into thinking this is a single-view model.
cfg_overrides = {
"data": {
"data_dir": str(data_dir),
"csv_file": str(csv_file),
}
}
# Avoid annotating set=train/val/test for CSV file other than the training CSV file.
if not add_train_val_test_set:
cfg_overrides.update({"train_prob": 1, "val_prob": 0, "train_frames": 1})
cfg_pred = OmegaConf.merge(self.cfg, cfg_overrides)
# HACK: For true multi-view model, trick predict_dataset and compute_metrics
# into thinking this is a single-view model.
if self.config.is_multi_view():
del cfg_pred.data.view_names
# HACK: If we don't delete mirrored_column_matches, downstream
# interprets this as a mirrored multiview model, and compute_metrics fails.
del cfg_pred.data.mirrored_column_matches
data_module_pred = _build_datamodule_pred(cfg_pred)
preds_file_path = output_dir / "predictions.csv"
preds_file = str(preds_file_path)
df = predict_dataset(
cfg_pred, data_module_pred, model=self.model, preds_file=preds_file
)
if compute_metrics:
metrics = compute_metrics_single(
cfg=cfg_pred,
labels_file=str(csv_file),
preds_file=preds_file,
data_module=data_module_pred,
)
else:
metrics = None
return PredictionResult(predictions=df, metrics=metrics)
[docs]
def predict_on_label_csv_multiview(
self,
csv_file_per_view: list[str] | list[Path],
bbox_file_per_view: list[str] | list[Path] | None = None,
camera_params_file: str | Path | None = None,
data_dir: str | Path | None = None,
compute_metrics: bool = True,
add_train_val_test_set: bool = False,
) -> MultiviewPredictionResult:
"""Version of `predict_on_label_csv` that gives models access to all views of each frame.
Arguments:
csv_file_per_view (list[str] | list[Path]): A list of csv files each from a different
view of the same session. Order must match the `view_names` in the config file.
See `predict_on_label_csv` docstring for other arguments."""
assert self.config.is_multi_view()
self._load()
view_names = self.config.cfg.data.view_names
assert len(csv_file_per_view) == len(
view_names
), f"{len(csv_file_per_view)} != {len(view_names)}"
# Convert this to absolute, because if relative, downstream will
# assume its relative to the data_dir.
csv_file_per_view: list[Path] = [Path(f).absolute() for f in csv_file_per_view]
if data_dir is None:
data_dir = self.config.cfg.data.data_dir
# Point predict_dataset to the csv_file and data_dir.
cfg_overrides = {
"data": {
"data_dir": str(data_dir),
"csv_file": [str(p) for p in csv_file_per_view],
}
}
if camera_params_file:
cfg_overrides["data"]["camera_params_file"] = camera_params_file
if bbox_file_per_view:
cfg_overrides["data"]["bbox_file"] = [str(p) for p in bbox_file_per_view]
else:
cfg_overrides["data"]["bbox_file"] = None
# Avoid annotating set=train/val/test for CSV file other than the training CSV file.
if not add_train_val_test_set:
cfg_overrides.update({"train_prob": 1, "val_prob": 0, "train_frames": 1})
cfg_pred = OmegaConf.merge(self.cfg, cfg_overrides)
data_module_pred = _build_datamodule_pred(cfg_pred)
preds_files = []
for i, _view_name in enumerate(view_names):
output_dir = self.image_preds_dir() / csv_file_per_view[i].name
output_dir.mkdir(parents=True, exist_ok=True)
preds_files.append(str(output_dir / "predictions.csv"))
# Outputs dict[str, pd.DataFrame] because inputs indicate multiview.
view_to_df_dict = predict_dataset(
cfg_pred, data_module_pred, model=self.model, preds_file=preds_files
)
if compute_metrics:
metrics = {}
for view_name, labels_file, _preds_file in zip(
view_names, csv_file_per_view, preds_files, strict=True
):
metrics[view_name] = compute_metrics_single(
cfg=self.cfg,
labels_file=str(labels_file),
preds_file=_preds_file,
data_module=data_module_pred,
)
else:
metrics = None
return MultiviewPredictionResult(predictions=view_to_df_dict, metrics=metrics)
[docs]
def predict_on_video_file(
self,
video_file: str | Path,
output_dir: str | Path | None = UNSPECIFIED,
compute_metrics: bool = True,
generate_labeled_video: bool = False,
progress_file: Path | None = None,
) -> PredictionResult:
"""Predicts on a video file and computes unsupervised loss metrics if applicable.
Args:
video_file (str | Path): Path to the video file.
output_dir (str | Path, optional): The directory to save outputs to.
Defaults to `{model_dir}/image_preds/{csv_file_name}`.
If set to None, outputs are not saved.
compute_metrics (bool, optional): Whether to compute pixel error and loss metrics on
predictions.
generate_labeled_video (bool, optional): Whether to save a labeled video.
Defaults to False.
progress_file (Path, optional): Path to a file to save progress information for the
App. Defaults to None.
Returns:
PredictionResult: A PredictionResult object containing the predictions and metrics.
"""
self._load()
video_file = Path(video_file)
if output_dir == self.__class__.UNSPECIFIED:
output_dir = self.video_preds_dir()
elif output_dir is None:
raise NotImplementedError("Currently we must save predictions")
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
prediction_csv_file = output_dir / f"{video_file.stem}.csv"
df: pd.DataFrame = predict_video(
video_file=str(video_file),
model=self,
output_pred_file=str(prediction_csv_file),
progress_file=progress_file,
)
if generate_labeled_video:
labeled_mp4_file = str(self.labeled_videos_dir() / f"{video_file.stem}_labeled.mp4")
generate_labeled_video_fn(
video_file=str(video_file),
preds_df=df,
output_mp4_file=labeled_mp4_file,
confidence_thresh_for_vid=self.cfg.eval.confidence_thresh_for_vid,
colormap=self.cfg.eval.get("colormap", "cool"),
)
if compute_metrics:
# FIXME: Data module is only used for computing PCA metrics.
data_module = _build_datamodule_pred(self.cfg)
metrics = compute_metrics_single(
cfg=self.cfg,
labels_file=None,
preds_file=str(prediction_csv_file),
data_module=data_module,
)
else:
metrics = None
return PredictionResult(predictions=df, metrics=metrics)
[docs]
def predict_on_video_file_multiview(
self,
video_file_per_view: list[str] | list[Path],
output_dir: str | Path | None = UNSPECIFIED,
compute_metrics: bool = True,
generate_labeled_video: bool = False,
progress_file: Path | None = None,
) -> MultiviewPredictionResult:
"""Version of `predict_on_video_file` that accesses to multiple camera views of each frame.
Arguments:
video_file_per_view (list[str] | list[Path]): A list of video files each from a
different view of the same session.
Number of video files must match the `view_names` in the config file.
Order of the list does not matter: video files are intelligently matched to views
by their filename using `utils.io.collect_video_files_by_view`.
output_dir (str | Path, optional): The directory to save outputs to.
Defaults to `{model_dir}/image_preds/{csv_file_name}`.
If set to None, outputs are not saved.
compute_metrics (bool, optional): Whether to compute pixel error and loss metrics on
predictions.
generate_labeled_video (bool, optional): Whether to save a labeled video.
Defaults to False.
progress_file (Path, optional): Path to a file to save progress information for
the App.
Returns:
MultiviewPredictionResult: object containing the predictions and metrics for each view.
"""
assert self.config.is_multi_view()
self._load()
view_names = self.config.cfg.data.view_names
assert len(video_file_per_view) == len(
view_names
), f"{len(video_file_per_view)} != {len(view_names)}"
video_file_per_view: list[Path] = [Path(f) for f in video_file_per_view]
if output_dir == self.__class__.UNSPECIFIED:
output_dir = self.video_preds_dir()
elif output_dir is None:
raise NotImplementedError("Currently we must save predictions")
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
# Arranges video_file_per_view to be in the same order as cfg.data.view_names.
_view_to_video_file: dict[str, Path] = io_utils.collect_video_files_by_view(
video_file_per_view, view_names
)
video_file_per_view: list[Path] = [
_view_to_video_file[view_name] for view_name in view_names
]
prediction_csv_file_list = [
str(output_dir / f"{video_file.stem}.csv")
for video_file in video_file_per_view
]
df_list: list[pd.DataFrame] = predict_video(
video_file=list(map(str, video_file_per_view)),
model=self,
output_pred_file=prediction_csv_file_list,
progress_file=progress_file,
)
if generate_labeled_video:
for video_file, preds_df in zip(video_file_per_view, df_list, strict=True):
labeled_mp4_file = str(
self.labeled_videos_dir() / f"{video_file.stem}_labeled.mp4"
)
generate_labeled_video_fn(
video_file=str(video_file),
preds_df=preds_df,
output_mp4_file=labeled_mp4_file,
confidence_thresh_for_vid=self.cfg.eval.confidence_thresh_for_vid,
colormap=self.cfg.eval.get("colormap", "cool"),
)
data_module = _build_datamodule_pred(self.cfg)
if compute_metrics:
metrics = {}
for view_name, preds_file in zip(view_names, prediction_csv_file_list, strict=True):
metrics[view_name] = compute_metrics_single(
cfg=self.cfg,
labels_file=None,
preds_file=preds_file,
data_module=data_module,
)
else:
metrics = None
df_dict = {view_name: df for view_name, df in zip(view_names, df_list, strict=True)}
return MultiviewPredictionResult(predictions=df_dict, metrics=metrics)
def _build_datamodule_pred(cfg: DictConfig):
cfg_pred = copy.deepcopy(cfg)
cfg_pred.training.imgaug = "default"
imgaug_transform_pred = get_imgaug_transform(cfg=cfg_pred)
dataset_pred = get_dataset(
cfg=cfg_pred,
data_dir=cfg_pred.data.data_dir,
imgaug_transform=imgaug_transform_pred,
)
data_module_pred = get_data_module(
cfg=cfg_pred, dataset=dataset_pred, video_dir=cfg_pred.data.video_dir
)
return data_module_pred