SemiSupervisedHeatmapTracker
- class lightning_pose.models.heatmap_tracker.SemiSupervisedHeatmapTracker(num_keypoints: int, loss_factory: LossFactory | None = None, loss_factory_unsupervised: LossFactory | None = None, backbone: Literal['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'resnet50_contrastive', 'resnet50_animal_apose', 'resnet50_animal_ap10k', 'resnet50_human_jhmdb', 'resnet50_human_res_rle', 'resnet50_human_top_res', 'resnet50_human_hand', 'efficientnet_b0', 'efficientnet_b1', 'efficientnet_b2', 'vit_b_sam'] = 'resnet50', downsample_factor: Literal[1, 2, 3] = 2, pretrained: bool = True, output_shape: tuple | None = None, torch_seed: int = 123, lr_scheduler: str = 'multisteplr', lr_scheduler_params: DictConfig | dict | None = None, **kwargs: Any)[source]
Bases:
SemiSupervisedTrackerMixin,HeatmapTrackerModel produces heatmaps of keypoints from labeled/unlabeled images.
Methods Summary
get_loss_inputs_unlabeled(batch_dict)Return predicted heatmaps and their softmaxes (estimated keypoints).
Methods Documentation
- get_loss_inputs_unlabeled(batch_dict: UnlabeledBatchDict | MultiviewUnlabeledBatchDict) dict[source]
Return predicted heatmaps and their softmaxes (estimated keypoints).
- __init__(num_keypoints: int, loss_factory: LossFactory | None = None, loss_factory_unsupervised: LossFactory | None = None, backbone: Literal['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'resnet50_contrastive', 'resnet50_animal_apose', 'resnet50_animal_ap10k', 'resnet50_human_jhmdb', 'resnet50_human_res_rle', 'resnet50_human_top_res', 'resnet50_human_hand', 'efficientnet_b0', 'efficientnet_b1', 'efficientnet_b2', 'vit_b_sam'] = 'resnet50', downsample_factor: Literal[1, 2, 3] = 2, pretrained: bool = True, output_shape: tuple | None = None, torch_seed: int = 123, lr_scheduler: str = 'multisteplr', lr_scheduler_params: DictConfig | dict | None = None, **kwargs: Any) None[source]
- Parameters:
num_keypoints – number of body parts
loss_factory – object to orchestrate supervised loss computation
loss_factory_unsupervised – object to orchestrate unsupervised loss computation
backbone – ResNet or EfficientNet variant to be used
downsample_factor – make heatmap smaller than original frames to save memory; subpixel operations are performed for increased precision
pretrained – True to load pretrained imagenet weights
output_shape – hard-coded image size to avoid dynamic shape computations
torch_seed – make weight initialization reproducible
lr_scheduler – how to schedule learning rate multisteplr
lr_scheduler_params – params for specific learning rate schedulers multisteplr: milestones, gamma