Source code for lightning_pose.data.cameras

"""Camera geometry utilities for multi-view 2D-to-3D projection and triangulation."""

from __future__ import annotations

import copy
import itertools
from pathlib import Path

import cv2
import numpy as np
import torch
from aniposelib.cameras import CameraGroup as CameraGroupAnipose
from jaxtyping import Float
from kornia.geometry.calibration import distort_points, undistort_points
from kornia.geometry.camera import PinholeCamera
from kornia.geometry.epipolar import triangulate_points

# to ignore imports for sphix-autoapidoc
__all__ = [
    "project_camera_pairs_to_3d",
    "project_3d_to_2d",
    "CameraGroup",
]


[docs] def project_camera_pairs_to_3d( points: Float[torch.Tensor, "batch num_views num_keypoints 2"], intrinsics: Float[torch.Tensor, "batch num_views 3 3"], extrinsics: Float[torch.Tensor, "batch num_views 3 4"], dist: Float[torch.Tensor, "batch num_views num_params"], ) -> Float[torch.Tensor, "batch cam_pair num_keypoints 3"]: """Project 2D keypoints from each pair of cameras into 3D world space.""" num_batch, num_views, num_keypoints, _ = points.shape points = undistort_points( points=points, K=intrinsics, dist=dist, new_K=torch.eye(3, device=points.device).expand(num_batch, num_views, 3, 3), ) p3d = [] for j1, j2 in itertools.combinations(range(num_views), 2): points1 = points[:, j1, ...] points2 = points[:, j2, ...] # create a mask for valid keypoints # a keypoint is valid if it's not NaN in BOTH views valid_mask = ~( torch.isnan(points1).any(dim=-1) | torch.isnan(points2).any(dim=-1) ) # prepare points for triangulation tri = torch.full( (num_batch, num_keypoints, 3), float('nan'), device=points.device, dtype=points.dtype, ) # triangulate only valid points for batch_idx in range(num_batch): # get valid keypoint indices for this batch batch_valid_indices = torch.where(valid_mask[batch_idx])[0] if len(batch_valid_indices) > 0: # extract valid points for this batch batch_points1 = points1[batch_idx][valid_mask[batch_idx]] batch_points2 = points2[batch_idx][valid_mask[batch_idx]] # triangulate valid points batch_tri = triangulate_points( P1=extrinsics[batch_idx, j1], P2=extrinsics[batch_idx, j2], points1=batch_points1, points2=batch_points2, ) # place triangulated points back in the full tensor tri[batch_idx, valid_mask[batch_idx]] = batch_tri p3d.append(tri) return torch.stack(p3d, dim=1)
[docs] def project_3d_to_2d( points_3d: Float[torch.Tensor, "batch num_keypoints 3"], intrinsics: Float[torch.Tensor, "batch num_views 3 3"], extrinsics: Float[torch.Tensor, "batch num_views 3 4"], dist: Float[torch.Tensor, "batch num_views num_params"], ) -> Float[torch.Tensor, "batch num_views num_keypoints 2"]: """Project 3D keypoints to 2D using camera parameters. Fully vectorized and differentiable implementation. Args: points_3d: 3D points in world coordinates intrinsics: Camera intrinsic matrices (3x3) extrinsics: Camera extrinsic matrices (3x4) dist: Camera distortion parameters Returns: 2D projected points for each camera view """ num_batch, num_keypoints, _ = points_3d.shape num_views = intrinsics.shape[1] device = points_3d.device dtype = points_3d.dtype # Convert 3x3 intrinsics to 4x4 format K_4x4 = torch.eye( 4, device=device, dtype=dtype, ).unsqueeze(0).unsqueeze(0).repeat(num_batch, num_views, 1, 1) K_4x4[:, :, :3, :3] = intrinsics # Convert 3x4 extrinsics to 4x4 format E_4x4 = torch.eye( 4, device=device, dtype=dtype, ).unsqueeze(0).unsqueeze(0).repeat(num_batch, num_views, 1, 1) E_4x4[:, :, :3, :4] = extrinsics # Dummy height/width (not used in projection but required by PinholeCamera) height = torch.ones(num_batch, device=device, dtype=dtype) width = torch.ones(num_batch, device=device, dtype=dtype) # Initialize output points_2d = torch.full( (num_batch, num_views, num_keypoints, 2), float('nan'), device=device, dtype=dtype, ) # Process each view (we can't fully vectorize due to PinholeCamera API limitations) for view_idx in range(num_views): # Create cameras for all batches for this view cameras = PinholeCamera( intrinsics=K_4x4[:, view_idx], # (batch, 4, 4) extrinsics=E_4x4[:, view_idx], # (batch, 4, 4) height=height, # (batch,) width=width # (batch,) ) # Project all 3D points for all batches at once # PinholeCamera.project handles batching automatically projected_points = cameras.project(points_3d) # (batch, num_keypoints, 2) # Apply distortion for all batches at once has_distortion = torch.any(dist[:, view_idx] != 0, dim=-1) # (batch,) if torch.any(has_distortion): # Only apply distortion where needed, but in a vectorized way distorted_points = distort_points( points=projected_points, # (batch, num_keypoints, 2) K=intrinsics[:, view_idx], # (batch, 3, 3) dist=dist[:, view_idx] # (batch, num_params) ) # Use where to select distorted vs undistorted points final_points = torch.where( has_distortion.unsqueeze(-1).unsqueeze(-1), # (batch, 1, 1) distorted_points, projected_points ) else: final_points = projected_points # Assign to output tensor points_2d[:, view_idx] = final_points return points_2d
[docs] class CameraGroup(CameraGroupAnipose): """Inherit Anipose camera group and add new non-jitted triangulation method for dataloaders."""
[docs] def triangulate_fast(self, points: np.ndarray, undistort: bool = True) -> np.ndarray: """Given an CxNx2 array, this returns an Nx3 array of points, where N is the number of points and C is the number of cameras""" assert points.shape[0] == len(self.cameras), \ f"Invalid points shape, first dim should be equal to" \ f" number of cameras ({len(self.cameras)}), but shape is {points.shape}" one_point = False if len(points.shape) == 2: points = points.reshape(-1, 1, 2) one_point = True if undistort: new_points = np.empty(points.shape) for cnum, cam in enumerate(self.cameras): # must copy in order to satisfy opencv underneath sub = np.copy(points[cnum]) new_points[cnum] = cam.undistort_points(sub) points = new_points n_cams, n_points, _ = points.shape # type: ignore[misc] cam_Rt_mats = np.array([cam.get_extrinsics_mat()[:3] for cam in self.cameras]) p3d_allview_withnan = [] for j1, j2 in itertools.combinations(range(n_cams), 2): pts1, pts2 = points[j1], points[j2] Rt1, Rt2 = cam_Rt_mats[j1], cam_Rt_mats[j2] tri = cv2.triangulatePoints(Rt1, Rt2, pts1.T, pts2.T) tri = tri[:3] / tri[3] p3d_allview_withnan.append(tri.T) p3d_allview_withnan = np.array(p3d_allview_withnan) out = np.nanmedian(p3d_allview_withnan, axis=0) if one_point: out = out[0] return out
[docs] def copy(self) -> CameraGroup: """Return a shallow copy of this CameraGroup with copied cameras and metadata. Returns: A new ``CameraGroup`` instance with independent copies of all cameras and metadata. """ cameras = [cam.copy() for cam in self.cameras] metadata = copy.copy(self.metadata) return CameraGroup(cameras, metadata)
[docs] def copy_with_new_cameras(self, cameras: list) -> CameraGroup: """Create a new CameraGroup with the same properties but different cameras.""" new_group = copy.deepcopy(self) new_group.cameras = cameras return new_group
[docs] @classmethod def load(cls, path: str | Path) -> CameraGroup: """Load a CameraGroup from a file. Args: path: path to the serialized camera group file. Returns: A ``CameraGroup`` instance with the loaded camera parameters. """ parent_instance = super().load(path) # type: ignore[arg-type] # Load using parent class return cls(**vars(parent_instance)) # Return as CameraGroup