Source code for lightning_pose.data.datatypes

"""Classes to streamline data typechecking."""
from __future__ import annotations

from dataclasses import dataclass
from typing import TypedDict, Union

import pandas as pd
import torch
from nvidia.dali.plugin.pytorch import DALIGenericIterator
from torchtyping import TensorType

# to ignore imports for sphix-autoapidoc
__all__ = [
    "BaseLabeledExampleDict",
    "HeatmapLabeledExampleDict",
    "MultiviewLabeledExampleDict",
    "MultiviewHeatmapLabeledExampleDict",
    "BaseLabeledBatchDict",
    "HeatmapLabeledBatchDict",
    "MultiviewLabeledBatchDict",
    "MultiviewHeatmapLabeledBatchDict",
    "UnlabeledBatchDict",
    "MultiviewUnlabeledBatchDict",
    "SemiSupervisedBatchDict",
    "SemiSupervisedHeatmapBatchDict",
    "SemiSupervisedDataLoaderDict",
]


[docs] @dataclass class PredictionResult: """ """ # suppresses sphinx class doc from getting autogenerated from __init__. predictions: pd.DataFrame metrics: ComputeMetricsSingleResult | None = None
[docs] @dataclass class MultiviewPredictionResult: """ """ # suppresses sphinx class doc from getting autogenerated from __init__. predictions: dict[str, pd.DataFrame] metrics: dict[str, ComputeMetricsSingleResult] | None = None
[docs] @dataclass class ComputeMetricsSingleResult: """ """ # suppresses sphinx class doc from getting autogenerated from __init__. pixel_error_df: pd.DataFrame | None = None temporal_norm_df: pd.DataFrame | None = None pca_sv_df: pd.DataFrame | None = None pca_mv_df: pd.DataFrame | None = None
[docs] class BaseLabeledExampleDict(TypedDict): """Return type when calling __getitem__() on BaseTrackingDataset.""" images: Union[ TensorType["RGB":3, "image_height", "image_width", float], TensorType["frames", "RGB":3, "image_height", "image_width", float], ] keypoints: TensorType["num_targets", float] bbox: TensorType["xyhw":4, float] idxs: int
[docs] class HeatmapLabeledExampleDict(BaseLabeledExampleDict): """Return type when calling __getitem__() on HeatmapTrackingDataset.""" heatmaps: TensorType["num_keypoints", "heatmap_height", "heatmap_width", float]
[docs] class MultiviewLabeledExampleDict(TypedDict): """Return type when calling __getitem__() on MultiviewDataset.""" images: Union[ TensorType["num_views", "RGB":3, "image_height", "image_width", float], TensorType["num_views", "frames", "RGB":3, "image_height", "image_width", float], ] keypoints: TensorType["num_targets", float] bbox: TensorType["num_views", "xyhw":4, float] idxs: int num_views: int concat_order: list[str] view_names: list[str] # these attributes exist if camera calibration info is available keypoints_3d: Union[ TensorType["num_keypoints", 3], TensorType["null":1], torch.Tensor, ] intrinsic_matrix: Union[ TensorType["num_views", 3, 3], TensorType["null":1], torch.Tensor, ] extrinsic_matrix: Union[ TensorType["num_views", 3, 4], TensorType["null":1], torch.Tensor, ] distortions: Union[ TensorType["num_views", "num_distortion_params"], TensorType["null":1], torch.Tensor, ]
# for distortion params info see # https://kornia.readthedocs.io/en/latest/geometry.calibration.html
[docs] class MultiviewHeatmapLabeledExampleDict(MultiviewLabeledExampleDict): """Return type when calling __getitem__() on MultiviewHeatmapDataset.""" heatmaps: TensorType["num_keypoints", "heatmap_height", "heatmap_width", float]
[docs] class BaseLabeledBatchDict(TypedDict): """Batch type for base labeled data.""" images: Union[ TensorType["batch", "RGB":3, "image_height", "image_width", float], TensorType["batch", "frames", "RGB":3, "image_height", "image_width", float], ] keypoints: TensorType["batch", "num_targets", float] bbox: TensorType["batch", "xyhw":4, float] idxs: TensorType["batch", int]
[docs] class HeatmapLabeledBatchDict(BaseLabeledBatchDict): """Batch type for heatmap labeled data.""" heatmaps: TensorType["batch", "num_keypoints", "heatmap_height", "heatmap_width", float]
[docs] class MultiviewLabeledBatchDict(TypedDict): """Batch type for multiview labeled data.""" images: Union[ TensorType["batch", "num_views", "RGB":3, "image_height", "image_width", float], TensorType["batch", "num_views", "frames", "RGB":3, "image_height", "image_width", float], ] keypoints: TensorType["batch", "num_targets", float] bbox: TensorType["batch", "num_views * xyhw", float] idxs: TensorType["batch", int] num_views: TensorType["batch", int] concat_order: list # [Tuple[str]] view_names: list # [Tuple[str]] # these attributes exist if camera calibration info is available keypoints_3d: Union[ TensorType["batch", "num_keypoints", 3], TensorType["batch", 1], ] intrinsic_matrix: Union[ TensorType["batch", "num_views", 3, 3], TensorType["batch", 1], ] extrinsic_matrix: Union[ TensorType["batch", "num_views", 3, 4], TensorType["batch", 1], ] distortions: Union[ TensorType["batch", "num_views", "num_distortion_params"], TensorType["batch", 1], ]
[docs] class MultiviewHeatmapLabeledBatchDict(MultiviewLabeledBatchDict): """Batch type for multiview heatmap labeled data.""" heatmaps: TensorType["batch", "num_keypoints", "heatmap_height", "heatmap_width", float]
[docs] class UnlabeledBatchDict(TypedDict): """Batch type for unlabeled data.""" frames: TensorType["seq_len", "RGB":3, "image_height", "image_width", float] transforms: Union[ TensorType["seq_len", "h":2, "w":3, float], TensorType["h":2, "w":3, float], TensorType["seq_len", "null":1, float], TensorType["null":1, float], torch.Tensor, ] # transforms shapes # (seq_len, 2, 3): different transform for each sequence # (2, 3): same transform for all returned frames/keypoints # (seq_len, 1): no transforms # (1,): no transforms # torch.Tensor: necessary, getting error about torch.AnnotatedAlias that I don't understand bbox: TensorType["seq_len", "xyhw":4, float] is_multiview: bool = False # helps with downstream logic since isinstance fails on TypedDicts
[docs] class MultiviewUnlabeledBatchDict(TypedDict): """Batch type for multiview unlabeled data.""" frames: TensorType["seq_len", "num_views", "RGB":3, "image_height", "image_width", float] transforms: Union[ TensorType["num_views", "h":2, "w":3, float], TensorType["num_views", "null":1, "null":1, float], torch.Tensor, ] bbox: TensorType["seq_len", "num_views * xyhw", float] is_multiview: bool = True # helps with downstream logic since isinstance fails on TypedDicts
[docs] class SemiSupervisedBatchDict(TypedDict): """Batch type for base labeled+unlabeled data.""" labeled: BaseLabeledBatchDict | MultiviewLabeledBatchDict unlabeled: UnlabeledBatchDict | MultiviewUnlabeledBatchDict
[docs] class SemiSupervisedHeatmapBatchDict(TypedDict): """Batch type for heatmap labeled+unlabeled data.""" labeled: HeatmapLabeledBatchDict | MultiviewHeatmapLabeledBatchDict unlabeled: UnlabeledBatchDict | MultiviewUnlabeledBatchDict
[docs] class SemiSupervisedDataLoaderDict(TypedDict): """Return type when calling train/val/test_dataloader() on semi-supervised models.""" labeled: torch.utils.data.DataLoader unlabeled: DALIGenericIterator