BaseSupervisedTrackerο
- class lightning_pose.models.base.BaseSupervisedTracker[source]ο
Bases:
BaseFeatureExtractorBase class for supervised trackers.
Methods Summary
evaluate_labeled(batch_dict[,Β stage,Β ...])Compute and log the losses on a batch of labeled data.
get_loss_inputs_labeled(batch_dict)Return predicted coordinates for a batch of data.
test_step(batch_dict,Β batch_idx)Base test step, a wrapper around the evaluate_labeled method.
training_step(batch_dict,Β batch_idx)Base training step, a wrapper around the evaluate_labeled method.
validation_step(batch_dict,Β batch_idx)Base validation step, a wrapper around the evaluate_labeled method.
Methods Documentation
- evaluate_labeled(batch_dict: BaseLabeledBatchDict | HeatmapLabeledBatchDict | MultiviewLabeledBatchDict | MultiviewHeatmapLabeledBatchDict, stage: Literal['train', 'val', 'test'] | None = None, anneal_weight: Tensor | None = None) Float[Tensor, ''][source]ο
Compute and log the losses on a batch of labeled data.
- get_loss_inputs_labeled(batch_dict: BaseLabeledBatchDict | HeatmapLabeledBatchDict | MultiviewLabeledBatchDict | MultiviewHeatmapLabeledBatchDict) dict[source]ο
Return predicted coordinates for a batch of data.
- test_step(batch_dict: BaseLabeledBatchDict | HeatmapLabeledBatchDict | MultiviewLabeledBatchDict | MultiviewHeatmapLabeledBatchDict, batch_idx: int) None[source]ο
Base test step, a wrapper around the evaluate_labeled method.
- training_step(batch_dict: BaseLabeledBatchDict | HeatmapLabeledBatchDict | MultiviewLabeledBatchDict | MultiviewHeatmapLabeledBatchDict, batch_idx: int) dict[str, Float[Tensor, '']][source]ο
Base training step, a wrapper around the evaluate_labeled method.
- validation_step(batch_dict: BaseLabeledBatchDict | HeatmapLabeledBatchDict | MultiviewLabeledBatchDict | MultiviewHeatmapLabeledBatchDict, batch_idx: int) None[source]ο
Base validation step, a wrapper around the evaluate_labeled method.
- __init__(backbone: Literal['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'resnet50_animal_apose', 'resnet50_animal_ap10k', 'resnet50_human_jhmdb', 'resnet50_human_res_rle', 'resnet50_human_top_res', 'resnet50_human_hand', 'efficientnet_b0', 'efficientnet_b1', 'efficientnet_b2', 'vits_dino', 'vits_dinov2', 'vits_dinov3', 'vitb_dino', 'vitb_dinov2', 'vitb_dinov3', 'vitb_imagenet', 'vitb_sam'] = 'resnet50', pretrained: bool = True, lr_scheduler: str = 'multisteplr', lr_scheduler_params: dict | DictConfig | ListConfig | None = None, optimizer: str = 'Adam', optimizer_params: dict | DictConfig | ListConfig | None = None, do_context: bool = False, image_size: int = 256, model_type: Literal['heatmap', 'regression'] = 'heatmap', **kwargs: Any) Noneο
A CNN model that takes in images and generates features.
ResNets will be loaded from torchvision and can be either pre-trained on ImageNet or randomly initialized. These were originally used for classification tasks, so we truncate their final fully connected layer.
- Parameters:
backbone β which backbone version to use; defaults to resnet50
pretrained β True to load weights pretrained on imagenet (torchvision models only)
optimizer β optimizer class to instantiate (Adam, AdamW, more to be added in future)
optimizer_params β arguments to pass to optimizer
lr_scheduler β how to schedule learning rate
lr_scheduler_params β params for specific learning rate schedulers
do_context β include temporal context when processing each frame
image_size β height/width of frames, for ViT models only
model_type β type of model
- __new__(**kwargs)ο