"""Classes to streamline data typechecking."""
from __future__ import annotations
from dataclasses import dataclass
from typing import TypedDict
import pandas as pd
import torch
from nvidia.dali.plugin.pytorch import DALIGenericIterator
from torchtyping import TensorType
# to ignore imports for sphix-autoapidoc
__all__ = [
"BaseLabeledExampleDict",
"HeatmapLabeledExampleDict",
"MultiviewLabeledExampleDict",
"MultiviewHeatmapLabeledExampleDict",
"BaseLabeledBatchDict",
"HeatmapLabeledBatchDict",
"MultiviewLabeledBatchDict",
"MultiviewHeatmapLabeledBatchDict",
"UnlabeledBatchDict",
"MultiviewUnlabeledBatchDict",
"SemiSupervisedBatchDict",
"SemiSupervisedHeatmapBatchDict",
"SemiSupervisedDataLoaderDict",
]
[docs]
@dataclass
class PredictionResult:
""" """ # suppresses sphinx class doc from getting autogenerated from __init__.
predictions: pd.DataFrame
metrics: ComputeMetricsSingleResult | None = None
[docs]
@dataclass
class MultiviewPredictionResult:
""" """ # suppresses sphinx class doc from getting autogenerated from __init__.
predictions: dict[str, pd.DataFrame]
metrics: dict[str, ComputeMetricsSingleResult] | None = None
[docs]
@dataclass
class ComputeMetricsSingleResult:
""" """ # suppresses sphinx class doc from getting autogenerated from __init__.
pixel_error_df: pd.DataFrame | None = None
temporal_norm_df: pd.DataFrame | None = None
pca_sv_df: pd.DataFrame | None = None
pca_mv_df: pd.DataFrame | None = None
[docs]
class BaseLabeledExampleDict(TypedDict):
"""Return type when calling __getitem__() on BaseTrackingDataset."""
images: (
TensorType["RGB":3, "image_height", "image_width", float]
| TensorType["frames", "RGB":3, "image_height", "image_width", float]
)
keypoints: TensorType["num_targets", float]
bbox: TensorType["xyhw":4, float]
idxs: int
[docs]
class HeatmapLabeledExampleDict(BaseLabeledExampleDict):
"""Return type when calling __getitem__() on HeatmapTrackingDataset."""
heatmaps: TensorType["num_keypoints", "heatmap_height", "heatmap_width", float]
[docs]
class MultiviewLabeledExampleDict(TypedDict):
"""Return type when calling __getitem__() on MultiviewDataset."""
images: (
TensorType["num_views", "RGB":3, "image_height", "image_width", float]
| TensorType["num_views", "frames", "RGB":3, "image_height", "image_width", float]
)
keypoints: TensorType["num_targets", float]
bbox: TensorType["num_views", "xyhw":4, float]
idxs: int
num_views: int
concat_order: list[str]
view_names: list[str]
# these attributes exist if camera calibration info is available
keypoints_3d: TensorType["num_keypoints", 3] | TensorType["null":1] | torch.Tensor
intrinsic_matrix: TensorType["num_views", 3, 3] | TensorType["null":1] | torch.Tensor
extrinsic_matrix: TensorType["num_views", 3, 4] | TensorType["null":1] | torch.Tensor
distortions: (
TensorType["num_views", "num_distortion_params"]
| TensorType["null":1]
| torch.Tensor
)
# for distortion params info see
# https://kornia.readthedocs.io/en/latest/geometry.calibration.html
[docs]
class MultiviewHeatmapLabeledExampleDict(MultiviewLabeledExampleDict):
"""Return type when calling __getitem__() on MultiviewHeatmapDataset."""
heatmaps: TensorType["num_keypoints", "heatmap_height", "heatmap_width", float]
[docs]
class BaseLabeledBatchDict(TypedDict):
"""Batch type for base labeled data."""
images: (
TensorType["batch", "RGB":3, "image_height", "image_width", float]
| TensorType["batch", "frames", "RGB":3, "image_height", "image_width", float]
)
keypoints: TensorType["batch", "num_targets", float]
bbox: TensorType["batch", "xyhw":4, float]
idxs: TensorType["batch", int]
[docs]
class HeatmapLabeledBatchDict(BaseLabeledBatchDict):
"""Batch type for heatmap labeled data."""
heatmaps: TensorType["batch", "num_keypoints", "heatmap_height", "heatmap_width", float]
[docs]
class MultiviewLabeledBatchDict(TypedDict):
"""Batch type for multiview labeled data."""
images: (
TensorType["batch", "num_views", "RGB":3, "image_height", "image_width", float]
| TensorType["batch", "num_views", "frames", "RGB":3, "image_height", "image_width", float]
)
keypoints: TensorType["batch", "num_targets", float]
bbox: TensorType["batch", "num_views * xyhw", float]
idxs: TensorType["batch", int]
num_views: TensorType["batch", int]
concat_order: list # [Tuple[str]]
view_names: list # [Tuple[str]]
# these attributes exist if camera calibration info is available
keypoints_3d: TensorType["batch", "num_keypoints", 3] | TensorType["batch", 1]
intrinsic_matrix: TensorType["batch", "num_views", 3, 3] | TensorType["batch", 1]
extrinsic_matrix: TensorType["batch", "num_views", 3, 4] | TensorType["batch", 1]
distortions: TensorType["batch", "num_views", "num_distortion_params"] | TensorType["batch", 1]
[docs]
class MultiviewHeatmapLabeledBatchDict(MultiviewLabeledBatchDict):
"""Batch type for multiview heatmap labeled data."""
heatmaps: TensorType["batch", "num_keypoints", "heatmap_height", "heatmap_width", float]
[docs]
class UnlabeledBatchDict(TypedDict):
"""Batch type for unlabeled data."""
frames: TensorType["seq_len", "RGB":3, "image_height", "image_width", float]
transforms: (
TensorType["seq_len", "h":2, "w":3, float]
| TensorType["h":2, "w":3, float]
| TensorType["seq_len", "null":1, float]
| TensorType["null":1, float]
| torch.Tensor
)
# transforms shapes
# (seq_len, 2, 3): different transform for each sequence
# (2, 3): same transform for all returned frames/keypoints
# (seq_len, 1): no transforms
# (1,): no transforms
# torch.Tensor: necessary, getting error about torch.AnnotatedAlias that I don't understand
bbox: TensorType["seq_len", "xyhw":4, float]
is_multiview: bool = False # helps with downstream logic since isinstance fails on TypedDicts
[docs]
class MultiviewUnlabeledBatchDict(TypedDict):
"""Batch type for multiview unlabeled data."""
frames: TensorType["seq_len", "num_views", "RGB":3, "image_height", "image_width", float]
transforms: (
TensorType["num_views", "h":2, "w":3, float]
| TensorType["num_views", "null":1, "null":1, float]
| torch.Tensor
)
bbox: TensorType["seq_len", "num_views * xyhw", float]
is_multiview: bool = True # helps with downstream logic since isinstance fails on TypedDicts
[docs]
class SemiSupervisedBatchDict(TypedDict):
"""Batch type for base labeled+unlabeled data."""
labeled: BaseLabeledBatchDict | MultiviewLabeledBatchDict
unlabeled: UnlabeledBatchDict | MultiviewUnlabeledBatchDict
[docs]
class SemiSupervisedHeatmapBatchDict(TypedDict):
"""Batch type for heatmap labeled+unlabeled data."""
labeled: HeatmapLabeledBatchDict | MultiviewHeatmapLabeledBatchDict
unlabeled: UnlabeledBatchDict | MultiviewUnlabeledBatchDict
[docs]
class SemiSupervisedDataLoaderDict(TypedDict):
"""Return type when calling train/val/test_dataloader() on semi-supervised models."""
labeled: torch.utils.data.DataLoader
unlabeled: DALIGenericIterator