HeatmapTrackerMHCRNN

class lightning_pose.models.HeatmapTrackerMHCRNN

Bases: BaseSupervisedTracker

Multi-headed Convolutional RNN network that handles context frames.

Methods Summary

forward(images[, is_multiview])

Forward pass through the network.

get_loss_inputs_labeled(batch_dict)

Return predicted heatmaps and their softmaxes (estimated keypoints).

get_parameters()

predict_step(batch_dict, batch_idx[, ...])

Predict heatmaps and keypoints for a batch of video frames.

Methods Documentation

forward(images: ~torch.Annotated[~torch.Tensor, {'__torchtyping__': True, 'details': ('batch', 'frames', channels: 3, 'image_height', 'image_width',), 'cls_name': 'TensorType'}] | ~torch.Annotated[~torch.Tensor, {'__torchtyping__': True, 'details': ('batch', channels: 3, 'image_height', 'image_width',), 'cls_name': 'TensorType'}] | ~torch.Annotated[~torch.Tensor, {'__torchtyping__': True, 'details': ('batch', 'view', 'frames', channels: 3, 'image_height', 'image_width',), 'cls_name': 'TensorType'}] | ~torch.Annotated[~torch.Tensor, {'__torchtyping__': True, 'details': ('batch', 'view', channels: 3, 'image_height', 'image_width',), 'cls_name': 'TensorType'}], is_multiview: bool = False) tuple[~torch.Annotated[~torch.Tensor, {'__torchtyping__': True, 'details': ('num_valid_outputs', 'num_keypoints', 'heatmap_height', 'heatmap_width',), 'cls_name': 'TensorType'}], ~torch.Annotated[~torch.Tensor, {'__torchtyping__': True, 'details': ('num_valid_outputs', 'num_keypoints', 'heatmap_height', 'heatmap_width',), 'cls_name': 'TensorType'}]][source]

Forward pass through the network.

Batch options

  • TensorType[“batch”, “frames”, “channels”:3, “image_height”, “image_width”] single view, labeled context batch

  • TensorType[“batch”, “channels”:3, “image_height”, “image_width”] single view, unlabeled batch from DALI

  • TensorType[“batch”, “view”, “frames”, “channels”:3, “image_height”, “image_width”] multivew, labeled context batch

  • TensorType[“batch”, “view”, “channels”:3, “image_height”, “image_width”] multiview, unlabeled batch from DALI

get_loss_inputs_labeled(batch_dict: HeatmapLabeledBatchDict | MultiviewHeatmapLabeledBatchDict) dict[source]

Return predicted heatmaps and their softmaxes (estimated keypoints).

get_parameters()[source]
predict_step(batch_dict: HeatmapLabeledBatchDict | MultiviewHeatmapLabeledBatchDict | UnlabeledBatchDict, batch_idx: int, return_heatmaps: bool | None = False) tuple[Tensor, Tensor] | tuple[Tensor, Tensor, Tensor][source]

Predict heatmaps and keypoints for a batch of video frames.

Assuming a DALI video loader is passed in > trainer = Trainer(devices=8, accelerator=”gpu”) > predictions = trainer.predict(model, data_loader)

__init__(num_keypoints: int, num_targets: int | None = None, loss_factory: LossFactory | None = None, backbone: Literal['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'resnet50_animal_apose', 'resnet50_animal_ap10k', 'resnet50_human_jhmdb', 'resnet50_human_res_rle', 'resnet50_human_top_res', 'resnet50_human_hand', 'efficientnet_b0', 'efficientnet_b1', 'efficientnet_b2', 'vits_dino', 'vits_dinov2', 'vits_dinov3', 'vitb_dino', 'vitb_dinov2', 'vitb_dinov3', 'vitb_imagenet', 'vitb_sam'] = 'resnet50', pretrained: bool = True, downsample_factor: Literal[1, 2, 3] = 2, torch_seed: int = 123, optimizer: str = 'Adam', optimizer_params: DictConfig | dict | None = None, lr_scheduler: str = 'multisteplr', lr_scheduler_params: DictConfig | dict | None = None, **kwargs: Any)[source]

Initialize a DLC-like model with resnet backbone.

Parameters:
  • num_keypoints – number of body parts

  • loss_factory – object to orchestrate loss computation

  • backbone – ResNet or EfficientNet variant to be used

  • pretrained – True to load pretrained imagenet weights

  • downsample_factor – make heatmap smaller than original frames to save memory; subpixel operations are performed for increased precision

  • torch_seed – make weight initialization reproducible

  • lr_scheduler – how to schedule learning rate

  • lr_scheduler_params – params for specific learning rate schedulers - multisteplr: milestones, gamma

__new__(**kwargs)