Source code for lightning_pose.metrics

import numpy as np
import torch
from typeguard import typechecked

from lightning_pose.utils.pca import KeypointPCA

# to ignore imports for sphix-autoapidoc
__all__ = [
    "pixel_error",
    "temporal_norm",
    "pca_singleview_reprojection_error",
    "pca_multiview_reprojection_error",
]


[docs] @typechecked def pixel_error(keypoints_true: np.ndarray, keypoints_pred: np.ndarray) -> np.ndarray: """Root mean square error between true and predicted keypoints. Args: keypoints_true: shape (samples, n_keypoints, 2) keypoints_pred: shape (samples, n_keypoints, 2) Returns: shape (samples, n_keypoints) """ error = np.linalg.norm(keypoints_true - keypoints_pred, axis=2) return error
[docs] @typechecked def temporal_norm(keypoints_pred: np.ndarray | torch.Tensor) -> np.ndarray: """Norm of difference between keypoints on successive time bins. Args: keypoints_pred: shape (samples, n_keypoints * 2) or (samples, n_keypoints, 2) Returns: shape (samples, n_keypoints) """ from lightning_pose.losses.losses import TemporalLoss t_loss = TemporalLoss() if not isinstance(keypoints_pred, torch.Tensor): keypoints_pred = torch.tensor(keypoints_pred, dtype=torch.float32) # (samples, n_keypoints, 2) -> (samples, n_keypoints * 2) if len(keypoints_pred.shape) != 2: keypoints_pred = keypoints_pred.reshape(keypoints_pred.shape[0], -1) # compute loss with already-implemented class t_norm = t_loss.compute_loss(keypoints_pred) # prepend nan vector; no temporal norm for the very first frame t_norm = np.vstack([np.nan * np.zeros((1, t_norm.shape[1])), t_norm.numpy()]) return t_norm
[docs] @typechecked def pca_singleview_reprojection_error( keypoints_pred: np.ndarray | torch.Tensor, pca: KeypointPCA, ) -> np.ndarray: """PCA reprojection error. Args: keypoints_pred: shape (samples, n_keypoints, 2) pca: pca object that contains info about pca subspace Returns: shape (samples, n_keypoints) """ if not isinstance(keypoints_pred, torch.Tensor): keypoints_pred = torch.tensor(keypoints_pred, device=pca.device, dtype=torch.float32) original_dims = keypoints_pred.shape pca_cols = pca.columns_for_singleview_pca # reshape: loss class expects a single last dim with num_keypoints * 2 data_arr = pca._format_data(data_arr=keypoints_pred.reshape(keypoints_pred.shape[0], -1)) # compute reprojection reproj = pca.reproject(data_arr=data_arr) # reshape again keypoints_reproj = reproj.reshape(reproj.shape[0], reproj.shape[1] // 2, 2) # compute pixel error error_pca = pixel_error( keypoints_pred[:, pca_cols, :].cpu().numpy(), keypoints_reproj.cpu().numpy()) # next, put this back into a full keypoints pred arr; keypoints not included in pose for pca # are set to nan error_all = np.nan * np.zeros((original_dims[0], original_dims[1])) error_all[:, pca_cols] = error_pca return error_all
[docs] @typechecked def pca_multiview_reprojection_error( keypoints_pred: np.ndarray | torch.Tensor, pca: KeypointPCA, ) -> np.ndarray: """PCA reprojection error. Args: keypoints_pred: shape (samples, n_keypoints, 2) pca: pca object that contains info about pca subspace Returns: shape (samples, n_keypoints) """ if not isinstance(keypoints_pred, torch.Tensor): keypoints_pred = torch.tensor(keypoints_pred, device=pca.device, dtype=torch.float32) original_dims = keypoints_pred.shape mirrored_column_matches = list(pca.mirrored_column_matches) # reshape: loss class expects a single last dim with num_keypoints * 2 data_arr = pca._format_data(data_arr=keypoints_pred.reshape(keypoints_pred.shape[0], -1)) # compute reprojection reproj = pca.reproject(data_arr=data_arr) # reshape again keypoints_reproj = reproj.reshape(reproj.shape[0], reproj.shape[1] // 2, 2) # put original keypoints in same format keypoints_pred_reformat = pca._format_data( data_arr=keypoints_pred.reshape(keypoints_pred.shape[0], -1)) keypoints_pred_reformat = keypoints_pred_reformat.reshape( keypoints_pred_reformat.shape[0], keypoints_pred_reformat.shape[1] // 2, 2) # compute pixel error error_pca = pixel_error(keypoints_pred_reformat.cpu().numpy(), keypoints_reproj.cpu().numpy()) # next, put this back into a full keypoints pred arr error_pca = error_pca.reshape( -1, len(mirrored_column_matches[0]), len(mirrored_column_matches), ) # batch X num_used_keypoints X num_views error_all = np.nan * np.zeros((original_dims[0], original_dims[1])) for c, cols in enumerate(mirrored_column_matches): error_all[:, cols] = error_pca[:, :, c] # just the columns belonging to view c return error_all