HeatmapTrackerMultiviewTransformerο
- class lightning_pose.models.HeatmapTrackerMultiviewTransformerο
Bases:
BaseSupervisedTrackerTransformer network that handles multi-view datasets.
Methods Summary
forward(batch_dict)Forward pass through the network.
forward_vit(images)Override forward pass through the vision encoder to add view embeddings.
get_loss_inputs_labeled(batch_dict)Return predicted heatmaps and their softmaxes (estimated keypoints).
Return per-parameter-group optimizer configuration for backbone, head, and embeddings.
predict_step(batch_dict,Β batch_idx[,Β ...])Predict heatmaps and keypoints for a batch of video frames.
Methods Documentation
- forward(batch_dict: MultiviewHeatmapLabeledBatchDict | UnlabeledBatchDict | MultiviewUnlabeledBatchDict) Float[Tensor, 'num_valid_outputs num_keypoints heatmap_height heatmap_width'][source]ο
Forward pass through the network.
Batch optionsο
Float[torch.Tensor, βbatch view channels image_height image_widthβ] multiview labeled batch or unlabeled batch from DALI
- forward_vit(images: Float[Tensor, 'view_x_batch channels image_height image_width']) Float[Tensor, 'view_x_batch embedding_dim height width'][source]ο
Override forward pass through the vision encoder to add view embeddings.
- get_loss_inputs_labeled(batch_dict: MultiviewHeatmapLabeledBatchDict) dict[source]ο
Return predicted heatmaps and their softmaxes (estimated keypoints).
- get_parameters() list[dict][source]ο
Return per-parameter-group optimizer configuration for backbone, head, and embeddings.
- Returns:
List of dicts with
"params"and"name"keys; the backbone starts with learning rate 0 (frozen until unfreezing), and view embeddings are trained normally.
- predict_step(batch_dict: MultiviewHeatmapLabeledBatchDict | UnlabeledBatchDict, batch_idx: int, return_heatmaps: bool = False) tuple[Tensor, Tensor] | tuple[Tensor, Tensor, Tensor][source]ο
Predict heatmaps and keypoints for a batch of video frames.
Assuming a DALI video loader is passed in > trainer = Trainer(devices=8, accelerator=βgpuβ) > predictions = trainer.predict(model, data_loader)
- __init__(num_keypoints: int, num_views: int, loss_factory: LossFactory | None = None, backbone: Literal['vits_dino', 'vits_dinov2', 'vits_dinov3', 'vitb_dino', 'vitb_dinov2', 'vitb_dinov3', 'vitb_imagenet', 'vitb_sam'] = 'vits_dino', pretrained: bool = True, head: Literal['heatmap_cnn'] = 'heatmap_cnn', downsample_factor: Literal[1, 2, 3] = 2, torch_seed: int = 123, optimizer: str = 'Adam', optimizer_params: dict | DictConfig | ListConfig | None = None, lr_scheduler: str = 'multisteplr', lr_scheduler_params: dict | DictConfig | ListConfig | None = None, image_size: int = 256, **kwargs: Any) None[source]ο
Initialize a multi-view model with transformer backbone. :param num_keypoints: number of body parts :param num_views: number of camera views :param loss_factory: object to orchestrate loss computation :param backbone: transformer variant to be used; cannot use convnets with this model :param pretrained: True to load pretrained imagenet weights :param head: architecture used to project per-view information to 2D heatmaps
heatmap_cnn
- Parameters:
downsample_factor β make heatmap smaller than original frames to save memory; subpixel operations are performed for increased precision
torch_seed β make weight initialization reproducible
lr_scheduler β how to schedule learning rate
lr_scheduler_params β params for specific learning rate schedulers - multisteplr: milestones, gamma
image_size β size of input images (height=width for ViT models)
**kwargs β additional arguments
- __new__(**kwargs)ο