BaseFeatureExtractor
- class lightning_pose.models.base.BaseFeatureExtractor(backbone: Literal['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'resnet50_contrastive', 'resnet50_animal_apose', 'resnet50_animal_ap10k', 'resnet50_human_jhmdb', 'resnet50_human_res_rle', 'resnet50_human_top_res', 'resnet50_human_hand', 'efficientnet_b0', 'efficientnet_b1', 'efficientnet_b2', 'vit_b_sam'] = 'resnet50', pretrained: bool = True, lr_scheduler: str = 'multisteplr', lr_scheduler_params: DictConfig | dict | None = None, do_context: bool = False, image_size: int = 256, model_type: Literal['heatmap', 'regression'] = 'heatmap', **kwargs: Any)[source]
Bases:
LightningModuleObject that contains the base resnet feature extractor.
Methods Summary
Select optimizer, lr scheduler, and metric for monitoring.
forward(images)Forward pass from images to representations.
get_representations(images[, is_multiview])Forward pass from images to feature maps.
get_scheduler(optimizer)Methods Documentation
- forward(images: Tensor[Tensor] | Tensor[Tensor] | Tensor[Tensor]) Tensor[Tensor] | Tensor[Tensor][source]
Forward pass from images to representations.
Wrapper around self.get_representations(). Fancier childern models will use get_representations() in their forward methods.
- Parameters:
images – a batch of images.
- Returns:
a representation of the images.
- get_representations(images: Tensor[Tensor] | Tensor[Tensor] | Tensor[Tensor] | Tensor[Tensor] | Tensor[Tensor], is_multiview: bool = False) Tensor[Tensor] | Tensor[Tensor][source]
Forward pass from images to feature maps.
Wrapper around the backbone’s feature_extractor() method for typechecking purposes. See tests/models/test_base.py for example shapes.
Batch options
TensorType[“batch”, “channels”:3, “image_height”, “image_width”] single view, labeled batch
TensorType[“batch”, “frames”, “channels”:3, “image_height”, “image_width”] single view, labeled context batch
TensorType[“seq_len”, “channels”:3, “image_height”, “image_width”] single view, unlabeled batch from DALI
TensorType[“batch”, “views”, “frames”, “channels”:3, “image_height”, “image_width”] multivew, labeled context batch
TensorType[“seq_len”, “views”, “channels”:3, “image_height”, “image_width”] multiview, unlabeled batch from DALI
- param images:
a batch of images
- param is_multiview:
flag to distinguish batches of the same size
- returns:
a representation of the images; features differ as a function of resnet version. Representation height and width differ as a function of image dimensions, and are not necessarily equal.