Source code for lightning_pose.metrics

"""Evaluation metrics for assessing pose estimation model quality.

CSV format conventions
----------------------
Labels CSV (DLC format, 3-row MultiIndex header: scorer / bodyparts / coords):
  - coord values are ``x``, ``y``, and optionally ``visible``.
  - ``visible`` encodes per-keypoint visibility: 2 = labeled, 1 = present but
    unlabeled in this dataset, 0 = keypoint does not belong to this dataset.
  - Functions that consume labels filter coords to ``x``/``y`` before any
    reshape, so CSVs with a ``visible`` column are handled transparently.

Predictions CSV (same 3-row MultiIndex header):
  - coord values are ``x``, ``y``, ``likelihood``.
  - An optional trailing column whose first MultiIndex level is ``'set'``
    signals that the file comes from a labeled dataset (not a video).
    Its presence sets ``is_video = False``, which triggers pixel-error
    computation instead of temporal-norm computation in
    ``compute_metrics_single``.
  - ``get_keypoint_names`` identifies keypoints by finding columns whose
    coord level equals ``'x'``, so the ``'set'`` column is automatically
    excluded from the returned keypoint list.
"""

from pathlib import Path

import numpy as np
import pandas as pd
import torch
from omegaconf import DictConfig, ListConfig

from lightning_pose.data.datamodules import BaseDataModule, UnlabeledDataModule
from lightning_pose.data.datasets import MultiviewHeatmapDataset
from lightning_pose.data.datatypes import ComputeMetricsSingleResult
from lightning_pose.utils.io import fix_empty_first_row, get_keypoint_names
from lightning_pose.utils.pca import KeypointPCA

# to ignore imports for sphix-autoapidoc
__all__ = [
    "pixel_error",
    "temporal_norm",
    "pca_singleview_reprojection_error",
    "pca_multiview_reprojection_error",
    "compute_metrics_single",
]


[docs] def pixel_error(keypoints_true: np.ndarray, keypoints_pred: np.ndarray) -> np.ndarray: """Root mean square error between true and predicted keypoints. Args: keypoints_true: shape (samples, n_keypoints, 2) keypoints_pred: shape (samples, n_keypoints, 2) Returns: shape (samples, n_keypoints) """ error = np.linalg.norm(keypoints_true - keypoints_pred, axis=2) return error
[docs] def temporal_norm(keypoints_pred: np.ndarray | torch.Tensor) -> np.ndarray: """Norm of difference between keypoints on successive time bins. Args: keypoints_pred: shape (samples, n_keypoints * 2) or (samples, n_keypoints, 2) Returns: shape (samples, n_keypoints) """ from lightning_pose.losses.losses import TemporalLoss t_loss = TemporalLoss() if not isinstance(keypoints_pred, torch.Tensor): keypoints_pred = torch.tensor(keypoints_pred, dtype=torch.float32) # (samples, n_keypoints, 2) -> (samples, n_keypoints * 2) if len(keypoints_pred.shape) != 2: keypoints_pred = keypoints_pred.reshape(keypoints_pred.shape[0], -1) # compute loss with already-implemented class t_norm = t_loss.compute_loss(keypoints_pred) # prepend nan vector; no temporal norm for the very first frame t_norm = np.vstack([np.nan * np.zeros((1, t_norm.shape[1])), t_norm.numpy()]) return t_norm
[docs] def pca_singleview_reprojection_error( keypoints_pred: np.ndarray | torch.Tensor, pca: KeypointPCA, ) -> np.ndarray: """PCA reprojection error. Args: keypoints_pred: shape (samples, n_keypoints, 2) pca: pca object that contains info about pca subspace Returns: shape (samples, n_keypoints) """ if not isinstance(keypoints_pred, torch.Tensor): keypoints_pred = torch.tensor(keypoints_pred, device=pca.device, dtype=torch.float32) original_dims = keypoints_pred.shape pca_cols = pca.columns_for_singleview_pca # reshape: loss class expects a single last dim with num_keypoints * 2 data_arr = pca._format_data(data_arr=keypoints_pred.reshape(keypoints_pred.shape[0], -1)) # compute reprojection reproj = pca.reproject(data_arr=data_arr) # reshape again keypoints_reproj = reproj.reshape(reproj.shape[0], reproj.shape[1] // 2, 2) # compute pixel error error_pca = pixel_error( keypoints_pred[:, pca_cols, :].cpu().numpy(), keypoints_reproj.cpu().numpy()) # next, put this back into a full keypoints pred arr; keypoints not included in pose for pca # are set to nan error_all = np.nan * np.zeros((original_dims[0], original_dims[1])) error_all[:, pca_cols] = error_pca return error_all
[docs] def pca_multiview_reprojection_error( keypoints_pred: np.ndarray | torch.Tensor, pca: KeypointPCA, ) -> np.ndarray: """PCA reprojection error. Args: keypoints_pred: shape (samples, n_keypoints, 2) pca: pca object that contains info about pca subspace Returns: shape (samples, n_keypoints) """ if not isinstance(keypoints_pred, torch.Tensor): keypoints_pred = torch.tensor(keypoints_pred, device=pca.device, dtype=torch.float32) original_dims = keypoints_pred.shape assert pca.mirrored_column_matches is not None mirrored_column_matches = list(pca.mirrored_column_matches) # reshape: loss class expects a single last dim with num_keypoints * 2 data_arr = pca._format_data(data_arr=keypoints_pred.reshape(keypoints_pred.shape[0], -1)) # compute reprojection reproj = pca.reproject(data_arr=data_arr) # reshape again keypoints_reproj = reproj.reshape(reproj.shape[0], reproj.shape[1] // 2, 2) # put original keypoints in same format keypoints_pred_reformat = pca._format_data( data_arr=keypoints_pred.reshape(keypoints_pred.shape[0], -1)) keypoints_pred_reformat = keypoints_pred_reformat.reshape( keypoints_pred_reformat.shape[0], keypoints_pred_reformat.shape[1] // 2, 2) # compute pixel error error_pca = pixel_error(keypoints_pred_reformat.cpu().numpy(), keypoints_reproj.cpu().numpy()) # next, put this back into a full keypoints pred arr error_pca = error_pca.reshape( -1, len(mirrored_column_matches[0]), len(mirrored_column_matches), ) # batch X num_used_keypoints X num_views error_all = np.nan * np.zeros((original_dims[0], original_dims[1])) for c, cols in enumerate(mirrored_column_matches): error_all[:, cols] = error_pca[:, :, c] # just the columns belonging to view c return error_all
[docs] def compute_metrics_single( cfg: DictConfig | ListConfig, labels_file: str | Path | None, preds_file: str | Path, data_module: BaseDataModule | UnlabeledDataModule | None = None, ) -> ComputeMetricsSingleResult: """Compute various metrics on a predictions csv file from a single view. Args: cfg: hydra config labels_file: path to the labels csv file; required for pixel error computation preds_file: path to the predictions csv file data_module: required for PCA metric computation Returns: ComputeMetricsSingleResult containing dataframes for each computed metric """ pred_df = pd.read_csv(preds_file, header=[0, 1, 2], index_col=0) keypoint_names = get_keypoint_names(cfg, csv_file=str(preds_file), header_rows=[0, 1, 2]) xyl_mask = pred_df.columns.get_level_values('coords').isin(['x', 'y', 'likelihood']) tmp = pred_df.loc[:, xyl_mask].to_numpy().reshape(pred_df.shape[0], -1, 3) index = pred_df.index if pred_df.keys()[-1][0] == 'set': is_video = False set = pred_df.iloc[:, -1].to_numpy() else: is_video = True set = None keypoints_pred = tmp[:, :, :2] # shape (samples, n_keypoints, 2) if is_video: metrics_to_compute = ['temporal'] else: assert labels_file is not None metrics_to_compute = ['pixel_error'] if ( data_module is not None and cfg.data.get('columns_for_singleview_pca', None) is not None and len(cfg.data.columns_for_singleview_pca) != 0 and not isinstance(data_module.dataset, MultiviewHeatmapDataset) ): metrics_to_compute += ['pca_singleview'] if ( data_module is not None and cfg.data.get('mirrored_column_matches', None) is not None and len(cfg.data.mirrored_column_matches) != 0 and not isinstance(data_module.dataset, MultiviewHeatmapDataset) ): metrics_to_compute += ['pca_multiview'] result = ComputeMetricsSingleResult() preds_file_path = Path(preds_file) if 'pixel_error' in metrics_to_compute: assert labels_file is not None, '"pixel_error" metric requires labels_file' labels_df = pd.read_csv(labels_file, header=[0, 1, 2], index_col=0) labels_df = fix_empty_first_row(labels_df) assert labels_df.index.equals(index) xy_mask = labels_df.columns.get_level_values('coords').isin(['x', 'y']) labels_df = labels_df.loc[:, xy_mask] keypoints_true = labels_df.to_numpy().reshape(labels_df.shape[0], -1, 2) error_per_keypoint = pixel_error(keypoints_true, keypoints_pred) error_df = pd.DataFrame( error_per_keypoint, index=pd.Index(index), columns=pd.Index(keypoint_names), ) if set is not None: error_df['set'] = set save_file = preds_file_path.with_name(preds_file_path.stem + '_pixel_error.csv') error_df.to_csv(save_file) result.pixel_error_df = error_df if 'temporal' in metrics_to_compute: temporal_norm_per_keypoint = temporal_norm(keypoints_pred) temporal_norm_df = pd.DataFrame( temporal_norm_per_keypoint, index=pd.Index(index), columns=pd.Index(keypoint_names), ) if set is not None: temporal_norm_df['set'] = set save_file = preds_file_path.with_name(preds_file_path.stem + '_temporal_norm.csv') temporal_norm_df.to_csv(save_file) result.temporal_norm_df = temporal_norm_df if 'pca_singleview' in metrics_to_compute: try: assert data_module is not None pca = KeypointPCA( loss_type='pca_singleview', data_module=data_module, components_to_keep=cfg.losses.pca_singleview.components_to_keep, empirical_epsilon_percentile=cfg.losses.pca_singleview.get( 'empirical_epsilon_percentile', 1.0, ), columns_for_singleview_pca=cfg.data.columns_for_singleview_pca, centering_method=cfg.losses.pca_singleview.get('centering_method', None), ) pca() pcasv_error_per_keypoint = pca_singleview_reprojection_error(keypoints_pred, pca) pcasv_df = pd.DataFrame( pcasv_error_per_keypoint, index=pd.Index(index), columns=pd.Index(keypoint_names), ) if set is not None: pcasv_df['set'] = set save_file = preds_file_path.with_name( preds_file_path.stem + '_pca_singleview_error.csv', ) pcasv_df.to_csv(save_file) result.pca_sv_df = pcasv_df except ValueError as e: if 'cannot fit PCA' not in str(e): raise e if 'pca_multiview' in metrics_to_compute: assert data_module is not None pca = KeypointPCA( loss_type='pca_multiview', data_module=data_module, components_to_keep=cfg.losses.pca_singleview.components_to_keep, empirical_epsilon_percentile=cfg.losses.pca_singleview.get( 'empirical_epsilon_percentile', 1.0, ), mirrored_column_matches=cfg.data.mirrored_column_matches, ) pca() pcamv_error_per_keypoint = pca_multiview_reprojection_error(keypoints_pred, pca) pcamv_df = pd.DataFrame( pcamv_error_per_keypoint, index=pd.Index(index), columns=pd.Index(keypoint_names), ) if set is not None: pcamv_df['set'] = set save_file = preds_file_path.with_name(preds_file_path.stem + '_pca_multiview_error.csv') pcamv_df.to_csv(save_file) result.pca_mv_df = pcamv_df return result