SemiSupervisedRegressionTracker

class lightning_pose.models.SemiSupervisedRegressionTracker

Bases: SemiSupervisedTrackerMixin, RegressionTracker

Model produces vectors of keypoints from labeled/unlabeled images.

Methods Summary

get_loss_inputs_unlabeled(batch_dict)

Return predicted heatmaps and their softmaxes (estimated keypoints).

Methods Documentation

get_loss_inputs_unlabeled(batch_dict: UnlabeledBatchDict) dict[source]

Return predicted heatmaps and their softmaxes (estimated keypoints).

__init__(num_keypoints: int, loss_factory: LossFactory | None = None, loss_factory_unsupervised: LossFactory | None = None, backbone: Literal['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'resnet50_contrastive', 'resnet50_animal_apose', 'resnet50_animal_ap10k', 'resnet50_human_jhmdb', 'resnet50_human_res_rle', 'resnet50_human_top_res', 'resnet50_human_hand', 'efficientnet_b0', 'efficientnet_b1', 'efficientnet_b2', 'vits_dino', 'vitb_dino', 'vitb_imagenet', 'vitb_sam'] = 'resnet50', pretrained: bool = True, torch_seed: int = 123, optimizer: str = 'Adam', optimizer_params: DictConfig | dict | None = None, lr_scheduler: str = 'multisteplr', lr_scheduler_params: DictConfig | dict | None = None, **kwargs: Any) None[source]
Parameters:
  • num_keypoints – number of body parts

  • loss_factory – object to orchestrate supervised loss computation

  • loss_factory_unsupervised – object to orchestrate unsupervised loss computation

  • backbone – ResNet or EfficientNet variant to be used

  • pretrained – True to load pretrained imagenet weights

  • torch_seed – make weight initialization reproducible

  • lr_scheduler – how to schedule learning rate multisteplr

  • lr_scheduler_params – params for specific learning rate schedulers multisteplr: milestones, gamma

  • do_context – use temporal context frames to improve predictions

__new__(**kwargs)