from __future__ import annotations
import copy
from pathlib import Path
from typing import TypedDict
import pandas as pd
from omegaconf import DictConfig, OmegaConf
from lightning_pose.model_config import ModelConfig
from lightning_pose.models import ALLOWED_MODELS
from lightning_pose.utils.io import ckpt_path_from_base_path
from lightning_pose.utils.predictions import (
export_predictions_and_labeled_video,
load_model_from_checkpoint,
predict_dataset,
)
# Import as different name to avoid naming conflict with the kwarg `compute_metrics`.
from lightning_pose.utils.scripts import compute_metrics as compute_metrics_fn
from lightning_pose.utils.scripts import get_data_module, get_dataset, get_imgaug_transform
__all__ = ["Model"]
[docs]
class Model:
model_dir: Path
"""Directory the model is stored in."""
config: ModelConfig
"""The model configuration stored as a `ModelConfig` object.
`ModelConfig` wraps the `omegaconf.DictConfig` and provides util functions
over it.
"""
model: ALLOWED_MODELS | None = None
# Just a constant we can use as a default value for kwargs,
# to differentiate between user omitting a kwarg, vs explicitly passing None.
UNSPECIFIED = "unspecified"
[docs]
@staticmethod
def from_dir(model_dir: str | Path):
"""Create a `Model` instance for a model stored at `model_dir`."""
model_dir = Path(model_dir)
config = ModelConfig.from_yaml_file(model_dir / "config.yaml")
return Model(model_dir, config)
def __init__(self, model_dir: str | Path, config: ModelConfig):
self.model_dir = Path(model_dir).absolute()
self.config = config
@property
def cfg(self) -> DictConfig:
"""The model configuration as an `omegaconf.DictConfig`."""
return self.config.cfg
def _load(self):
if self.model is None:
ckpt_file = ckpt_path_from_base_path(
base_path=str(self.model_dir), model_name=self.cfg.model.model_name
)
self.model = load_model_from_checkpoint(
cfg=self.cfg,
ckpt_file=ckpt_file,
eval=True,
skip_data_module=True,
)
def image_preds_dir(self) -> Path:
return self.model_dir / "image_preds"
def video_preds_dir(self) -> Path:
return self.model_dir / "video_preds"
def labeled_videos_dir(self) -> Path:
return self.model_dir / "video_preds" / "labeled_videos"
class PredictionResult(TypedDict):
predictions: pd.DataFrame
metrics: pd.DataFrame
[docs]
def predict_on_label_csv(
self,
csv_file: str | Path,
data_dir: str | Path | None = None,
compute_metrics: bool = True,
generate_labeled_images: bool = False,
output_dir: str | Path | None = UNSPECIFIED,
) -> PredictionResult:
"""Predicts on a labeled dataset and computes error/loss metrics if applicable.
Args:
csv_file (str | Path): Path to the CSV file of images, keypoint locations.
data_dir (str | Path, optional): Root path for relative paths in the CSV file. Defaults to the
parent directory of the CSV file.
compute_metrics (bool, optional): Whether to compute pixel error and loss metrics on
predictions.
generate_labeled_images (bool, optional): Whether to save labeled images. Defaults to False.
output_dir (str | Path, optional): The directory to save outputs to.
Defaults to `{model_dir}/image_preds/{csv_file_name}`. If set to None, outputs are not saved.
Returns:
PredictionResult: A PredictionResult object containing the predictions and metrics.
"""
return self.predict_on_label_csv_internal(
csv_file=csv_file,
data_dir=data_dir,
compute_metrics=compute_metrics,
generate_labeled_images=generate_labeled_images,
output_dir=output_dir,
output_filename_stem="predictions",
add_train_val_test_set=False,
)
def predict_on_label_csv_internal(
self,
csv_file: str | Path,
data_dir: str | Path | None = None,
compute_metrics: bool = True,
generate_labeled_images: bool = False,
output_dir: str | Path | None = UNSPECIFIED,
output_filename_stem: str = "predictions",
add_train_val_test_set: bool = False,
) -> PredictionResult:
"""
See predict_on_label_csv for the rest of the arguments. The following are the
arguments specific to the internal function.
Args:
output_filename_stem (str): The stem of the output filename. Defaults to 'predictions'.
Used to generate predictions_new for OOD, and predictions_{view_name} for multi-view, in the
model_dir.
add_train_val_test_set (bool): When predicting on training dataset, set to true to add the `set`
column to the prediction output.
"""
self._load()
csv_file = Path(csv_file)
if data_dir is None:
data_dir = csv_file.parent
if output_dir == self.__class__.UNSPECIFIED:
output_dir = self.image_preds_dir() / csv_file.name
elif output_dir is None:
raise NotImplementedError("Currently we must save predictions")
output_dir.mkdir(parents=True, exist_ok=True)
if generate_labeled_images:
raise NotImplementedError()
# Point predict_dataset to the csv_file and data_dir.
cfg_overrides = {
"data": {
"data_dir": str(data_dir),
"csv_file": str(csv_file),
}
}
# Avoid annotating set=train/val/test for CSV file other than the training CSV file.
if not add_train_val_test_set:
cfg_overrides.update({"train_prob": 1, "val_prob": 0, "train_frames": 1})
cfg_pred = OmegaConf.merge(self.cfg, cfg_overrides)
# HACK: For true multi-view model, trick predict_dataset and compute_metrics
# into thinking this is a single-view model.
if self.config.is_multi_view():
del cfg_pred.data.view_names
# HACK: If we don't delete mirrored_column_matches, downstream
# interprets this as a mirrored multiview model, and compute_metrics fails.
del cfg_pred.data.mirrored_column_matches
data_module_pred = _build_datamodule_pred(cfg_pred)
preds_file_path = output_dir / (output_filename_stem + ".csv")
preds_file = str(preds_file_path)
df = predict_dataset(
cfg_pred, data_module_pred, model=self.model, preds_file=preds_file
)
if compute_metrics:
# HACK: True multi-view model treated as single-view model, so preds_file is
# a string, not a list per-view. This means we can't yet compute pca_multiview.
compute_metrics_fn(
cfg=cfg_pred,
preds_file=preds_file,
data_module=data_module_pred,
)
# TODO: Generate detector outputs.
return self.PredictionResult(predictions=df)
[docs]
def predict_on_video_file(
self,
video_file: str | Path,
output_dir: str | Path | None = UNSPECIFIED,
compute_metrics: bool = True,
generate_labeled_video: bool = False,
) -> PredictionResult:
"""Predicts on a video file and computes unsupervised loss metrics if applicable.
Args:
video_file (str | Path): Path to the video file.
compute_metrics (bool, optional): Whether to compute pixel error and loss metrics on
predictions.
generate_labeled_video (bool, optional): Whether to save a labeled video. Defaults to False.
output_dir (str | Path, optional): The directory to save outputs to.
Defaults to `{model_dir}/image_preds/{csv_file_name}`. If set to None, outputs are not saved.
Returns:
PredictionResult: A PredictionResult object containing the predictions and metrics.
"""
self._load()
video_file = Path(video_file)
if output_dir == self.__class__.UNSPECIFIED:
output_dir = self.video_preds_dir()
elif output_dir is None:
raise NotImplementedError("Currently we must save predictions")
output_dir.mkdir(parents=True, exist_ok=True)
prediction_csv_file = output_dir / f"{video_file.stem}.csv"
labeled_mp4_file = None
if generate_labeled_video:
labeled_mp4_file = str(
self.labeled_videos_dir() / f"{video_file.stem}_labeled.mp4"
)
if self.config.cfg.eval.get("predict_vids_after_training_save_heatmaps", False):
raise NotImplementedError(
"Implement this after cleaning up _predict_frames: "
"Set a flag on the model to return heatmaps. "
"Use trainer.predict instead of side-stepping it."
)
df = export_predictions_and_labeled_video(
video_file=str(video_file),
cfg=self.config.cfg,
prediction_csv_file=str(prediction_csv_file),
labeled_mp4_file=labeled_mp4_file,
model=self.model,
)
if compute_metrics:
# FIXME: This is only used for computing PCA metrics.
data_module = _build_datamodule_pred(self.cfg)
compute_metrics_fn(self.cfg, str(prediction_csv_file), data_module)
return self.PredictionResult(predictions=df)
def _build_datamodule_pred(cfg: DictConfig):
cfg_pred = copy.deepcopy(cfg)
cfg_pred.training.imgaug = "default"
imgaug_transform_pred = get_imgaug_transform(cfg=cfg_pred)
dataset_pred = get_dataset(
cfg=cfg_pred,
data_dir=cfg_pred.data.data_dir,
imgaug_transform=imgaug_transform_pred,
)
data_module_pred = get_data_module(
cfg=cfg_pred, dataset=dataset_pred, video_dir=cfg_pred.data.video_dir
)
data_module_pred.setup()
return data_module_pred