SemiSupervisedHeatmapTrackerMultiviewTransformer

class lightning_pose.models.SemiSupervisedHeatmapTrackerMultiviewTransformer

Bases: SemiSupervisedTrackerMixin, HeatmapTrackerMultiviewTransformer

Semi-supervised HeatmapTrackerMultiviewTransformer that supports unsupervised losses.

Methods Summary

get_loss_inputs_unlabeled(batch_dict)

Return predicted heatmaps and keypoints for unlabeled data (required by SemiSupervisedTrackerMixin).

Methods Documentation

get_loss_inputs_unlabeled(batch_dict: UnlabeledBatchDict | MultiviewUnlabeledBatchDict) dict[source]

Return predicted heatmaps and keypoints for unlabeled data (required by SemiSupervisedTrackerMixin).

__init__(num_keypoints: int, num_views: int, loss_factory: LossFactory | None = None, loss_factory_unsupervised: LossFactory | None = None, backbone: Literal['vits_dino', 'vits_dinov2', 'vits_dinov3', 'vitb_dino', 'vitb_dinov2', 'vitb_dinov3', 'vitb_imagenet', 'vitb_sam'] = 'vits_dino', pretrained: bool = True, head: Literal['heatmap_cnn'] = 'heatmap_cnn', downsample_factor: Literal[1, 2, 3] = 2, torch_seed: int = 123, optimizer: str = 'Adam', optimizer_params: DictConfig | dict | None = None, lr_scheduler: str = 'multisteplr', lr_scheduler_params: DictConfig | dict | None = None, image_size: int = 256, **kwargs: Any)[source]

Initialize a semi-supervised multi-view model with transformer backbone.

Parameters:
  • num_keypoints – number of body parts

  • num_views – number of camera views

  • loss_factory – object to orchestrate supervised loss computation

  • loss_factory_unsupervised – object to orchestrate unsupervised loss computation

  • backbone – transformer variant to be used; cannot use convnets with this model

  • pretrained – True to load pretrained imagenet weights

  • head – architecture used to project per-view information to 2D heatmaps

  • downsample_factor – make heatmap smaller than original frames to save memory

  • torch_seed – make weight initialization reproducible

  • lr_scheduler – how to schedule learning rate

  • lr_scheduler_params – params for specific learning rate schedulers

  • image_size – size of input images (height=width for ViT models)

__new__(**kwargs)