Source code for lightning_pose.data.datatypes

"""Classes to streamline data typechecking."""
from __future__ import annotations

from dataclasses import dataclass
from typing import Any, TypedDict

import numpy as np
import pandas as pd
import torch
from jaxtyping import Float, Int
from nvidia.dali.plugin.pytorch import DALIGenericIterator

# to ignore imports for sphix-autoapidoc
__all__ = [
    "BaseLabeledExampleDict",
    "HeatmapLabeledExampleDict",
    "MultiviewLabeledExampleDict",
    "MultiviewHeatmapLabeledExampleDict",
    "BaseLabeledBatchDict",
    "HeatmapLabeledBatchDict",
    "MultiviewLabeledBatchDict",
    "MultiviewHeatmapLabeledBatchDict",
    "UnlabeledBatchDict",
    "MultiviewUnlabeledBatchDict",
    "SemiSupervisedBatchDict",
    "SemiSupervisedHeatmapBatchDict",
    "SemiSupervisedDataLoaderDict",
]


[docs] @dataclass class PredictionResult: """ """ # suppresses sphinx class doc from getting autogenerated from __init__. predictions: pd.DataFrame metrics: ComputeMetricsSingleResult | None = None
[docs] def to_dict(self) -> dict[str, Any]: """Return predictions and metrics as a flat dict of named numpy arrays. All arrays have shape ``(n_frames, n_keypoints)`` and share the same row order. Metric arrays are ``None`` when the metric was not computed. Returns: dict with keys: - ``keypoint_names``: list of keypoint name strings. - ``index``: list of frame identifiers (file paths or integer indices). - ``x``: float array of predicted x coordinates. - ``y``: float array of predicted y coordinates. - ``confidence``: float array of per-keypoint likelihood in [0, 1]. - ``pixel_error``: float array or None. - ``temporal_norm``: float array or None. - ``pca_singleview_error``: float array or None. - ``pca_multiview_error``: float array or None. """ def _metric(df: pd.DataFrame | None) -> np.ndarray | None: if df is None: return None cols = [c for c in df.columns if c != 'set'] return df[cols].to_numpy() m = self.metrics return { 'keypoint_names': list(self.predictions.columns.get_level_values(1).unique()), 'index': list(self.predictions.index), 'x': self.predictions.xs('x', level=2, axis=1).to_numpy(), 'y': self.predictions.xs('y', level=2, axis=1).to_numpy(), 'confidence': self.predictions.xs('likelihood', level=2, axis=1).to_numpy(), 'pixel_error': _metric(m.pixel_error_df) if m else None, 'temporal_norm': _metric(m.temporal_norm_df) if m else None, 'pca_singleview_error': _metric(m.pca_sv_df) if m else None, 'pca_multiview_error': _metric(m.pca_mv_df) if m else None, }
[docs] @dataclass class MultiviewPredictionResult: """ """ # suppresses sphinx class doc from getting autogenerated from __init__. predictions: dict[str, pd.DataFrame] metrics: dict[str, ComputeMetricsSingleResult] | None = None
[docs] def to_dict(self) -> dict[str, dict[str, Any]]: """Return predictions and metrics for each view as a flat dict of named numpy arrays. Wraps :meth:`PredictionResult.to_dict` for each view. Returns: dict keyed by view name, where each value is the ``to_dict()`` output for that view. """ return { view: PredictionResult( predictions=df, metrics=self.metrics.get(view) if self.metrics else None, ).to_dict() for view, df in self.predictions.items() }
[docs] @dataclass class ComputeMetricsSingleResult: """ """ # suppresses sphinx class doc from getting autogenerated from __init__. pixel_error_df: pd.DataFrame | None = None temporal_norm_df: pd.DataFrame | None = None pca_sv_df: pd.DataFrame | None = None pca_mv_df: pd.DataFrame | None = None
[docs] class BaseLabeledExampleDict(TypedDict): """Return type when calling __getitem__() on BaseTrackingDataset.""" images: ( Float[torch.Tensor, "RGB image_height image_width"] | Float[torch.Tensor, "frames RGB image_height image_width"] ) keypoints: Float[torch.Tensor, "num_targets"] bbox: Float[torch.Tensor, "xyhw"] idxs: int
[docs] class HeatmapLabeledExampleDict(BaseLabeledExampleDict): """Return type when calling __getitem__() on HeatmapTrackingDataset.""" heatmaps: Float[torch.Tensor, "num_keypoints heatmap_height heatmap_width"]
[docs] class MultiviewLabeledExampleDict(TypedDict): """Return type when calling __getitem__() on MultiviewDataset.""" images: ( Float[torch.Tensor, "num_views RGB image_height image_width"] | Float[torch.Tensor, "num_views frames RGB image_height image_width"] ) keypoints: Float[torch.Tensor, "num_targets"] bbox: Float[torch.Tensor, "num_views xyhw"] idxs: int num_views: int concat_order: list[str] view_names: list[str] # these attributes exist if camera calibration info is available keypoints_3d: Float[torch.Tensor, "num_keypoints 3"] | Float[torch.Tensor, "1"] | torch.Tensor intrinsic_matrix: ( Float[torch.Tensor, "num_views 3 3"] | Float[torch.Tensor, "1"] | torch.Tensor ) extrinsic_matrix: ( Float[torch.Tensor, "num_views 3 4"] | Float[torch.Tensor, "1"] | torch.Tensor ) distortions: ( Float[torch.Tensor, "num_views num_distortion_params"] | Float[torch.Tensor, "1"] | torch.Tensor )
# for distortion params info see # https://kornia.readthedocs.io/en/latest/geometry.calibration.html
[docs] class MultiviewHeatmapLabeledExampleDict(MultiviewLabeledExampleDict): """Return type when calling __getitem__() on MultiviewHeatmapDataset.""" heatmaps: Float[torch.Tensor, "num_keypoints heatmap_height heatmap_width"]
[docs] class BaseLabeledBatchDict(TypedDict): """Batch type for base labeled data.""" images: ( Float[torch.Tensor, "batch RGB image_height image_width"] | Float[torch.Tensor, "batch frames RGB image_height image_width"] ) keypoints: Float[torch.Tensor, "batch num_targets"] bbox: Float[torch.Tensor, "batch xyhw"] idxs: Int[torch.Tensor, "batch"]
[docs] class HeatmapLabeledBatchDict(BaseLabeledBatchDict): """Batch type for heatmap labeled data.""" heatmaps: Float[torch.Tensor, "batch num_keypoints heatmap_height heatmap_width"]
[docs] class MultiviewLabeledBatchDict(TypedDict): """Batch type for multiview labeled data.""" images: ( Float[torch.Tensor, "batch num_views RGB image_height image_width"] | Float[torch.Tensor, "batch num_views frames RGB image_height image_width"] ) keypoints: Float[torch.Tensor, "batch num_targets"] bbox: Float[torch.Tensor, "batch num_views_x_xyhw"] idxs: Int[torch.Tensor, "batch"] num_views: Int[torch.Tensor, "batch"] concat_order: list # [Tuple[str]] view_names: list # [Tuple[str]] # these attributes exist if camera calibration info is available keypoints_3d: Float[torch.Tensor, "batch num_keypoints 3"] | Float[torch.Tensor, "batch 1"] intrinsic_matrix: Float[torch.Tensor, "batch num_views 3 3"] | Float[torch.Tensor, "batch 1"] extrinsic_matrix: Float[torch.Tensor, "batch num_views 3 4"] | Float[torch.Tensor, "batch 1"] distortions: ( Float[torch.Tensor, "batch num_views num_distortion_params"] | Float[torch.Tensor, "batch 1"] )
[docs] class MultiviewHeatmapLabeledBatchDict(MultiviewLabeledBatchDict): """Batch type for multiview heatmap labeled data.""" heatmaps: Float[torch.Tensor, "batch num_keypoints heatmap_height heatmap_width"]
[docs] class UnlabeledBatchDict(TypedDict): """Batch type for unlabeled data.""" frames: Float[torch.Tensor, "seq_len RGB image_height image_width"] transforms: ( Float[torch.Tensor, "seq_len h w"] | Float[torch.Tensor, "h w"] | Float[torch.Tensor, "seq_len 1"] | Float[torch.Tensor, "1"] | torch.Tensor ) # transforms shapes # (seq_len, 2, 3): different transform for each sequence # (2, 3): same transform for all returned frames/keypoints # (seq_len, 1): no transforms # (1,): no transforms # torch.Tensor: necessary, getting error about torch.AnnotatedAlias that I don't understand bbox: Float[torch.Tensor, "seq_len xyhw"] is_multiview: bool # always False for this type; isinstance fails on TypedDicts
[docs] class MultiviewUnlabeledBatchDict(TypedDict): """Batch type for multiview unlabeled data.""" frames: Float[torch.Tensor, "seq_len num_views RGB image_height image_width"] transforms: ( Float[torch.Tensor, "num_views h w"] | Float[torch.Tensor, "num_views 1 1"] | torch.Tensor ) bbox: Float[torch.Tensor, "seq_len num_views_x_xyhw"] is_multiview: bool # always True for this type; isinstance fails on TypedDicts
[docs] class SemiSupervisedBatchDict(TypedDict): """Batch type for base labeled+unlabeled data.""" labeled: BaseLabeledBatchDict | MultiviewLabeledBatchDict unlabeled: UnlabeledBatchDict | MultiviewUnlabeledBatchDict
[docs] class SemiSupervisedHeatmapBatchDict(TypedDict): """Batch type for heatmap labeled+unlabeled data.""" labeled: HeatmapLabeledBatchDict | MultiviewHeatmapLabeledBatchDict unlabeled: UnlabeledBatchDict | MultiviewUnlabeledBatchDict
[docs] class SemiSupervisedDataLoaderDict(TypedDict): """Return type when calling train/val/test_dataloader() on semi-supervised models.""" labeled: torch.utils.data.DataLoader unlabeled: DALIGenericIterator