Source code for lightning_pose.data.datatypes

"""Classes to streamline data typechecking."""
from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, TypedDict

import numpy as np
import pandas as pd
import torch
from jaxtyping import Float, Int

if TYPE_CHECKING:
    from nvidia.dali.plugin.pytorch import DALIGenericIterator

# to ignore imports for sphix-autoapidoc
__all__ = [
    "BaseLabeledExampleDict",
    "HeatmapLabeledExampleDict",
    "MultiviewLabeledExampleDict",
    "MultiviewHeatmapLabeledExampleDict",
    "BaseLabeledBatchDict",
    "HeatmapLabeledBatchDict",
    "MultiviewLabeledBatchDict",
    "MultiviewHeatmapLabeledBatchDict",
    "UnlabeledBatchDict",
    "MultiviewUnlabeledBatchDict",
    "SemiSupervisedBatchDict",
    "SemiSupervisedHeatmapBatchDict",
    "SemiSupervisedDataLoaderDict",
]


[docs] @dataclass class PredictionResult: """ """ # suppresses sphinx class doc from getting autogenerated from __init__. predictions: pd.DataFrame metrics: ComputeMetricsSingleResult | None = None
[docs] def to_dict(self) -> dict[str, Any]: """Return predictions and metrics as a flat dict of named numpy arrays. All arrays have shape ``(n_frames, n_keypoints)`` and share the same row order. Metric arrays are ``None`` when the metric was not computed. Returns: dict with keys: - ``keypoint_names``: list of keypoint name strings. - ``index``: list of frame identifiers (file paths or integer indices). - ``x``: float array of predicted x coordinates. - ``y``: float array of predicted y coordinates. - ``confidence``: float array of per-keypoint likelihood in [0, 1]. - ``pixel_error``: float array or None. - ``temporal_norm``: float array or None. - ``pca_singleview_error``: float array or None. - ``pca_multiview_error``: float array or None. """ def _metric(df: pd.DataFrame | None) -> np.ndarray | None: if df is None: return None cols = [c for c in df.columns if c != 'set'] return df[cols].to_numpy() m = self.metrics return { 'keypoint_names': list(self.predictions.columns.get_level_values(1).unique()), 'index': list(self.predictions.index), 'x': self.predictions.xs('x', level=2, axis=1).to_numpy(), 'y': self.predictions.xs('y', level=2, axis=1).to_numpy(), 'confidence': self.predictions.xs('likelihood', level=2, axis=1).to_numpy(), 'pixel_error': _metric(m.pixel_error_df) if m else None, 'temporal_norm': _metric(m.temporal_norm_df) if m else None, 'pca_singleview_error': _metric(m.pca_sv_df) if m else None, 'pca_multiview_error': _metric(m.pca_mv_df) if m else None, }
[docs] @dataclass class MultiviewPredictionResult: """ """ # suppresses sphinx class doc from getting autogenerated from __init__. predictions: dict[str, pd.DataFrame] metrics: dict[str, ComputeMetricsSingleResult] | None = None
[docs] def to_dict(self) -> dict[str, dict[str, Any]]: """Return predictions and metrics for each view as a flat dict of named numpy arrays. Wraps :meth:`PredictionResult.to_dict` for each view. Returns: dict keyed by view name, where each value is the ``to_dict()`` output for that view. """ return { view: PredictionResult( predictions=df, metrics=self.metrics.get(view) if self.metrics else None, ).to_dict() for view, df in self.predictions.items() }
[docs] @dataclass class ComputeMetricsSingleResult: """ """ # suppresses sphinx class doc from getting autogenerated from __init__. pixel_error_df: pd.DataFrame | None = None temporal_norm_df: pd.DataFrame | None = None pca_sv_df: pd.DataFrame | None = None pca_mv_df: pd.DataFrame | None = None
[docs] class BaseLabeledExampleDict(TypedDict): """Return type when calling __getitem__() on BaseTrackingDataset.""" images: ( Float[torch.Tensor, "RGB image_height image_width"] | Float[torch.Tensor, "frames RGB image_height image_width"] ) keypoints: Float[torch.Tensor, "num_targets"] bbox: Float[torch.Tensor, "xyhw"] idxs: int visibility: Int[torch.Tensor, "num_keypoints"]
[docs] class HeatmapLabeledExampleDict(BaseLabeledExampleDict): """Return type when calling __getitem__() on HeatmapTrackingDataset.""" heatmaps: Float[torch.Tensor, "num_keypoints heatmap_height heatmap_width"]
[docs] class MultiviewLabeledExampleDict(TypedDict): """Return type when calling __getitem__() on MultiviewDataset.""" images: ( Float[torch.Tensor, "num_views RGB image_height image_width"] | Float[torch.Tensor, "num_views frames RGB image_height image_width"] ) keypoints: Float[torch.Tensor, "num_targets"] bbox: Float[torch.Tensor, "num_views xyhw"] idxs: int num_views: int concat_order: list[str] view_names: list[str] # these attributes exist if camera calibration info is available keypoints_3d: Float[torch.Tensor, "num_keypoints 3"] | Float[torch.Tensor, "1"] | torch.Tensor intrinsic_matrix: ( Float[torch.Tensor, "num_views 3 3"] | Float[torch.Tensor, "1"] | torch.Tensor ) extrinsic_matrix: ( Float[torch.Tensor, "num_views 3 4"] | Float[torch.Tensor, "1"] | torch.Tensor ) distortions: ( Float[torch.Tensor, "num_views num_distortion_params"] | Float[torch.Tensor, "1"] | torch.Tensor )
# for distortion params info see # https://kornia.readthedocs.io/en/latest/geometry.calibration.html
[docs] class MultiviewHeatmapLabeledExampleDict(MultiviewLabeledExampleDict): """Return type when calling __getitem__() on MultiviewHeatmapDataset.""" heatmaps: Float[torch.Tensor, "num_keypoints heatmap_height heatmap_width"]
[docs] class BaseLabeledBatchDict(TypedDict): """Batch type for base labeled data.""" images: ( Float[torch.Tensor, "batch RGB image_height image_width"] | Float[torch.Tensor, "batch frames RGB image_height image_width"] ) keypoints: Float[torch.Tensor, "batch num_targets"] bbox: Float[torch.Tensor, "batch xyhw"] idxs: Int[torch.Tensor, "batch"]
[docs] class HeatmapLabeledBatchDict(BaseLabeledBatchDict): """Batch type for heatmap labeled data.""" heatmaps: Float[torch.Tensor, "batch num_keypoints heatmap_height heatmap_width"]
[docs] class MultiviewLabeledBatchDict(TypedDict): """Batch type for multiview labeled data.""" images: ( Float[torch.Tensor, "batch num_views RGB image_height image_width"] | Float[torch.Tensor, "batch num_views frames RGB image_height image_width"] ) keypoints: Float[torch.Tensor, "batch num_targets"] bbox: Float[torch.Tensor, "batch num_views_x_xyhw"] idxs: Int[torch.Tensor, "batch"] num_views: Int[torch.Tensor, "batch"] concat_order: list # [Tuple[str]] view_names: list # [Tuple[str]] # these attributes exist if camera calibration info is available keypoints_3d: Float[torch.Tensor, "batch num_keypoints 3"] | Float[torch.Tensor, "batch 1"] intrinsic_matrix: Float[torch.Tensor, "batch num_views 3 3"] | Float[torch.Tensor, "batch 1"] extrinsic_matrix: Float[torch.Tensor, "batch num_views 3 4"] | Float[torch.Tensor, "batch 1"] distortions: ( Float[torch.Tensor, "batch num_views num_distortion_params"] | Float[torch.Tensor, "batch 1"] )
[docs] class MultiviewHeatmapLabeledBatchDict(MultiviewLabeledBatchDict): """Batch type for multiview heatmap labeled data.""" heatmaps: Float[torch.Tensor, "batch num_keypoints heatmap_height heatmap_width"]
[docs] class UnlabeledBatchDict(TypedDict): """Batch type for unlabeled data.""" frames: Float[torch.Tensor, "seq_len RGB image_height image_width"] transforms: ( Float[torch.Tensor, "seq_len h w"] | Float[torch.Tensor, "h w"] | Float[torch.Tensor, "seq_len 1"] | Float[torch.Tensor, "1"] | torch.Tensor ) # transforms shapes # (seq_len, 2, 3): different transform for each sequence # (2, 3): same transform for all returned frames/keypoints # (seq_len, 1): no transforms # (1,): no transforms # torch.Tensor: necessary, getting error about torch.AnnotatedAlias that I don't understand bbox: Float[torch.Tensor, "seq_len xyhw"] is_multiview: bool # always False for this type; isinstance fails on TypedDicts
[docs] class MultiviewUnlabeledBatchDict(TypedDict): """Batch type for multiview unlabeled data.""" frames: Float[torch.Tensor, "seq_len num_views RGB image_height image_width"] transforms: ( Float[torch.Tensor, "num_views h w"] | Float[torch.Tensor, "num_views 1 1"] | torch.Tensor ) bbox: Float[torch.Tensor, "seq_len num_views_x_xyhw"] is_multiview: bool # always True for this type; isinstance fails on TypedDicts
[docs] class SemiSupervisedBatchDict(TypedDict): """Batch type for base labeled+unlabeled data.""" labeled: BaseLabeledBatchDict | MultiviewLabeledBatchDict unlabeled: UnlabeledBatchDict | MultiviewUnlabeledBatchDict
[docs] class SemiSupervisedHeatmapBatchDict(TypedDict): """Batch type for heatmap labeled+unlabeled data.""" labeled: HeatmapLabeledBatchDict | MultiviewHeatmapLabeledBatchDict unlabeled: UnlabeledBatchDict | MultiviewUnlabeledBatchDict
[docs] class SemiSupervisedDataLoaderDict(TypedDict): """Return type when calling train/val/test_dataloader() on semi-supervised models.""" labeled: torch.utils.data.DataLoader unlabeled: DALIGenericIterator