"""Functions for predicting keypoints on labeled datasets and unlabeled videos."""
from __future__ import annotations
import datetime
import gc
import os
from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal, cast, overload
import cv2
import lightning.pytorch as pl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from jaxtyping import Float
from moviepy import VideoFileClip
from omegaconf import DictConfig, ListConfig
from lightning_pose.callbacks import JSONInferenceProgressTracker
from lightning_pose.data.dali import PrepareDALI
from lightning_pose.data.datamodules import BaseDataModule, UnlabeledDataModule
from lightning_pose.data.utils import count_frames
if TYPE_CHECKING:
from lightning_pose.api import Model
# to ignore imports for sphix-autoapidoc
__all__ = [
"PredictionHandler",
"predict_dataset",
"predict_video",
"make_dlc_pandas_index",
"generate_labeled_video",
]
[docs]
class PredictionHandler:
"""Convert batches of model outputs into a prediction dataframe."""
[docs]
def __init__(
self,
cfg: DictConfig | ListConfig,
data_module: BaseDataModule | UnlabeledDataModule | None = None,
video_file: str | None = None,
) -> None:
"""
Args
cfg
data_module: Only required for prediction of CSV files.
video_file: For prediction on video, path to the video file.
Used to get frame_count.
"""
if data_module is None and video_file is None:
raise ValueError("must pass either data_module or video_file")
if cfg.data.get("keypoint_names", None) is None:
raise ValueError("must include `keypoint_names` field in cfg.data")
self.cfg = cfg
self.data_module = data_module
self.video_file = video_file
@property
def frame_count(self) -> int:
"""Returns the number of frames in the video or the labeled dataset"""
if self.video_file is not None:
return count_frames(self.video_file)
else:
assert self.data_module is not None
return len(self.data_module.dataset) # type: ignore[arg-type]
@property
def keypoint_names(self) -> list[str]:
"""List of keypoint name strings from the config.
Returns:
List of keypoint names.
"""
return list(self.cfg.data.keypoint_names)
@property
def do_context(self) -> bool:
"""Whether the model/loader uses 5-frame context.
Returns:
True if context frames are used, otherwise False.
"""
if self.data_module:
return self.data_module.dataset.do_context # type: ignore[union-attr]
else:
return self.cfg.model.model_type == "heatmap_mhcrnn"
[docs]
def unpack_preds(
self,
preds: list[
tuple[
Float[torch.Tensor, "batch two_times_num_keypoints"],
Float[torch.Tensor, "batch num_keypoints"],
]
],
) -> tuple[
Float[torch.Tensor, "num_frames two_times_num_keypoints"],
Float[torch.Tensor, "num_frames num_keypoints"],
]:
"""unpack list of preds coming out from pl.trainer.predict, confs tuples into tensors.
It still returns unnecessary final rows, which should be discarded at the dataframe stage.
This works for the output of predict_loader, suitable for
batch_size=1, sequence_length=16, step=16
"""
# stack the predictions into rows.
# loop over the batches, and stack
stacked_preds = torch.vstack([pred[0] for pred in preds])
stacked_confs = torch.vstack([pred[1] for pred in preds])
if self.video_file is not None: # dealing with dali loaders
# DB: this used to be an else but I think it should apply to all dataloaders now
# first we chop off the last few rows that are not part of the video
# next:
# for baseline: chop extra empty frames from last sequence.
num_rows_to_discard = stacked_preds.shape[0] - self.frame_count
if num_rows_to_discard > 0:
stacked_preds = stacked_preds[:-num_rows_to_discard]
stacked_confs = stacked_confs[:-num_rows_to_discard]
# for context: missing first two frames, have to handle with the last two frames still
if self.do_context:
# fix shifts in the context model
stacked_preds = self.fix_context_preds_confs(stacked_preds)
if self.cfg.model.model_type == "heatmap_mhcrnn":
stacked_confs = self.fix_context_preds_confs(
stacked_confs, zero_pad_confidence=False
)
else:
stacked_confs = self.fix_context_preds_confs(
stacked_confs, zero_pad_confidence=True
)
# else:
# in this dataloader, the last sequence has a few extra frames.
return stacked_preds, stacked_confs
[docs]
def fix_context_preds_confs(
self, stacked_preds: torch.Tensor, zero_pad_confidence: bool = False
) -> torch.Tensor:
"""
In the context model, ind=0 is associated with image[2], and ind=1 is associated with
image[3], so we need to shift the predictions and confidences by two and eliminate the
edges.
NOTE: confidences are not zero in the first and last two images, they are instead replicas
of images[-2] and images[-3]
"""
# first pad the first two rows for which we have no valid preds.
preds_1 = torch.tile(stacked_preds[0], (2, 1)) # copying twice the prediction for image[2]
preds_2 = stacked_preds[0:-2] # throw out the last two rows.
preds_combined = torch.vstack([preds_1, preds_2])
# repat the last one twice
if preds_combined.shape[0] == self.frame_count:
# i.e., after concat this has the length of the video.
# we don't have valid predictions for the last two elements, so we pad with element -3
preds_combined[-2:, :] = preds_combined[-3, :]
else:
# we don't have as many predictions as frames; pad with final entry which is valid.
n_pad = self.frame_count - preds_combined.shape[0]
preds_combined = torch.vstack(
[preds_combined, torch.tile(preds_combined[0], (n_pad, 1))]
)
if zero_pad_confidence:
# zeroing out those first and last two rows (after we've shifted everything above)
preds_combined[:2, :] = 0.0
preds_combined[-2:, :] = 0.0
return preds_combined
[docs]
@staticmethod
def make_pred_arr_undo_resize(
keypoints_np: np.ndarray,
confidence_np: np.ndarray,
) -> np.ndarray:
"""Resize keypoints and add confidences into one numpy array.
Args:
keypoints_np: shape (n_frames, n_keypoints * 2)
confidence_np: shape (n_frames, n_keypoints)
Returns:
np.ndarray: cols are (bp0_x, bp0_y, bp0_likelihood, bp1_x, bp1_y, ...)
"""
# check num frames in the dataset
assert keypoints_np.shape[0] == confidence_np.shape[0]
# check we have two (x,y) coordinates and a single likelihood value
assert keypoints_np.shape[1] == confidence_np.shape[1] * 2
num_joints = confidence_np.shape[-1] # model.num_keypoints
predictions = np.zeros((keypoints_np.shape[0], num_joints * 3))
predictions[:, 0] = np.arange(keypoints_np.shape[0])
predictions[:, 0::3] = keypoints_np[:, 0::2]
predictions[:, 1::3] = keypoints_np[:, 1::2]
predictions[:, 2::3] = confidence_np
return predictions
[docs]
def make_dlc_pandas_index(self, keypoint_names: list | None = None) -> pd.MultiIndex:
"""Build a DLC-style pandas MultiIndex for labelling prediction DataFrames.
Args:
keypoint_names: optional override for the list of keypoint names; defaults to
``self.keypoint_names``.
Returns:
``pd.MultiIndex`` with levels ``["scorer", "bodyparts", "coords"]``.
"""
return make_dlc_pandas_index(
cfg=self.cfg, keypoint_names=keypoint_names or self.keypoint_names
)
[docs]
def add_split_indices_to_df(self, df: pd.DataFrame) -> pd.DataFrame:
"""Add split indices to the dataframe."""
assert self.data_module is not None
df["set"] = np.array(["unused"] * df.shape[0])
assert self.data_module.train_dataset is not None
assert self.data_module.val_dataset is not None
assert self.data_module.test_dataset is not None
dataset_split_indices = {
"train": self.data_module.train_dataset.indices,
"validation": self.data_module.val_dataset.indices,
"test": self.data_module.test_dataset.indices,
}
for key, val in dataset_split_indices.items():
df.loc[val, ("set", "", "")] = np.repeat(key, len(val))
return df
@overload
def __call__(
self,
preds: list[
tuple[
Float[torch.Tensor, "batch two_times_num_keypoints"],
Float[torch.Tensor, "batch num_keypoints"],
]
],
is_multiview_video: Literal[False] = ...,
) -> pd.DataFrame: ...
@overload
def __call__(
self,
preds: list[
tuple[
Float[torch.Tensor, "batch two_times_num_keypoints"],
Float[torch.Tensor, "batch num_keypoints"],
]
],
is_multiview_video: Literal[True],
) -> dict[str, pd.DataFrame]: ...
[docs]
def __call__(
self,
preds: list[
tuple[
Float[torch.Tensor, "batch two_times_num_keypoints"],
Float[torch.Tensor, "batch num_keypoints"],
]
],
is_multiview_video: bool = False,
) -> pd.DataFrame | dict[str, pd.DataFrame]:
"""
Call this function to get a pandas dataframe of the predictions for a single video.
Assuming you've already run trainer.predict(), and have a list of Tuple predictions.
Args:
preds: list of tuples of (predictions, confidences)
is_multiview_video: specify True when you are using multiview video prediction
dataloader, i.e. for heatmap_multiview.
Returns:
pd.DataFrame: index is (frame, bodypart, x, y, likelihood)
"""
stacked_preds, stacked_confs = self.unpack_preds(preds=preds)
if (
self.cfg.data.get("view_names", None)
and len(self.cfg.data.view_names) > 1
and (self.video_file is None or is_multiview_video)
):
# NOTE: if self.video_file is not None assume we are processing one view at a time, and
# move to the `else` block below.
# UPDATE: No longer true, added is_multiview_video mode.
num_keypoints = len(self.keypoint_names)
view_to_df = {}
for view_idx, view_name in enumerate(self.cfg.data.view_names):
idx_beg = view_idx * num_keypoints
idx_end = idx_beg + num_keypoints
stacked_preds_single = stacked_preds[:, idx_beg * 2:idx_end * 2]
stacked_confs_single = stacked_confs[:, idx_beg:idx_end]
pred_arr = self.make_pred_arr_undo_resize(
stacked_preds_single.cpu().numpy(), stacked_confs_single.cpu().numpy()
)
pdindex = self.make_dlc_pandas_index(self.keypoint_names)
df = pd.DataFrame(pred_arr, columns=pdindex)
view_to_df[view_name] = df
if self.video_file is None:
# specify which image is train/test/val/unused
df = self.add_split_indices_to_df(df)
assert self.data_module is not None
view_dataset = self.data_module.dataset.dataset # type: ignore[index]
df.index = view_dataset[view_name].image_names
retval = view_to_df
else:
pred_arr = self.make_pred_arr_undo_resize(
stacked_preds.cpu().numpy(), stacked_confs.cpu().numpy()
)
pdindex = self.make_dlc_pandas_index()
df = pd.DataFrame(pred_arr, columns=pdindex)
if self.video_file is None:
# specify which image is train/test/val/unused
df = self.add_split_indices_to_df(df)
assert self.data_module is not None
df.index = self.data_module.dataset.image_names # type: ignore[union-attr]
retval = df
return retval
[docs]
def predict_dataset(
model: Model,
data_module: BaseDataModule,
preds_file: str | list[str],
cfg: DictConfig | ListConfig | None = None,
) -> pd.DataFrame | dict[str, pd.DataFrame]:
"""Save predicted keypoints for a labeled dataset.
Args:
model: API model wrapper; its underlying lightning module is used for inference.
data_module: data module that contains dataloaders for train, val, test splits.
preds_file: path for the predictions .csv file.
cfg: hydra config; if None, falls back to ``model.config.cfg``.
Returns:
pandas dataframe with predictions or dict with dataframe of predictions for each view
"""
cfg_eff = cfg if cfg is not None else model.config.cfg
trainer = pl.Trainer(devices=1, accelerator='gpu', logger=False)
labeled_preds = trainer.predict(
model=model.model,
dataloaders=data_module.full_labeled_dataloader(),
return_predictions=True,
)
assert labeled_preds is not None
pred_handler = PredictionHandler(cfg=cfg_eff, data_module=data_module, video_file=None)
labeled_preds_typed = cast(
list[tuple[torch.Tensor, torch.Tensor]], labeled_preds
)
labeled_preds_df = pred_handler(preds=labeled_preds_typed)
if isinstance(labeled_preds_df, dict):
if isinstance(preds_file, str):
# old logic used to save to <predictions>_<view_name>.csv
for view_name, df in labeled_preds_df.items():
df.to_csv(preds_file.replace(".csv", f"_{view_name}.csv"))
elif isinstance(preds_file, list):
# preds_file is a list of views corresponding to cfg.data.view_names.
# this allows the caller to specify the output locations more flexibly.
# Check the order of labeled_preds_df keys matches the order of the views in the cfg.
assert list(labeled_preds_df.keys()) == list(cfg_eff.data.view_names)
for (_view_name, df), _pred_file in zip(
labeled_preds_df.items(), preds_file, strict=True
):
df.to_csv(_pred_file)
else:
assert isinstance(preds_file, str), 'preds_file must be a str for single-view predictions'
labeled_preds_df.to_csv(preds_file)
# clear up memory
del trainer
gc.collect()
torch.cuda.empty_cache()
return labeled_preds_df
@overload
def predict_video(
video_file: str,
model: Model,
output_pred_file: str | None = None,
progress_file: Path | None = None,
) -> pd.DataFrame: ...
@overload
def predict_video(
video_file: list[str],
model: Model,
output_pred_file: list[str] | None = None,
progress_file: Path | None = None,
) -> list[pd.DataFrame]: ...
[docs]
def predict_video(
video_file: str | list[str],
model: Model,
output_pred_file: str | list[str] | None = None,
progress_file: Path | None = None,
) -> pd.DataFrame | list[pd.DataFrame]:
"""
Args:
video_file: Predict on a video, or for true multiview models, a list of videos
(order: 1-1 correspondence with cfg.data.view_names).
model: The model to predict with.
output_pred_file: (optional) File to save predictions in.
For multiview, a list of files (1-1 correspondance to cfg.data.view_names).
"""
is_multiview = not isinstance(video_file, str)
if is_multiview:
# Validate output_pred_file is a list
if output_pred_file is not None and not isinstance(output_pred_file, list):
raise ValueError(
"for multiview prediction, 'output_pred_file' should be a list corresponding to "
"view_names"
)
# sanity check 1-1 correspondence of video_file to cfg.data.view_names
# important since PredictionHandler relies on correspondence to organize the outputted dict
for single_video_file, view_name in zip(
video_file, model.config.cfg.data.view_names, strict=True
):
assert (
view_name in Path(single_video_file).stem
), "expected video_file to correspond 1-1 with cfg.data.view_name"
trainer = pl.Trainer(
accelerator="gpu",
devices=1,
logger=False,
callbacks=(
[JSONInferenceProgressTracker(progress_file)] if progress_file is not None else None
),
)
model_type: Literal["base", "context"] = (
"context" if model.config.cfg.model.model_type == "heatmap_mhcrnn" else "base"
)
filenames = [video_file] if not is_multiview else [[f] for f in video_file]
vid_pred_class = PrepareDALI(
train_stage="predict",
model_type=model_type,
dali_config=model.config.cfg.dali,
# Important: This will be a list of lists for multiview.
# This will trigger dali to return multiview batches to predict_step.
filenames=filenames,
resize_dims=[
model.config.cfg.data.image_resize_dims.height,
model.config.cfg.data.image_resize_dims.width,
],
)
# get loader
predict_loader = vid_pred_class()
# initialize prediction handler class
pred_handler = PredictionHandler(
cfg=model.config.cfg,
video_file=video_file[0] if is_multiview else video_file,
)
# compute predictions
preds = trainer.predict(
model=model.model,
dataloaders=predict_loader,
return_predictions=True,
)
assert preds is not None
preds_typed = cast(list[tuple[torch.Tensor, torch.Tensor]], preds)
preds_df = pred_handler(preds=preds_typed, is_multiview_video=is_multiview)
# Convert to a 1-1 correspondence list similar to video_files, for multiview.
if isinstance(preds_df, dict):
preds_df = [
preds_df[view_name] for view_name in model.config.cfg.data.view_names
]
if output_pred_file is not None:
# save the predictions to a csv; create directory if it doesn't exist
if is_multiview:
for df, output_file in zip(preds_df, output_pred_file, strict=True):
os.makedirs(os.path.dirname(output_file), exist_ok=True)
df.to_csv(output_file)
else:
assert isinstance(preds_df, pd.DataFrame)
assert isinstance(output_pred_file, str)
preds_df.to_csv(output_pred_file)
# clear up memory
del model
del trainer
del predict_loader
gc.collect()
torch.cuda.empty_cache()
return preds_df
[docs]
def make_dlc_pandas_index(
cfg: DictConfig | ListConfig,
keypoint_names: list[str],
) -> pd.MultiIndex:
"""Create a DLC-style three-level pandas MultiIndex for prediction DataFrames.
Args:
cfg: hydra config used to obtain the model type for the scorer level.
keypoint_names: list of body-part names.
Returns:
``pd.MultiIndex`` with levels ``["scorer", "bodyparts", "coords"]`` where coords are
``["x", "y", "likelihood"]``.
"""
xyl_labels = ["x", "y", "likelihood"]
pdindex = pd.MultiIndex.from_product(
[[f"{cfg.model.model_type}_tracker"], keypoint_names, xyl_labels],
names=["scorer", "bodyparts", "coords"],
)
return pdindex
def _make_cmap(number_colors: int, cmap: str) -> np.ndarray:
"""Sample ``number_colors`` evenly spaced RGB colours from a matplotlib colormap.
Args:
number_colors: number of discrete colours to sample.
cmap: matplotlib colormap name (e.g., ``"cool"``).
Returns:
Uint8 array of shape ``(number_colors, 3)`` with RGB values in ``[0, 255]``.
"""
color_class = plt.cm.ScalarMappable(cmap=cmap)
C = color_class.to_rgba(np.linspace(0, 1, number_colors))
colors = (C[:, :3] * 255).astype(np.uint8)
return colors
def _create_labeled_video(
clip: VideoFileClip,
xs_arr: np.ndarray,
ys_arr: np.ndarray,
mask_array: np.ndarray | None = None,
dotsize: int = 4,
colormap: str | None = "cool",
fps: float | None = None,
output_video_path: str = "movie.mp4",
start_time: float = 0.0,
) -> None:
"""Helper function for creating annotated videos.
Args
clip
xs_arr: shape T x n_joints
ys_arr: shape T x n_joints
mask_array: shape T x n_joints; timepoints/joints with a False entry will not be plotted
dotsize: size of marker dot on labeled video
colormap: matplotlib color map for markers
fps: None to default to fps of original video
output_video_path: video file name
start_time: time (in seconds) of video start
"""
if mask_array is None:
mask_array = ~np.isnan(xs_arr)
n_frames, n_keypoints = xs_arr.shape
# set colormap for each color
colors = _make_cmap(n_keypoints, cmap=colormap or "cool")
# extract info from clip
nx, ny = clip.size
dur = int(clip.duration - clip.start)
fps_og = clip.fps
# upsample clip if low resolution; need to do this for dots and text to look nice
if nx <= 100 or ny <= 100:
upsample_factor = 2.5
elif nx <= 192 or ny <= 192:
upsample_factor = 2
else:
upsample_factor = 1
if upsample_factor > 1:
clip = cast(VideoFileClip, clip.resized((upsample_factor * nx, upsample_factor * ny)))
nx, ny = clip.size
print(f"Duration of video [s]: {np.round(dur, 2)}, recorded at {np.round(fps_og, 2)} fps!")
def seconds_to_hms(seconds: float) -> str:
"""Format a duration in seconds as an ``HH:MM:SS`` string."""
# Convert seconds to a timedelta object
td = datetime.timedelta(seconds=seconds)
# Extract hours, minutes, and seconds from the timedelta object
hours = td // datetime.timedelta(hours=1)
minutes = (td // datetime.timedelta(minutes=1)) % 60
remainder = td % datetime.timedelta(minutes=1)
# Format the hours, minutes, and seconds into a string
hms_str = f"{hours:02}:{minutes:02}:{remainder.seconds:02}"
return hms_str
# add marker to each frame t, where t is in sec
def add_marker_and_timestamps(get_frame: Any, t: float) -> np.ndarray:
"""Overlay keypoint markers and a timestamp on the frame at time ``t``."""
image = get_frame(t)
# frame [ny x ny x 3]
frame = image.copy()
# convert from sec to indices
index = int(np.round(t * fps_og))
# ----------------
# markers
# ----------------
if index >= n_frames:
print(f"add_marker_and_timestamps: Skipped frame {index}")
else:
for bpindex in range(n_keypoints):
if mask_array[index, bpindex]:
xc = min(int(upsample_factor * xs_arr[index, bpindex]), nx - 1)
yc = min(int(upsample_factor * ys_arr[index, bpindex]), ny - 1)
frame = cv2.circle(
frame,
center=(xc, yc),
radius=dotsize,
color=colors[bpindex].tolist(),
thickness=-1,
)
# ----------------
# timestamps
# ----------------
seconds_from_start = t + start_time
time_from_start = seconds_to_hms(seconds_from_start)
idx_from_start = int(np.round(seconds_from_start * fps_og))
text = f"t={time_from_start}, frame={idx_from_start}"
# define text info
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = 0.5
font_thickness = 1
# calculate the size of the text
text_size = cv2.getTextSize(text, font, font_scale, font_thickness)[0]
# calculate the position of the text in the upper-left corner
offset = 6
text_x = offset # offset from the left
text_y = text_size[1] + offset # offset from the bottom
# make black rectangle with a small padding of offset / 2 pixels
cv2.rectangle(
frame,
(text_x - int(offset / 2), text_y + int(offset / 2)),
(text_x + text_size[0] + int(offset / 2), text_y - text_size[1] - int(offset / 2)),
(0, 0, 0), # rectangle color
cv2.FILLED,
)
cv2.putText(
frame,
text,
(text_x, text_y),
font,
font_scale,
(255, 255, 255), # font color
font_thickness,
lineType=cv2.LINE_AA,
)
return frame
clip_marked = clip.transform(add_marker_and_timestamps)
clip_marked.write_videofile(
output_video_path, codec="libx264", fps=fps or fps_og or 20.0
)
clip_marked.close()
[docs]
def generate_labeled_video(
video_file: str,
preds_df: pd.DataFrame,
output_mp4_file: str,
confidence_thresh_for_vid: float,
colormap: str,
) -> None:
"""Overlay keypoint markers on a video and write the result to disk.
Args:
video_file: path to the source video file.
preds_df: predictions DataFrame with columns indexed as
``(scorer, bodypart, coord)`` where coord is x, y, or likelihood.
output_mp4_file: path where the labeled video will be saved.
confidence_thresh_for_vid: keypoints with confidence below this value are not plotted.
colormap: matplotlib colormap name used to colour each keypoint.
"""
os.makedirs(os.path.dirname(output_mp4_file), exist_ok=True)
# transform df to numpy array
keypoints_arr = np.reshape(preds_df.to_numpy(), [preds_df.shape[0], -1, 3])
xs_arr = keypoints_arr[:, :, 0]
ys_arr = keypoints_arr[:, :, 1]
mask_array = keypoints_arr[:, :, 2] > confidence_thresh_for_vid
# video generation
video_clip = VideoFileClip(video_file)
_create_labeled_video(
clip=video_clip,
xs_arr=xs_arr,
ys_arr=ys_arr,
mask_array=mask_array,
output_video_path=output_mp4_file,
colormap=colormap,
)