Source code for lightning_pose.losses.helpers

"""Helper functions for losses."""

from typing import Dict, Literal, Union

import numpy as np
import torch

# to ignore imports for sphix-autoapidoc
__all__ = [
    "EmpiricalEpsilon",
    "convert_dict_values_to_tensors",
]


[docs]class EmpiricalEpsilon: """Find percentile value of a given loss tensor.""" def __init__(self, percentile: float) -> None: self.percentile = percentile
[docs] def __call__(self, loss: Union[torch.Tensor, np.array]) -> float: """Compute the percentile of some loss, to use an for epsilon-insensitive loss. Args: loss: tensor with scalar loss per term (e.g., loss per image, or loss per keypoint, etc.) Returns: the percentile of the loss which we use as epsilon """ flattened_loss = loss.flatten() # applies for both np arrays and torch tensors. if type(loss) is torch.Tensor: flattened_loss = flattened_loss.clone().detach().cpu().numpy() return np.nanpercentile(flattened_loss, self.percentile, axis=0)
# @typechecked
[docs]def convert_dict_values_to_tensors( param_dict: Dict[str, Union[np.array, float]], device: Union[Literal["cpu", "cuda"], torch.device], ) -> Dict[str, torch.Tensor]: # TODO: currently supporting just floats for key, val in param_dict.items(): param_dict[key] = torch.tensor(val, dtype=torch.float, device=device) return param_dict