HeatmapTrackerMultiviewTransformer

class lightning_pose.models.HeatmapTrackerMultiviewTransformer

Bases: BaseSupervisedTracker

Transformer network that handles multi-view datasets.

Methods Summary

forward(batch_dict)

Forward pass through the network.

forward_vit(images)

Override forward pass through the vision encoder to add view embeddings.

get_loss_inputs_labeled(batch_dict)

Return predicted heatmaps and their softmaxes (estimated keypoints).

get_parameters()

predict_step(batch_dict, batch_idx[, ...])

Predict heatmaps and keypoints for a batch of video frames.

Methods Documentation

forward(batch_dict: MultiviewHeatmapLabeledBatchDict) Tensor, {'__torchtyping__': True, 'details': ('num_valid_outputs', 'num_keypoints', 'heatmap_height', 'heatmap_width',), 'cls_name': 'TensorType'}][source]

Forward pass through the network.

Batch options

  • TensorType[“batch”, “view”, “channels”:3, “image_height”, “image_width”] multiview labeled batch or unlabeled batch from DALI

forward_vit(images: ~torch.Annotated[~torch.Tensor, {'__torchtyping__': True, 'details': ('view * batch', channels: 3, 'image_height', 'image_width',), 'cls_name': 'TensorType'}])[source]

Override forward pass through the vision encoder to add view embeddings.

get_loss_inputs_labeled(batch_dict: MultiviewHeatmapLabeledBatchDict) dict[source]

Return predicted heatmaps and their softmaxes (estimated keypoints).

get_parameters()[source]
predict_step(batch_dict: MultiviewHeatmapLabeledBatchDict | UnlabeledBatchDict, batch_idx: int, return_heatmaps: bool = False) Tuple[Tensor, Tensor] | Tuple[Tensor, Tensor, Tensor][source]

Predict heatmaps and keypoints for a batch of video frames.

Assuming a DALI video loader is passed in > trainer = Trainer(devices=8, accelerator=”gpu”) > predictions = trainer.predict(model, data_loader)

__init__(num_keypoints: int, num_views: int, loss_factory: LossFactory | None = None, backbone: Literal['vits_dino', 'vits_dinov2', 'vits_dinov3', 'vitb_dino', 'vitb_dinov2', 'vitb_dinov3', 'vitb_imagenet', 'vitb_sam'] = 'vits_dino', pretrained: bool = True, head: Literal['heatmap_cnn'] = 'heatmap_cnn', downsample_factor: Literal[1, 2, 3] = 2, torch_seed: int = 123, optimizer: str = 'Adam', optimizer_params: DictConfig | dict | None = None, lr_scheduler: str = 'multisteplr', lr_scheduler_params: DictConfig | dict | None = None, image_size: int = 256, **kwargs: Any)[source]

Initialize a multi-view model with transformer backbone. :param num_keypoints: number of body parts :param num_views: number of camera views :param loss_factory: object to orchestrate loss computation :param backbone: transformer variant to be used; cannot use convnets with this model :param pretrained: True to load pretrained imagenet weights :param head: architecture used to project per-view information to 2D heatmaps

  • heatmap_cnn

Parameters:
  • downsample_factor – make heatmap smaller than original frames to save memory; subpixel operations are performed for increased precision

  • torch_seed – make weight initialization reproducible

  • lr_scheduler – how to schedule learning rate

  • lr_scheduler_params – params for specific learning rate schedulers - multisteplr: milestones, gamma

  • image_size – size of input images (height=width for ViT models)

  • **kwargs – additional arguments

__new__(**kwargs)