"""Classes to streamline data typechecking."""
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, TypedDict
import numpy as np
import pandas as pd
import torch
from jaxtyping import Float, Int
if TYPE_CHECKING:
from nvidia.dali.plugin.pytorch import DALIGenericIterator
# to ignore imports for sphix-autoapidoc
__all__ = [
"BaseLabeledExampleDict",
"HeatmapLabeledExampleDict",
"MultiviewLabeledExampleDict",
"MultiviewHeatmapLabeledExampleDict",
"BaseLabeledBatchDict",
"HeatmapLabeledBatchDict",
"MultiviewLabeledBatchDict",
"MultiviewHeatmapLabeledBatchDict",
"UnlabeledBatchDict",
"MultiviewUnlabeledBatchDict",
"SemiSupervisedBatchDict",
"SemiSupervisedHeatmapBatchDict",
"SemiSupervisedDataLoaderDict",
]
[docs]
@dataclass
class PredictionResult:
""" """ # suppresses sphinx class doc from getting autogenerated from __init__.
predictions: pd.DataFrame
metrics: ComputeMetricsSingleResult | None = None
[docs]
def to_dict(self) -> dict[str, Any]:
"""Return predictions and metrics as a flat dict of named numpy arrays.
All arrays have shape ``(n_frames, n_keypoints)`` and share the same row
order. Metric arrays are ``None`` when the metric was not computed.
Returns:
dict with keys:
- ``keypoint_names``: list of keypoint name strings.
- ``index``: list of frame identifiers (file paths or integer indices).
- ``x``: float array of predicted x coordinates.
- ``y``: float array of predicted y coordinates.
- ``confidence``: float array of per-keypoint likelihood in [0, 1].
- ``pixel_error``: float array or None.
- ``temporal_norm``: float array or None.
- ``pca_singleview_error``: float array or None.
- ``pca_multiview_error``: float array or None.
"""
def _metric(df: pd.DataFrame | None) -> np.ndarray | None:
if df is None:
return None
cols = [c for c in df.columns if c != 'set']
return df[cols].to_numpy()
m = self.metrics
return {
'keypoint_names': list(self.predictions.columns.get_level_values(1).unique()),
'index': list(self.predictions.index),
'x': self.predictions.xs('x', level=2, axis=1).to_numpy(),
'y': self.predictions.xs('y', level=2, axis=1).to_numpy(),
'confidence': self.predictions.xs('likelihood', level=2, axis=1).to_numpy(),
'pixel_error': _metric(m.pixel_error_df) if m else None,
'temporal_norm': _metric(m.temporal_norm_df) if m else None,
'pca_singleview_error': _metric(m.pca_sv_df) if m else None,
'pca_multiview_error': _metric(m.pca_mv_df) if m else None,
}
[docs]
@dataclass
class MultiviewPredictionResult:
""" """ # suppresses sphinx class doc from getting autogenerated from __init__.
predictions: dict[str, pd.DataFrame]
metrics: dict[str, ComputeMetricsSingleResult] | None = None
[docs]
def to_dict(self) -> dict[str, dict[str, Any]]:
"""Return predictions and metrics for each view as a flat dict of named numpy arrays.
Wraps :meth:`PredictionResult.to_dict` for each view.
Returns:
dict keyed by view name, where each value is the ``to_dict()`` output for that view.
"""
return {
view: PredictionResult(
predictions=df,
metrics=self.metrics.get(view) if self.metrics else None,
).to_dict()
for view, df in self.predictions.items()
}
[docs]
@dataclass
class ComputeMetricsSingleResult:
""" """ # suppresses sphinx class doc from getting autogenerated from __init__.
pixel_error_df: pd.DataFrame | None = None
temporal_norm_df: pd.DataFrame | None = None
pca_sv_df: pd.DataFrame | None = None
pca_mv_df: pd.DataFrame | None = None
[docs]
class BaseLabeledExampleDict(TypedDict):
"""Return type when calling __getitem__() on BaseTrackingDataset."""
images: (
Float[torch.Tensor, "RGB image_height image_width"]
| Float[torch.Tensor, "frames RGB image_height image_width"]
)
keypoints: Float[torch.Tensor, "num_targets"]
bbox: Float[torch.Tensor, "xyhw"]
idxs: int
visibility: Int[torch.Tensor, "num_keypoints"]
[docs]
class HeatmapLabeledExampleDict(BaseLabeledExampleDict):
"""Return type when calling __getitem__() on HeatmapTrackingDataset."""
heatmaps: Float[torch.Tensor, "num_keypoints heatmap_height heatmap_width"]
[docs]
class MultiviewLabeledExampleDict(TypedDict):
"""Return type when calling __getitem__() on MultiviewDataset."""
images: (
Float[torch.Tensor, "num_views RGB image_height image_width"]
| Float[torch.Tensor, "num_views frames RGB image_height image_width"]
)
keypoints: Float[torch.Tensor, "num_targets"]
bbox: Float[torch.Tensor, "num_views xyhw"]
idxs: int
num_views: int
concat_order: list[str]
view_names: list[str]
# these attributes exist if camera calibration info is available
keypoints_3d: Float[torch.Tensor, "num_keypoints 3"] | Float[torch.Tensor, "1"] | torch.Tensor
intrinsic_matrix: (
Float[torch.Tensor, "num_views 3 3"] | Float[torch.Tensor, "1"] | torch.Tensor
)
extrinsic_matrix: (
Float[torch.Tensor, "num_views 3 4"] | Float[torch.Tensor, "1"] | torch.Tensor
)
distortions: (
Float[torch.Tensor, "num_views num_distortion_params"]
| Float[torch.Tensor, "1"]
| torch.Tensor
)
# for distortion params info see
# https://kornia.readthedocs.io/en/latest/geometry.calibration.html
[docs]
class MultiviewHeatmapLabeledExampleDict(MultiviewLabeledExampleDict):
"""Return type when calling __getitem__() on MultiviewHeatmapDataset."""
heatmaps: Float[torch.Tensor, "num_keypoints heatmap_height heatmap_width"]
[docs]
class BaseLabeledBatchDict(TypedDict):
"""Batch type for base labeled data."""
images: (
Float[torch.Tensor, "batch RGB image_height image_width"]
| Float[torch.Tensor, "batch frames RGB image_height image_width"]
)
keypoints: Float[torch.Tensor, "batch num_targets"]
bbox: Float[torch.Tensor, "batch xyhw"]
idxs: Int[torch.Tensor, "batch"]
[docs]
class HeatmapLabeledBatchDict(BaseLabeledBatchDict):
"""Batch type for heatmap labeled data."""
heatmaps: Float[torch.Tensor, "batch num_keypoints heatmap_height heatmap_width"]
[docs]
class MultiviewLabeledBatchDict(TypedDict):
"""Batch type for multiview labeled data."""
images: (
Float[torch.Tensor, "batch num_views RGB image_height image_width"]
| Float[torch.Tensor, "batch num_views frames RGB image_height image_width"]
)
keypoints: Float[torch.Tensor, "batch num_targets"]
bbox: Float[torch.Tensor, "batch num_views_x_xyhw"]
idxs: Int[torch.Tensor, "batch"]
num_views: Int[torch.Tensor, "batch"]
concat_order: list # [Tuple[str]]
view_names: list # [Tuple[str]]
# these attributes exist if camera calibration info is available
keypoints_3d: Float[torch.Tensor, "batch num_keypoints 3"] | Float[torch.Tensor, "batch 1"]
intrinsic_matrix: Float[torch.Tensor, "batch num_views 3 3"] | Float[torch.Tensor, "batch 1"]
extrinsic_matrix: Float[torch.Tensor, "batch num_views 3 4"] | Float[torch.Tensor, "batch 1"]
distortions: (
Float[torch.Tensor, "batch num_views num_distortion_params"]
| Float[torch.Tensor, "batch 1"]
)
[docs]
class MultiviewHeatmapLabeledBatchDict(MultiviewLabeledBatchDict):
"""Batch type for multiview heatmap labeled data."""
heatmaps: Float[torch.Tensor, "batch num_keypoints heatmap_height heatmap_width"]
[docs]
class UnlabeledBatchDict(TypedDict):
"""Batch type for unlabeled data."""
frames: Float[torch.Tensor, "seq_len RGB image_height image_width"]
transforms: (
Float[torch.Tensor, "seq_len h w"]
| Float[torch.Tensor, "h w"]
| Float[torch.Tensor, "seq_len 1"]
| Float[torch.Tensor, "1"]
| torch.Tensor
)
# transforms shapes
# (seq_len, 2, 3): different transform for each sequence
# (2, 3): same transform for all returned frames/keypoints
# (seq_len, 1): no transforms
# (1,): no transforms
# torch.Tensor: necessary, getting error about torch.AnnotatedAlias that I don't understand
bbox: Float[torch.Tensor, "seq_len xyhw"]
is_multiview: bool # always False for this type; isinstance fails on TypedDicts
[docs]
class MultiviewUnlabeledBatchDict(TypedDict):
"""Batch type for multiview unlabeled data."""
frames: Float[torch.Tensor, "seq_len num_views RGB image_height image_width"]
transforms: (
Float[torch.Tensor, "num_views h w"]
| Float[torch.Tensor, "num_views 1 1"]
| torch.Tensor
)
bbox: Float[torch.Tensor, "seq_len num_views_x_xyhw"]
is_multiview: bool # always True for this type; isinstance fails on TypedDicts
[docs]
class SemiSupervisedBatchDict(TypedDict):
"""Batch type for base labeled+unlabeled data."""
labeled: BaseLabeledBatchDict | MultiviewLabeledBatchDict
unlabeled: UnlabeledBatchDict | MultiviewUnlabeledBatchDict
[docs]
class SemiSupervisedHeatmapBatchDict(TypedDict):
"""Batch type for heatmap labeled+unlabeled data."""
labeled: HeatmapLabeledBatchDict | MultiviewHeatmapLabeledBatchDict
unlabeled: UnlabeledBatchDict | MultiviewUnlabeledBatchDict
[docs]
class SemiSupervisedDataLoaderDict(TypedDict):
"""Return type when calling train/val/test_dataloader() on semi-supervised models."""
labeled: torch.utils.data.DataLoader
unlabeled: DALIGenericIterator