"""Classes to streamline data typechecking."""
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, TypedDict
import numpy as np
import pandas as pd
import torch
from jaxtyping import Float, Int
from nvidia.dali.plugin.pytorch import DALIGenericIterator
# to ignore imports for sphix-autoapidoc
__all__ = [
"BaseLabeledExampleDict",
"HeatmapLabeledExampleDict",
"MultiviewLabeledExampleDict",
"MultiviewHeatmapLabeledExampleDict",
"BaseLabeledBatchDict",
"HeatmapLabeledBatchDict",
"MultiviewLabeledBatchDict",
"MultiviewHeatmapLabeledBatchDict",
"UnlabeledBatchDict",
"MultiviewUnlabeledBatchDict",
"SemiSupervisedBatchDict",
"SemiSupervisedHeatmapBatchDict",
"SemiSupervisedDataLoaderDict",
]
[docs]
@dataclass
class PredictionResult:
""" """ # suppresses sphinx class doc from getting autogenerated from __init__.
predictions: pd.DataFrame
metrics: ComputeMetricsSingleResult | None = None
[docs]
def to_dict(self) -> dict[str, Any]:
"""Return predictions and metrics as a flat dict of named numpy arrays.
All arrays have shape ``(n_frames, n_keypoints)`` and share the same row
order. Metric arrays are ``None`` when the metric was not computed.
Returns:
dict with keys:
- ``keypoint_names``: list of keypoint name strings.
- ``index``: list of frame identifiers (file paths or integer indices).
- ``x``: float array of predicted x coordinates.
- ``y``: float array of predicted y coordinates.
- ``confidence``: float array of per-keypoint likelihood in [0, 1].
- ``pixel_error``: float array or None.
- ``temporal_norm``: float array or None.
- ``pca_singleview_error``: float array or None.
- ``pca_multiview_error``: float array or None.
"""
def _metric(df: pd.DataFrame | None) -> np.ndarray | None:
if df is None:
return None
cols = [c for c in df.columns if c != 'set']
return df[cols].to_numpy()
m = self.metrics
return {
'keypoint_names': list(self.predictions.columns.get_level_values(1).unique()),
'index': list(self.predictions.index),
'x': self.predictions.xs('x', level=2, axis=1).to_numpy(),
'y': self.predictions.xs('y', level=2, axis=1).to_numpy(),
'confidence': self.predictions.xs('likelihood', level=2, axis=1).to_numpy(),
'pixel_error': _metric(m.pixel_error_df) if m else None,
'temporal_norm': _metric(m.temporal_norm_df) if m else None,
'pca_singleview_error': _metric(m.pca_sv_df) if m else None,
'pca_multiview_error': _metric(m.pca_mv_df) if m else None,
}
[docs]
@dataclass
class MultiviewPredictionResult:
""" """ # suppresses sphinx class doc from getting autogenerated from __init__.
predictions: dict[str, pd.DataFrame]
metrics: dict[str, ComputeMetricsSingleResult] | None = None
[docs]
def to_dict(self) -> dict[str, dict[str, Any]]:
"""Return predictions and metrics for each view as a flat dict of named numpy arrays.
Wraps :meth:`PredictionResult.to_dict` for each view.
Returns:
dict keyed by view name, where each value is the ``to_dict()`` output for that view.
"""
return {
view: PredictionResult(
predictions=df,
metrics=self.metrics.get(view) if self.metrics else None,
).to_dict()
for view, df in self.predictions.items()
}
[docs]
@dataclass
class ComputeMetricsSingleResult:
""" """ # suppresses sphinx class doc from getting autogenerated from __init__.
pixel_error_df: pd.DataFrame | None = None
temporal_norm_df: pd.DataFrame | None = None
pca_sv_df: pd.DataFrame | None = None
pca_mv_df: pd.DataFrame | None = None
[docs]
class BaseLabeledExampleDict(TypedDict):
"""Return type when calling __getitem__() on BaseTrackingDataset."""
images: (
Float[torch.Tensor, "RGB image_height image_width"]
| Float[torch.Tensor, "frames RGB image_height image_width"]
)
keypoints: Float[torch.Tensor, "num_targets"]
bbox: Float[torch.Tensor, "xyhw"]
idxs: int
[docs]
class HeatmapLabeledExampleDict(BaseLabeledExampleDict):
"""Return type when calling __getitem__() on HeatmapTrackingDataset."""
heatmaps: Float[torch.Tensor, "num_keypoints heatmap_height heatmap_width"]
[docs]
class MultiviewLabeledExampleDict(TypedDict):
"""Return type when calling __getitem__() on MultiviewDataset."""
images: (
Float[torch.Tensor, "num_views RGB image_height image_width"]
| Float[torch.Tensor, "num_views frames RGB image_height image_width"]
)
keypoints: Float[torch.Tensor, "num_targets"]
bbox: Float[torch.Tensor, "num_views xyhw"]
idxs: int
num_views: int
concat_order: list[str]
view_names: list[str]
# these attributes exist if camera calibration info is available
keypoints_3d: Float[torch.Tensor, "num_keypoints 3"] | Float[torch.Tensor, "1"] | torch.Tensor
intrinsic_matrix: (
Float[torch.Tensor, "num_views 3 3"] | Float[torch.Tensor, "1"] | torch.Tensor
)
extrinsic_matrix: (
Float[torch.Tensor, "num_views 3 4"] | Float[torch.Tensor, "1"] | torch.Tensor
)
distortions: (
Float[torch.Tensor, "num_views num_distortion_params"]
| Float[torch.Tensor, "1"]
| torch.Tensor
)
# for distortion params info see
# https://kornia.readthedocs.io/en/latest/geometry.calibration.html
[docs]
class MultiviewHeatmapLabeledExampleDict(MultiviewLabeledExampleDict):
"""Return type when calling __getitem__() on MultiviewHeatmapDataset."""
heatmaps: Float[torch.Tensor, "num_keypoints heatmap_height heatmap_width"]
[docs]
class BaseLabeledBatchDict(TypedDict):
"""Batch type for base labeled data."""
images: (
Float[torch.Tensor, "batch RGB image_height image_width"]
| Float[torch.Tensor, "batch frames RGB image_height image_width"]
)
keypoints: Float[torch.Tensor, "batch num_targets"]
bbox: Float[torch.Tensor, "batch xyhw"]
idxs: Int[torch.Tensor, "batch"]
[docs]
class HeatmapLabeledBatchDict(BaseLabeledBatchDict):
"""Batch type for heatmap labeled data."""
heatmaps: Float[torch.Tensor, "batch num_keypoints heatmap_height heatmap_width"]
[docs]
class MultiviewLabeledBatchDict(TypedDict):
"""Batch type for multiview labeled data."""
images: (
Float[torch.Tensor, "batch num_views RGB image_height image_width"]
| Float[torch.Tensor, "batch num_views frames RGB image_height image_width"]
)
keypoints: Float[torch.Tensor, "batch num_targets"]
bbox: Float[torch.Tensor, "batch num_views_x_xyhw"]
idxs: Int[torch.Tensor, "batch"]
num_views: Int[torch.Tensor, "batch"]
concat_order: list # [Tuple[str]]
view_names: list # [Tuple[str]]
# these attributes exist if camera calibration info is available
keypoints_3d: Float[torch.Tensor, "batch num_keypoints 3"] | Float[torch.Tensor, "batch 1"]
intrinsic_matrix: Float[torch.Tensor, "batch num_views 3 3"] | Float[torch.Tensor, "batch 1"]
extrinsic_matrix: Float[torch.Tensor, "batch num_views 3 4"] | Float[torch.Tensor, "batch 1"]
distortions: (
Float[torch.Tensor, "batch num_views num_distortion_params"]
| Float[torch.Tensor, "batch 1"]
)
[docs]
class MultiviewHeatmapLabeledBatchDict(MultiviewLabeledBatchDict):
"""Batch type for multiview heatmap labeled data."""
heatmaps: Float[torch.Tensor, "batch num_keypoints heatmap_height heatmap_width"]
[docs]
class UnlabeledBatchDict(TypedDict):
"""Batch type for unlabeled data."""
frames: Float[torch.Tensor, "seq_len RGB image_height image_width"]
transforms: (
Float[torch.Tensor, "seq_len h w"]
| Float[torch.Tensor, "h w"]
| Float[torch.Tensor, "seq_len 1"]
| Float[torch.Tensor, "1"]
| torch.Tensor
)
# transforms shapes
# (seq_len, 2, 3): different transform for each sequence
# (2, 3): same transform for all returned frames/keypoints
# (seq_len, 1): no transforms
# (1,): no transforms
# torch.Tensor: necessary, getting error about torch.AnnotatedAlias that I don't understand
bbox: Float[torch.Tensor, "seq_len xyhw"]
is_multiview: bool # always False for this type; isinstance fails on TypedDicts
[docs]
class MultiviewUnlabeledBatchDict(TypedDict):
"""Batch type for multiview unlabeled data."""
frames: Float[torch.Tensor, "seq_len num_views RGB image_height image_width"]
transforms: (
Float[torch.Tensor, "num_views h w"]
| Float[torch.Tensor, "num_views 1 1"]
| torch.Tensor
)
bbox: Float[torch.Tensor, "seq_len num_views_x_xyhw"]
is_multiview: bool # always True for this type; isinstance fails on TypedDicts
[docs]
class SemiSupervisedBatchDict(TypedDict):
"""Batch type for base labeled+unlabeled data."""
labeled: BaseLabeledBatchDict | MultiviewLabeledBatchDict
unlabeled: UnlabeledBatchDict | MultiviewUnlabeledBatchDict
[docs]
class SemiSupervisedHeatmapBatchDict(TypedDict):
"""Batch type for heatmap labeled+unlabeled data."""
labeled: HeatmapLabeledBatchDict | MultiviewHeatmapLabeledBatchDict
unlabeled: UnlabeledBatchDict | MultiviewUnlabeledBatchDict
[docs]
class SemiSupervisedDataLoaderDict(TypedDict):
"""Return type when calling train/val/test_dataloader() on semi-supervised models."""
labeled: torch.utils.data.DataLoader
unlabeled: DALIGenericIterator