HeatmapTrackerMHCRNN
- class lightning_pose.models.HeatmapTrackerMHCRNN
Bases:
BaseSupervisedTrackerMulti-headed Convolutional RNN network that handles context frames.
Methods Summary
forward(images[, is_multiview])Forward pass through the network.
get_loss_inputs_labeled(batch_dict)Return predicted heatmaps and their softmaxes (estimated keypoints).
Return per-parameter-group optimizer configuration for backbone and head.
predict_step(batch_dict, batch_idx[, ...])Predict heatmaps and keypoints for a batch of video frames.
Methods Documentation
- forward(images: Float[Tensor, 'batch frames channels image_height image_width'] | Float[Tensor, 'batch channels image_height image_width'] | Float[Tensor, 'batch view frames channels image_height image_width'] | Float[Tensor, 'batch view channels image_height image_width'], is_multiview: bool = False) tuple[Float[Tensor, 'num_valid_outputs num_keypoints heatmap_height heatmap_width'], Float[Tensor, 'num_valid_outputs num_keypoints heatmap_height heatmap_width']][source]
Forward pass through the network.
Batch options
Float[torch.Tensor, “batch frames channels image_height image_width”] single view, labeled context batch
Float[torch.Tensor, “batch channels image_height image_width”] single view, unlabeled batch from DALI
Float[torch.Tensor, “batch view frames channels image_height image_width”] multivew, labeled context batch
Float[torch.Tensor, “batch view channels image_height image_width”] multiview, unlabeled batch from DALI
- get_loss_inputs_labeled(batch_dict: HeatmapLabeledBatchDict | MultiviewHeatmapLabeledBatchDict) dict[source]
Return predicted heatmaps and their softmaxes (estimated keypoints).
- get_parameters() list[dict][source]
Return per-parameter-group optimizer configuration for backbone and head.
- Returns:
List of dicts with
"params","name", and optionally"lr"keys; the backbone starts with learning rate 0 (frozen until unfreezing).
- predict_step(batch_dict: HeatmapLabeledBatchDict | MultiviewHeatmapLabeledBatchDict | UnlabeledBatchDict, batch_idx: int, return_heatmaps: bool | None = False) tuple[Tensor, Tensor] | tuple[Tensor, Tensor, Tensor][source]
Predict heatmaps and keypoints for a batch of video frames.
Assuming a DALI video loader is passed in > trainer = Trainer(devices=8, accelerator=”gpu”) > predictions = trainer.predict(model, data_loader)
- __init__(num_keypoints: int, num_targets: int | None = None, loss_factory: LossFactory | None = None, backbone: Literal['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'resnet50_animal_apose', 'resnet50_animal_ap10k', 'resnet50_human_jhmdb', 'resnet50_human_res_rle', 'resnet50_human_top_res', 'resnet50_human_hand', 'efficientnet_b0', 'efficientnet_b1', 'efficientnet_b2', 'vits_dino', 'vits_dinov2', 'vits_dinov3', 'vitb_dino', 'vitb_dinov2', 'vitb_dinov3', 'vitb_imagenet', 'vitb_sam'] = 'resnet50', pretrained: bool = True, downsample_factor: Literal[1, 2, 3] = 2, torch_seed: int = 123, optimizer: str = 'Adam', optimizer_params: dict | DictConfig | ListConfig | None = None, lr_scheduler: str = 'multisteplr', lr_scheduler_params: dict | DictConfig | ListConfig | None = None, **kwargs: Any) None[source]
Initialize a DLC-like model with resnet backbone.
- Parameters:
num_keypoints – number of body parts
loss_factory – object to orchestrate loss computation
backbone – ResNet or EfficientNet variant to be used
pretrained – True to load pretrained imagenet weights
downsample_factor – make heatmap smaller than original frames to save memory; subpixel operations are performed for increased precision
torch_seed – make weight initialization reproducible
lr_scheduler – how to schedule learning rate
lr_scheduler_params – params for specific learning rate schedulers - multisteplr: milestones, gamma
- __new__(**kwargs)