Source code for lightning_pose.metrics

"""Evaluation metrics for assessing pose estimation model quality."""

from pathlib import Path

import numpy as np
import pandas as pd
import torch
from omegaconf import DictConfig, ListConfig

from lightning_pose.data.datamodules import BaseDataModule, UnlabeledDataModule
from lightning_pose.data.datasets import MultiviewHeatmapDataset
from lightning_pose.data.datatypes import ComputeMetricsSingleResult
from lightning_pose.utils.io import fix_empty_first_row, get_keypoint_names
from lightning_pose.utils.pca import KeypointPCA

# to ignore imports for sphix-autoapidoc
__all__ = [
    "pixel_error",
    "temporal_norm",
    "pca_singleview_reprojection_error",
    "pca_multiview_reprojection_error",
    "compute_metrics_single",
]


[docs] def pixel_error(keypoints_true: np.ndarray, keypoints_pred: np.ndarray) -> np.ndarray: """Root mean square error between true and predicted keypoints. Args: keypoints_true: shape (samples, n_keypoints, 2) keypoints_pred: shape (samples, n_keypoints, 2) Returns: shape (samples, n_keypoints) """ error = np.linalg.norm(keypoints_true - keypoints_pred, axis=2) return error
[docs] def temporal_norm(keypoints_pred: np.ndarray | torch.Tensor) -> np.ndarray: """Norm of difference between keypoints on successive time bins. Args: keypoints_pred: shape (samples, n_keypoints * 2) or (samples, n_keypoints, 2) Returns: shape (samples, n_keypoints) """ from lightning_pose.losses.losses import TemporalLoss t_loss = TemporalLoss() if not isinstance(keypoints_pred, torch.Tensor): keypoints_pred = torch.tensor(keypoints_pred, dtype=torch.float32) # (samples, n_keypoints, 2) -> (samples, n_keypoints * 2) if len(keypoints_pred.shape) != 2: keypoints_pred = keypoints_pred.reshape(keypoints_pred.shape[0], -1) # compute loss with already-implemented class t_norm = t_loss.compute_loss(keypoints_pred) # prepend nan vector; no temporal norm for the very first frame t_norm = np.vstack([np.nan * np.zeros((1, t_norm.shape[1])), t_norm.numpy()]) return t_norm
[docs] def pca_singleview_reprojection_error( keypoints_pred: np.ndarray | torch.Tensor, pca: KeypointPCA, ) -> np.ndarray: """PCA reprojection error. Args: keypoints_pred: shape (samples, n_keypoints, 2) pca: pca object that contains info about pca subspace Returns: shape (samples, n_keypoints) """ if not isinstance(keypoints_pred, torch.Tensor): keypoints_pred = torch.tensor(keypoints_pred, device=pca.device, dtype=torch.float32) original_dims = keypoints_pred.shape pca_cols = pca.columns_for_singleview_pca # reshape: loss class expects a single last dim with num_keypoints * 2 data_arr = pca._format_data(data_arr=keypoints_pred.reshape(keypoints_pred.shape[0], -1)) # compute reprojection reproj = pca.reproject(data_arr=data_arr) # reshape again keypoints_reproj = reproj.reshape(reproj.shape[0], reproj.shape[1] // 2, 2) # compute pixel error error_pca = pixel_error( keypoints_pred[:, pca_cols, :].cpu().numpy(), keypoints_reproj.cpu().numpy()) # next, put this back into a full keypoints pred arr; keypoints not included in pose for pca # are set to nan error_all = np.nan * np.zeros((original_dims[0], original_dims[1])) error_all[:, pca_cols] = error_pca return error_all
[docs] def pca_multiview_reprojection_error( keypoints_pred: np.ndarray | torch.Tensor, pca: KeypointPCA, ) -> np.ndarray: """PCA reprojection error. Args: keypoints_pred: shape (samples, n_keypoints, 2) pca: pca object that contains info about pca subspace Returns: shape (samples, n_keypoints) """ if not isinstance(keypoints_pred, torch.Tensor): keypoints_pred = torch.tensor(keypoints_pred, device=pca.device, dtype=torch.float32) original_dims = keypoints_pred.shape assert pca.mirrored_column_matches is not None mirrored_column_matches = list(pca.mirrored_column_matches) # reshape: loss class expects a single last dim with num_keypoints * 2 data_arr = pca._format_data(data_arr=keypoints_pred.reshape(keypoints_pred.shape[0], -1)) # compute reprojection reproj = pca.reproject(data_arr=data_arr) # reshape again keypoints_reproj = reproj.reshape(reproj.shape[0], reproj.shape[1] // 2, 2) # put original keypoints in same format keypoints_pred_reformat = pca._format_data( data_arr=keypoints_pred.reshape(keypoints_pred.shape[0], -1)) keypoints_pred_reformat = keypoints_pred_reformat.reshape( keypoints_pred_reformat.shape[0], keypoints_pred_reformat.shape[1] // 2, 2) # compute pixel error error_pca = pixel_error(keypoints_pred_reformat.cpu().numpy(), keypoints_reproj.cpu().numpy()) # next, put this back into a full keypoints pred arr error_pca = error_pca.reshape( -1, len(mirrored_column_matches[0]), len(mirrored_column_matches), ) # batch X num_used_keypoints X num_views error_all = np.nan * np.zeros((original_dims[0], original_dims[1])) for c, cols in enumerate(mirrored_column_matches): error_all[:, cols] = error_pca[:, :, c] # just the columns belonging to view c return error_all
[docs] def compute_metrics_single( cfg: DictConfig | ListConfig, labels_file: str | Path | None, preds_file: str | Path, data_module: BaseDataModule | UnlabeledDataModule | None = None, ) -> ComputeMetricsSingleResult: """Compute various metrics on a predictions csv file from a single view. Args: cfg: hydra config labels_file: path to the labels csv file; required for pixel error computation preds_file: path to the predictions csv file data_module: required for PCA metric computation Returns: ComputeMetricsSingleResult containing dataframes for each computed metric """ pred_df = pd.read_csv(preds_file, header=[0, 1, 2], index_col=0) keypoint_names = get_keypoint_names(cfg, csv_file=str(preds_file), header_rows=[0, 1, 2]) xyl_mask = pred_df.columns.get_level_values('coords').isin(['x', 'y', 'likelihood']) tmp = pred_df.loc[:, xyl_mask].to_numpy().reshape(pred_df.shape[0], -1, 3) index = pred_df.index if pred_df.keys()[-1][0] == 'set': is_video = False set = pred_df.iloc[:, -1].to_numpy() else: is_video = True set = None keypoints_pred = tmp[:, :, :2] # shape (samples, n_keypoints, 2) if is_video: metrics_to_compute = ['temporal'] else: assert labels_file is not None metrics_to_compute = ['pixel_error'] if ( data_module is not None and cfg.data.get('columns_for_singleview_pca', None) is not None and len(cfg.data.columns_for_singleview_pca) != 0 and not isinstance(data_module.dataset, MultiviewHeatmapDataset) ): metrics_to_compute += ['pca_singleview'] if ( data_module is not None and cfg.data.get('mirrored_column_matches', None) is not None and len(cfg.data.mirrored_column_matches) != 0 and not isinstance(data_module.dataset, MultiviewHeatmapDataset) ): metrics_to_compute += ['pca_multiview'] result = ComputeMetricsSingleResult() preds_file_path = Path(preds_file) if 'pixel_error' in metrics_to_compute: assert labels_file is not None, '"pixel_error" metric requires labels_file' labels_df = pd.read_csv(labels_file, header=[0, 1, 2], index_col=0) labels_df = fix_empty_first_row(labels_df) assert labels_df.index.equals(index) keypoints_true = labels_df.to_numpy().reshape(labels_df.shape[0], -1, 2) error_per_keypoint = pixel_error(keypoints_true, keypoints_pred) error_df = pd.DataFrame( error_per_keypoint, index=pd.Index(index), columns=pd.Index(keypoint_names), ) if set is not None: error_df['set'] = set save_file = preds_file_path.with_name(preds_file_path.stem + '_pixel_error.csv') error_df.to_csv(save_file) result.pixel_error_df = error_df if 'temporal' in metrics_to_compute: temporal_norm_per_keypoint = temporal_norm(keypoints_pred) temporal_norm_df = pd.DataFrame( temporal_norm_per_keypoint, index=pd.Index(index), columns=pd.Index(keypoint_names), ) if set is not None: temporal_norm_df['set'] = set save_file = preds_file_path.with_name(preds_file_path.stem + '_temporal_norm.csv') temporal_norm_df.to_csv(save_file) result.temporal_norm_df = temporal_norm_df if 'pca_singleview' in metrics_to_compute: try: assert data_module is not None pca = KeypointPCA( loss_type='pca_singleview', data_module=data_module, components_to_keep=cfg.losses.pca_singleview.components_to_keep, empirical_epsilon_percentile=cfg.losses.pca_singleview.get( 'empirical_epsilon_percentile', 1.0, ), columns_for_singleview_pca=cfg.data.columns_for_singleview_pca, centering_method=cfg.losses.pca_singleview.get('centering_method', None), ) pca() pcasv_error_per_keypoint = pca_singleview_reprojection_error(keypoints_pred, pca) pcasv_df = pd.DataFrame( pcasv_error_per_keypoint, index=pd.Index(index), columns=pd.Index(keypoint_names), ) if set is not None: pcasv_df['set'] = set save_file = preds_file_path.with_name( preds_file_path.stem + '_pca_singleview_error.csv', ) pcasv_df.to_csv(save_file) result.pca_sv_df = pcasv_df except ValueError as e: if 'cannot fit PCA' not in str(e): raise e if 'pca_multiview' in metrics_to_compute: assert data_module is not None pca = KeypointPCA( loss_type='pca_multiview', data_module=data_module, components_to_keep=cfg.losses.pca_singleview.components_to_keep, empirical_epsilon_percentile=cfg.losses.pca_singleview.get( 'empirical_epsilon_percentile', 1.0, ), mirrored_column_matches=cfg.data.mirrored_column_matches, ) pca() pcamv_error_per_keypoint = pca_multiview_reprojection_error(keypoints_pred, pca) pcamv_df = pd.DataFrame( pcamv_error_per_keypoint, index=pd.Index(index), columns=pd.Index(keypoint_names), ) if set is not None: pcamv_df['set'] = set save_file = preds_file_path.with_name(preds_file_path.stem + '_pca_multiview_error.csv') pcamv_df.to_csv(save_file) result.pca_mv_df = pcamv_df return result