import math
import torch
from transformers import ViTModel
from typeguard import typechecked
from lightning_pose.models.backbones.vit_sam import SamVisionEncoder
# to ignore imports for sphix-autoapidoc
__all__ = [
"build_backbone",
]
[docs]
@typechecked
def build_backbone(backbone_arch: str, image_size: int = 256, **kwargs):
"""Load backbone weights for resnet models.
Args:
backbone_arch: which backbone version/weights to use
image_size: height/width in pixels of images (must be square)
Returns:
tuple
- backbone: pytorch model
- num_fc_input_features (int): number of input features to fully connected layer
"""
# deprecation warnings
if "vit_h_sam" in backbone_arch:
backbone_arch = "vitb_sam"
raise DeprecationWarning('vit_h_sam is now deprecated; reverting to "vitb_sam"')
elif "vit_b_sam" in backbone_arch:
backbone_arch = "vitb_sam"
raise DeprecationWarning('vit_b_sam is now deprecated; reverting to "vitb_sam"')
# load backbone weights
if "vits_dino" in backbone_arch:
base = VisionEncoder(model_name="facebook/dino-vits16")
encoder_embed_dim = base.vision_encoder.config.hidden_size
elif "vitb_dino" in backbone_arch:
base = VisionEncoder(model_name="facebook/dino-vitb16")
encoder_embed_dim = base.vision_encoder.config.hidden_size
elif "vitb_imagenet" in backbone_arch:
base = VisionEncoder(model_name="facebook/vit-mae-base")
encoder_embed_dim = base.vision_encoder.config.hidden_size
if kwargs.get("backbone_checkpoint"):
load_vit_backbone_checkpoint(base, kwargs["backbone_checkpoint"])
elif "vitb_sam" in backbone_arch:
base = SamVisionEncoder(
model_name="facebook/sam-vit-base",
finetune_img_size=image_size,
)
encoder_embed_dim = base.vision_encoder.config.hidden_size
else:
raise NotImplementedError(f"{backbone_arch} is not a valid backbone")
num_fc_input_features = encoder_embed_dim
return base, num_fc_input_features
def load_vit_backbone_checkpoint(base, checkpoint: str):
print(f"Loading VIT-MAE weights from {checkpoint}")
ckpt_vit_pretrain = torch.load(checkpoint, map_location="cpu")
# extract state dict if checkpoint contains additional info
if "state_dict" in ckpt_vit_pretrain:
ckpt_vit_pretrain = ckpt_vit_pretrain["state_dict"]
# Create a filtered state dict for the VIT-MAE part only
vit_mae_state_dict = {}
for key, value in ckpt_vit_pretrain.items():
if key.startswith("vit_mae."):
model_key = key.replace("vit_mae.vit.", "")
# Skip known problematic layers with size mismatches
if any(prob in model_key for prob in [
"position_embeddings",
"patch_embeddings.projection", # in case backbone was trained with grayscale imgs
"decoder_pos_embed",
"decoder_pred",
]):
continue
# Check if shapes match before including in state dict
if model_key in base.vision_encoder.state_dict():
if base.vision_encoder.state_dict()[model_key].shape == value.shape:
vit_mae_state_dict[model_key] = value
# Load the filtered weights
base.vision_encoder.load_state_dict(vit_mae_state_dict, strict=False)
class VisionEncoder(torch.nn.Module):
"""Wrapper around ViT Encoder."""
def __init__(self, model_name):
super().__init__()
self.vision_encoder = ViTModel.from_pretrained(model_name, add_pooling_layer=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass through the vision encoder.
Args:
x: Input tensor of shape (B, C, H, W)
Returns:
Encoded features
"""
outputs = self.vision_encoder(
x,
return_dict=True,
output_hidden_states=False,
output_attentions=False,
interpolate_pos_encoding=True,
).last_hidden_state
# skip the cls token
outputs = outputs[:, 1:, ...] # [N, S, D]
# change the shape to [N, H, W, D] -> [N, D, H, W]
N = x.shape[0]
S = outputs.shape[1]
H, W = math.isqrt(S), math.isqrt(S)
outputs = outputs.reshape(N, H, W, -1).permute(0, 3, 1, 2)
return outputs