RegressionTracker
- class lightning_pose.models.RegressionTracker
Bases:
BaseSupervisedTrackerBase model that produces (x, y) predictions of keypoints from images.
Methods Summary
forward(images)Forward pass through the network.
get_loss_inputs_labeled(batch_dict)Return predicted coordinates for a batch of data.
predict_step(batch_dict, batch_idx, **kwargs)Predict keypoints for a batch of video frames.
Methods Documentation
- forward(images: ~torch.Annotated[~torch.Tensor, {'__torchtyping__': True, 'details': ('batch', channels: 3, 'image_height', 'image_width',), 'cls_name': 'TensorType'}]) Tensor, {'__torchtyping__': True, 'details': ('batch', 'two_x_num_keypoints',), 'cls_name': 'TensorType'}][source]
Forward pass through the network.
- get_loss_inputs_labeled(batch_dict: BaseLabeledBatchDict) dict[source]
Return predicted coordinates for a batch of data.
- predict_step(batch_dict: BaseLabeledBatchDict | UnlabeledBatchDict, batch_idx: int, **kwargs: Any) Tuple[Tensor, Tensor][source]
Predict keypoints for a batch of video frames.
Assuming a DALI video loader is passed in > trainer = Trainer(devices=8, accelerator=”gpu”) > predictions = trainer.predict(model, data_loader)
- __init__(num_keypoints: int, loss_factory: LossFactory | None = None, backbone: Literal['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'resnet50_contrastive', 'resnet50_animal_apose', 'resnet50_animal_ap10k', 'resnet50_human_jhmdb', 'resnet50_human_res_rle', 'resnet50_human_top_res', 'resnet50_human_hand', 'efficientnet_b0', 'efficientnet_b1', 'efficientnet_b2', 'vits_dino', 'vits_dinov2', 'vits_dinov3', 'vitb_dino', 'vitb_dinov2', 'vitb_dinov3', 'vitb_imagenet', 'vitb_sam'] = 'resnet50', pretrained: bool = True, torch_seed: int = 123, optimizer: str = 'Adam', optimizer_params: DictConfig | dict | None = None, lr_scheduler: str = 'multisteplr', lr_scheduler_params: DictConfig | dict | None = None, **kwargs: Any) None[source]
Base model that produces (x, y) coordinates of keypoints from images.
- Parameters:
num_keypoints – number of body parts
loss_factory – object to orchestrate loss computation
backbone – ResNet or EfficientNet variant to be used
pretrained – True to load pretrained imagenet weights
torch_seed – make weight initialization reproducible
lr_scheduler – how to schedule learning rate multisteplr
lr_scheduler_params – params for specific learning rate schedulers multisteplr: milestones, gamma
- __new__(**kwargs)