BaseSupervisedTracker

class lightning_pose.models.base.BaseSupervisedTracker[source]

Bases: BaseFeatureExtractor

Base class for supervised trackers.

Methods Summary

evaluate_labeled(batch_dict[, stage])

Compute and log the losses on a batch of labeled data.

get_loss_inputs_labeled(batch_dict)

Return predicted coordinates for a batch of data.

test_step(batch_dict, batch_idx)

Base test step, a wrapper around the evaluate_labeled method.

training_step(batch_dict, batch_idx)

Base training step, a wrapper around the evaluate_labeled method.

validation_step(batch_dict, batch_idx)

Base validation step, a wrapper around the evaluate_labeled method.

Methods Documentation

evaluate_labeled(batch_dict: BaseLabeledBatchDict | HeatmapLabeledBatchDict | MultiviewLabeledBatchDict | MultiviewHeatmapLabeledBatchDict, stage: Literal['train', 'val', 'test'] | None = None) float32), 'cls_name': 'TensorType'}][source]

Compute and log the losses on a batch of labeled data.

get_loss_inputs_labeled(batch_dict: BaseLabeledBatchDict | HeatmapLabeledBatchDict | MultiviewLabeledBatchDict | MultiviewHeatmapLabeledBatchDict) dict[source]

Return predicted coordinates for a batch of data.

test_step(batch_dict: BaseLabeledBatchDict | HeatmapLabeledBatchDict | MultiviewLabeledBatchDict | MultiviewHeatmapLabeledBatchDict, batch_idx: int) None[source]

Base test step, a wrapper around the evaluate_labeled method.

training_step(batch_dict: BaseLabeledBatchDict | HeatmapLabeledBatchDict | MultiviewLabeledBatchDict | MultiviewHeatmapLabeledBatchDict, batch_idx: int) dict[str, ~torch.Annotated[~torch.Tensor, {'__torchtyping__': True, 'details': ((), torch.float32), 'cls_name': 'TensorType'}]][source]

Base training step, a wrapper around the evaluate_labeled method.

validation_step(batch_dict: BaseLabeledBatchDict | HeatmapLabeledBatchDict | MultiviewLabeledBatchDict | MultiviewHeatmapLabeledBatchDict, batch_idx: int) None[source]

Base validation step, a wrapper around the evaluate_labeled method.

__init__(backbone: Literal['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'resnet50_contrastive', 'resnet50_animal_apose', 'resnet50_animal_ap10k', 'resnet50_human_jhmdb', 'resnet50_human_res_rle', 'resnet50_human_top_res', 'resnet50_human_hand', 'efficientnet_b0', 'efficientnet_b1', 'efficientnet_b2', 'vits_dino', 'vitb_dino', 'vitb_imagenet', 'vitb_sam'] = 'resnet50', pretrained: bool = True, lr_scheduler: str = 'multisteplr', lr_scheduler_params: DictConfig | dict | None = None, optimizer: str = 'Adam', optimizer_params: DictConfig | dict | None = None, do_context: bool = False, image_size: int = 256, model_type: Literal['heatmap', 'regression'] = 'heatmap', **kwargs: Any) None

A CNN model that takes in images and generates features.

ResNets will be loaded from torchvision and can be either pre-trained on ImageNet or randomly initialized. These were originally used for classification tasks, so we truncate their final fully connected layer.

Parameters:
  • backbone – which backbone version to use; defaults to resnet50

  • pretrained – True to load weights pretrained on imagenet (torchvision models only)

  • optimizer – optimizer class to instantiate (Adam, AdamW, more to be added in future)

  • optimizer_params – arguments to pass to optimizer

  • lr_scheduler – how to schedule learning rate

  • lr_scheduler_params – params for specific learning rate schedulers

  • do_context – include temporal context when processing each frame

  • image_size – height/width of frames, for ViT models only

  • model_type – type of model

__new__(**kwargs)