BaseFeatureExtractor
- class lightning_pose.models.base.BaseFeatureExtractor[source]
Bases:
LightningModuleObject that contains the base resnet feature extractor.
Methods Summary
Select optimizer, lr scheduler, and metric for monitoring.
forward(images)Forward pass from images to representations.
Return an iterator over trainable (requires_grad) model parameters.
get_representations(images[, is_multiview])Forward pass from images to feature maps.
get_scheduler(optimizer)Build and return the learning rate scheduler.
Methods Documentation
- forward(images: Float[Tensor, 'batch RGB image_height image_width'] | Float[Tensor, 'batch seq_length RGB image_height image_width'] | Float[Tensor, 'seq_length RGB image_height image_width']) Float[Tensor, 'new_batch features rep_height rep_width'] | Float[Tensor, 'new_batch features rep_height rep_width frames'][source]
Forward pass from images to representations.
Wrapper around self.get_representations(). Fancier childern models will use get_representations() in their forward methods.
- Parameters:
images – a batch of images.
- Returns:
a representation of the images.
- get_parameters() Iterator[Parameter][source]
Return an iterator over trainable (requires_grad) model parameters.
- Returns:
Iterator of
torch.nn.Parameterobjects that require gradients.
- get_representations(images: Float[Tensor, 'batch channels image_height image_width'] | Float[Tensor, 'batch frames channels image_height image_width'] | Float[Tensor, 'seq_len channels image_height image_width'] | Float[Tensor, 'batch views frames channels image_height image_width'] | Float[Tensor, 'seq_len view frames channels image_height image_width'], is_multiview: bool = False) Float[Tensor, 'new_batch features rep_height rep_width'] | Float[Tensor, 'new_batch features rep_height rep_width frames'][source]
Forward pass from images to feature maps.
Wrapper around the backbone’s feature_extractor() method for typechecking purposes. See tests/models/test_base.py for example shapes.
Batch options
Float[torch.Tensor, “batch channels image_height image_width”] single view, labeled batch
Float[torch.Tensor, “batch frames channels image_height image_width”] single view, labeled context batch
Float[torch.Tensor, “seq_len channels image_height image_width”] single view, unlabeled batch from DALI
Float[torch.Tensor, “batch views frames channels image_height image_width”] multivew, labeled context batch
Float[torch.Tensor, “seq_len views channels image_height image_width”] multiview, unlabeled batch from DALI
- param images:
a batch of images
- param is_multiview:
flag to distinguish batches of the same size
- returns:
a representation of the images; features differ as a function of resnet version. Representation height and width differ as a function of image dimensions, and are not necessarily equal.
- get_scheduler(optimizer: Optimizer) MultiStepLR[source]
Build and return the learning rate scheduler.
- Parameters:
optimizer – the optimizer whose learning rate will be scheduled.
- Returns:
MultiStepLRscheduler configured fromself.lr_scheduler_params.- Raises:
LrNotImplementedError – if
self.lr_scheduleris not supported.
- __init__(backbone: Literal['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'resnet50_animal_apose', 'resnet50_animal_ap10k', 'resnet50_human_jhmdb', 'resnet50_human_res_rle', 'resnet50_human_top_res', 'resnet50_human_hand', 'efficientnet_b0', 'efficientnet_b1', 'efficientnet_b2', 'vits_dino', 'vits_dinov2', 'vits_dinov3', 'vitb_dino', 'vitb_dinov2', 'vitb_dinov3', 'vitb_imagenet', 'vitb_sam'] = 'resnet50', pretrained: bool = True, lr_scheduler: str = 'multisteplr', lr_scheduler_params: dict | DictConfig | ListConfig | None = None, optimizer: str = 'Adam', optimizer_params: dict | DictConfig | ListConfig | None = None, do_context: bool = False, image_size: int = 256, model_type: Literal['heatmap', 'regression'] = 'heatmap', **kwargs: Any) None[source]
A CNN model that takes in images and generates features.
ResNets will be loaded from torchvision and can be either pre-trained on ImageNet or randomly initialized. These were originally used for classification tasks, so we truncate their final fully connected layer.
- Parameters:
backbone – which backbone version to use; defaults to resnet50
pretrained – True to load weights pretrained on imagenet (torchvision models only)
optimizer – optimizer class to instantiate (Adam, AdamW, more to be added in future)
optimizer_params – arguments to pass to optimizer
lr_scheduler – how to schedule learning rate
lr_scheduler_params – params for specific learning rate schedulers
do_context – include temporal context when processing each frame
image_size – height/width of frames, for ViT models only
model_type – type of model
- __new__(**kwargs)