Source code for lightning_pose.models.backbones.vit_img_encoder

import torch
import torch.nn as nn
import torch.nn.functional as F

try:
    from segment_anything.modeling import ImageEncoderViT
except ImportError:
    raise NotImplementedError(
        "If you have pip installed lightning pose, there is no access to segment-anything"
        "models due to dependency/installation issues. "
        "For more information please contatct the package maintainers."
    )

from typing import Tuple, Type

# to ignore imports for sphix-autoapidoc
__all__ = [
    "ImageEncoderViT_FT",
    "resample_abs_pos_embed_nhwc",
]


# This class and its supporting functions lightly adapted from the ViTDet backbone available at:
# https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py
[docs] class ImageEncoderViT_FT(ImageEncoderViT):
[docs] def __init__( self, img_size: int = 1024, finetune_img_size: int = 1024, patch_size: int = 16, in_chans: int = 3, embed_dim: int = 768, depth: int = 12, num_heads: int = 12, mlp_ratio: float = 4.0, out_chans: int = 256, qkv_bias: bool = True, norm_layer: Type[nn.Module] = nn.LayerNorm, act_layer: Type[nn.Module] = nn.GELU, use_abs_pos: bool = True, use_rel_pos: bool = False, rel_pos_zero_init: bool = True, window_size: int = 0, global_attn_indexes: Tuple[int, ...] = (), ) -> None: """ Args: img_size (int): Input image size of pretrained ViT backbone checkpoint. finetune_img_size (int): Input image size for lightning-pose training. patch_size (int): Patch size. in_chans (int): Number of input image channels. embed_dim (int): Patch embedding dimension. depth (int): Depth of ViT. num_heads (int): Number of attention heads in each ViT block. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. qkv_bias (bool): If True, add a learnable bias to query, key, value. norm_layer (nn.Module): Normalization layer. act_layer (nn.Module): Activation layer. use_abs_pos (bool): If True, use absolute positional embeddings. use_rel_pos (bool): If True, add relative positional embeddings to the attention map. rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. window_size (int): Window size for window attention blocks. global_attn_indexes (list): Indexes for blocks using global attention. """ self.img_size = img_size self.finetune_img_size = finetune_img_size self.patch_size = patch_size self.pos_embed = None # build this later ImageEncoderViT.__init__( self, img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, depth=depth, num_heads=num_heads, mlp_ratio=mlp_ratio, out_chans=out_chans, qkv_bias=qkv_bias, norm_layer=norm_layer, act_layer=act_layer, use_abs_pos=use_abs_pos, use_rel_pos=use_rel_pos, rel_pos_zero_init=rel_pos_zero_init, window_size=window_size, global_attn_indexes=global_attn_indexes, )
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.patch_embed(x) if self.pos_embed is not None: if x.shape != self.pos_embed.shape: self.pos_embed = resample_abs_pos_embed_nhwc( posemb=self.pos_embed, new_size=[ self.finetune_img_size // self.patch_size, self.finetune_img_size // self.patch_size ], ) x = x + self.pos_embed for blk in self.blocks: x = blk(x) x = self.neck(x.permute(0, 3, 1, 2)) return x
[docs] def resample_abs_pos_embed_nhwc( posemb, new_size: list[int], interpolation: str = 'bicubic', antialias: bool = True, ): if new_size[0] == posemb.shape[-3] and new_size[1] == posemb.shape[-2]: return posemb # do the interpolation posemb = posemb.reshape( 1, posemb.shape[-3], posemb.shape[-2], posemb.shape[-1] ).permute(0, 3, 1, 2) posemb = F.interpolate(posemb, size=new_size, mode=interpolation, antialias=antialias) posemb = posemb.permute(0, 2, 3, 1) return nn.Parameter(posemb)