BaseFeatureExtractor

class lightning_pose.models.base.BaseFeatureExtractor(backbone: Literal['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'resnet50_contrastive', 'resnet50_animal_apose', 'resnet50_animal_ap10k', 'resnet50_human_jhmdb', 'resnet50_human_res_rle', 'resnet50_human_top_res', 'resnet50_human_hand', 'efficientnet_b0', 'efficientnet_b1', 'efficientnet_b2', 'vit_b_sam'] = 'resnet50', pretrained: bool = True, lr_scheduler: str = 'multisteplr', lr_scheduler_params: DictConfig | dict | None = None, do_context: bool = False, image_size: int = 256, model_type: Literal['heatmap', 'regression'] = 'heatmap', **kwargs: Any)[source]

Bases: LightningModule

Object that contains the base resnet feature extractor.

Methods Summary

configure_optimizers()

Select optimizer, lr scheduler, and metric for monitoring.

forward(images)

Forward pass from images to representations.

get_parameters()

get_representations(images[, is_multiview])

Forward pass from images to feature maps.

get_scheduler(optimizer)

Methods Documentation

configure_optimizers() dict[source]

Select optimizer, lr scheduler, and metric for monitoring.

forward(images: ~torch.Annotated[~torch.Tensor, {'__torchtyping__': True, 'details': ('batch', RGB: 3, 'image_height', 'image_width',), 'cls_name': 'TensorType'}] | ~torch.Annotated[~torch.Tensor, {'__torchtyping__': True, 'details': ('batch', 'seq_length', RGB: 3, 'image_height', 'image_width',), 'cls_name': 'TensorType'}] | ~torch.Annotated[~torch.Tensor, {'__torchtyping__': True, 'details': ('seq_length', RGB: 3, 'image_height', 'image_width',), 'cls_name': 'TensorType'}]) Tensor, {'__torchtyping__': True, 'details': ('new_batch', 'features', 'rep_height', 'rep_width', 'frames',), 'cls_name': 'TensorType'}][source]

Forward pass from images to representations.

Wrapper around self.get_representations(). Fancier childern models will use get_representations() in their forward methods.

Parameters:

images – a batch of images.

Returns:

a representation of the images.

get_parameters()[source]
get_representations(images: ~torch.Annotated[~torch.Tensor, {'__torchtyping__': True, 'details': ('batch', channels: 3, 'image_height', 'image_width',), 'cls_name': 'TensorType'}] | ~torch.Annotated[~torch.Tensor, {'__torchtyping__': True, 'details': ('batch', 'frames', channels: 3, 'image_height', 'image_width',), 'cls_name': 'TensorType'}] | ~torch.Annotated[~torch.Tensor, {'__torchtyping__': True, 'details': ('seq_len', channels: 3, 'image_height', 'image_width',), 'cls_name': 'TensorType'}] | ~torch.Annotated[~torch.Tensor, {'__torchtyping__': True, 'details': ('batch', 'views', 'frames', channels: 3, 'image_height', 'image_width',), 'cls_name': 'TensorType'}] | ~torch.Annotated[~torch.Tensor, {'__torchtyping__': True, 'details': ('seq_len', 'view', 'frames', channels: 3, 'image_height', 'image_width',), 'cls_name': 'TensorType'}], is_multiview: bool = False) Tensor, {'__torchtyping__': True, 'details': ('new_batch', 'features', 'rep_height', 'rep_width', 'frames',), 'cls_name': 'TensorType'}][source]

Forward pass from images to feature maps.

Wrapper around the backbone’s feature_extractor() method for typechecking purposes. See tests/models/test_base.py for example shapes.

Batch options

  • TensorType[“batch”, “channels”:3, “image_height”, “image_width”] single view, labeled batch

  • TensorType[“batch”, “frames”, “channels”:3, “image_height”, “image_width”] single view, labeled context batch

  • TensorType[“seq_len”, “channels”:3, “image_height”, “image_width”] single view, unlabeled batch from DALI

  • TensorType[“batch”, “views”, “frames”, “channels”:3, “image_height”, “image_width”] multivew, labeled context batch

  • TensorType[“seq_len”, “views”, “channels”:3, “image_height”, “image_width”] multiview, unlabeled batch from DALI

param images:

a batch of images

param is_multiview:

flag to distinguish batches of the same size

returns:

a representation of the images; features differ as a function of resnet version. Representation height and width differ as a function of image dimensions, and are not necessarily equal.

get_scheduler(optimizer)[source]
__init__(backbone: Literal['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'resnet50_contrastive', 'resnet50_animal_apose', 'resnet50_animal_ap10k', 'resnet50_human_jhmdb', 'resnet50_human_res_rle', 'resnet50_human_top_res', 'resnet50_human_hand', 'efficientnet_b0', 'efficientnet_b1', 'efficientnet_b2', 'vit_b_sam'] = 'resnet50', pretrained: bool = True, lr_scheduler: str = 'multisteplr', lr_scheduler_params: DictConfig | dict | None = None, do_context: bool = False, image_size: int = 256, model_type: Literal['heatmap', 'regression'] = 'heatmap', **kwargs: Any) None[source]

A CNN model that takes in images and generates features.

ResNets will be loaded from torchvision and can be either pre-trained on ImageNet or randomly initialized. These were originally used for classification tasks, so we truncate their final fully connected layer.

Parameters:
  • backbone – which backbone version to use; defaults to resnet50

  • pretrained – True to load weights pretrained on imagenet (torchvision models only)

  • lr_scheduler – how to schedule learning rate

  • lr_scheduler_params – params for specific learning rate schedulers

  • do_context – include temporal context when processing each frame

  • image_size – height/width of frames, for ViT models only

  • model_type – type of model