Source code for lightning_pose.models.base
"""Base class for backbone that acts as a feature extractor."""
from typing import Any, Dict, Literal, Optional, Union
import torch
from lightning.pytorch import LightningModule
from omegaconf import DictConfig
from torch.optim import Adam
from torch.optim.lr_scheduler import MultiStepLR
from torchtyping import TensorType
from typeguard import typechecked
from lightning_pose.data.utils import (
BaseLabeledBatchDict,
HeatmapLabeledBatchDict,
MultiviewHeatmapLabeledBatchDict,
MultiviewLabeledBatchDict,
MultiviewUnlabeledBatchDict,
SemiSupervisedBatchDict,
SemiSupervisedHeatmapBatchDict,
UnlabeledBatchDict,
)
# to ignore imports for sphix-autoapidoc
__all__ = [
"normalized_to_bbox",
"convert_bbox_coords",
"get_context_from_sequence",
"BaseFeatureExtractor",
"BaseSupervisedTracker",
"SemiSupervisedTrackerMixin",
]
MULTISTEPLR_MILESTONES_DEFAULT = [100, 200, 300]
MULTISTEPLR_GAMMA_DEFAULT = 0.5
# list of all allowed backbone options
ALLOWED_BACKBONES = Literal[
"resnet18",
"resnet34",
"resnet50",
"resnet101",
"resnet152",
"resnet50_contrastive", # needs extra install: pip install -e .[extra_models]
"resnet50_animal_apose",
"resnet50_animal_ap10k",
"resnet50_human_jhmdb",
"resnet50_human_res_rle",
"resnet50_human_top_res",
"resnet50_human_hand",
"efficientnet_b0",
"efficientnet_b1",
"efficientnet_b2",
# "vit_h_sam",
"vit_b_sam",
]
[docs]def normalized_to_bbox(
keypoints: TensorType["batch", "num_keypoints", "xy":2],
bbox: TensorType["batch", "xyhw":4]
) -> TensorType["batch", "num_keypoints", "xy":2]:
if keypoints.shape[0] == bbox.shape[0]:
# normal batch
keypoints[:, :, 0] *= bbox[:, 3].unsqueeze(1) # scale x by box width
keypoints[:, :, 0] += bbox[:, 0].unsqueeze(1) # add bbox x offset
keypoints[:, :, 1] *= bbox[:, 2].unsqueeze(1) # scale y by box height
keypoints[:, :, 1] += bbox[:, 1].unsqueeze(1) # add bbox y offset
else:
# context batch; we don't have predictions for first/last two frames
keypoints[:, :, 0] *= bbox[2:-2, 3].unsqueeze(1) # scale x by box width
keypoints[:, :, 0] += bbox[2:-2, 0].unsqueeze(1) # add bbox x offset
keypoints[:, :, 1] *= bbox[2:-2, 2].unsqueeze(1) # scale y by box height
keypoints[:, :, 1] += bbox[2:-2, 1].unsqueeze(1) # add bbox y offset
return keypoints
[docs]def convert_bbox_coords(
batch_dict: Union[
HeatmapLabeledBatchDict,
MultiviewHeatmapLabeledBatchDict,
MultiviewUnlabeledBatchDict,
UnlabeledBatchDict,
],
predicted_keypoints: TensorType["batch", "num_targets"],
) -> TensorType["batch", "num_targets"]:
"""Transform keypoints from bbox coordinates to absolute frame coordinates."""
num_targets = predicted_keypoints.shape[1]
num_keypoints = num_targets // 2
# reshape from (batch, n_targets) back to (batch, n_key, 2), in x,y order
predicted_keypoints = predicted_keypoints.reshape((-1, num_keypoints, 2))
# divide by image dims to get 0-1 normalized coordinates
if "images" in batch_dict.keys():
predicted_keypoints[:, :, 0] /= batch_dict["images"].shape[-1] # -1 dim is width "x"
predicted_keypoints[:, :, 1] /= batch_dict["images"].shape[-2] # -2 dim is height "y"
else: # we have unlabeled dict, 'frames' instead of 'images'
predicted_keypoints[:, :, 0] /= batch_dict["frames"].shape[-1] # -1 dim is width "x"
predicted_keypoints[:, :, 1] /= batch_dict["frames"].shape[-2] # -2 dim is height "y"
# multiply and add by bbox dims (x,y,h,w)
if "num_views" in batch_dict.keys() and int(batch_dict["num_views"].max()) > 1:
unique = batch_dict["num_views"].unique()
if len(unique) != 1:
raise ValueError(
f"each batch element must contain the same number of views; "
f"found elements with {unique} views"
)
num_views = int(unique)
num_keypoints_per_view = num_keypoints // num_views
for v in range(num_views):
idx_beg = num_keypoints_per_view * v
idx_end = num_keypoints_per_view * (v + 1)
predicted_keypoints[:, idx_beg:idx_end, :] = normalized_to_bbox(
predicted_keypoints[:, idx_beg:idx_end, :],
batch_dict["bbox"][:, 4 * v:4 * (v + 1)],
)
else:
predicted_keypoints = normalized_to_bbox(predicted_keypoints, batch_dict["bbox"])
# return new keypoints, reshaped to (batch, num_targets)
return predicted_keypoints.reshape((-1, num_targets))
[docs]def get_context_from_sequence(
img_seq: Union[
TensorType["seq_len", "RGB":3, "image_height", "image_width"],
TensorType["seq_len", "n_features", "rep_height", "rep_width"],
],
context_length: int,
) -> Union[
TensorType["seq_len", "context_length", "RGB": 3, "image_height", "image_width"],
TensorType["seq_len", "context_length", "n_features", "rep_height", "rep_width"],
]:
# our goal is to extract 5-frame sequences from this sequence
img_shape = img_seq.shape[1:] # e.g., (3, H, W)
seq_len = img_seq.shape[0] # how many images in batch
train_seq = torch.zeros((seq_len, context_length, *img_shape), device=img_seq.device)
# define pads: start pad repeats the zeroth image twice. end pad repeats the last image twice.
# this is to give padding for the first and last frames of the sequence
pad_start = torch.tile(img_seq[0].unsqueeze(0), (2, 1, 1, 1))
pad_end = torch.tile(img_seq[-1].unsqueeze(0), (2, 1, 1, 1))
# pad the sequence
padded_seq = torch.cat((pad_start, img_seq, pad_end), dim=0)
# padded_seq = torch.cat((two_pad, img_seq, two_pad), dim=0)
for i in range(seq_len):
# extract 5-frame sequences from the padded sequence
train_seq[i] = padded_seq[i : i + context_length]
return train_seq
[docs]class BaseFeatureExtractor(LightningModule):
"""Object that contains the base resnet feature extractor."""
def __init__(
self,
backbone: ALLOWED_BACKBONES = "resnet50",
pretrained: bool = True,
lr_scheduler: str = "multisteplr",
lr_scheduler_params: Optional[Union[DictConfig, dict]] = None,
do_context: bool = False,
image_size: int = 256,
model_type: Literal["heatmap", "regression"] = "heatmap",
**kwargs: Any,
) -> None:
"""A CNN model that takes in images and generates features.
ResNets will be loaded from torchvision and can be either pre-trained
on ImageNet or randomly initialized. These were originally used for
classification tasks, so we truncate their final fully connected layer.
Args:
backbone: which backbone version to use; defaults to resnet50
pretrained: True to load weights pretrained on imagenet (torchvision models only)
lr_scheduler: how to schedule learning rate
lr_scheduler_params: params for specific learning rate schedulers
do_context: include temporal context when processing each frame
image_size: height/width of frames, for ViT models only
model_type: type of model
"""
super().__init__()
if self.local_rank == 0:
print(f"\n Initializing a {self._get_name()} instance.")
self.backbone_arch = backbone
if "sam" in self.backbone_arch:
from lightning_pose.models.backbones.vits import build_backbone
else:
from lightning_pose.models.backbones.torchvision import build_backbone
self.backbone, self.num_fc_input_features = build_backbone(
backbone_arch=self.backbone_arch,
pretrained=pretrained,
model_type=model_type, # for torchvision only
image_size=image_size, # for ViTs only
)
self.lr_scheduler = lr_scheduler
self.lr_scheduler_params = lr_scheduler_params
self.do_context = do_context
[docs] def get_representations(
self,
images: Union[
TensorType["batch", "channels":3, "image_height", "image_width"],
TensorType["batch", "frames", "channels":3, "image_height", "image_width"],
TensorType["seq_len", "channels":3, "image_height", "image_width"],
TensorType["batch", "views", "frames", "channels":3, "image_height", "image_width"],
TensorType["seq_len", "view", "frames", "channels":3, "image_height", "image_width"],
],
is_multiview: bool = False,
) -> Union[
TensorType["new_batch", "features", "rep_height", "rep_width"],
TensorType["new_batch", "features", "rep_height", "rep_width", "frames"],
]:
"""Forward pass from images to feature maps.
Wrapper around the backbone's feature_extractor() method for typechecking purposes.
See tests/models/test_base.py for example shapes.
Batch options
-------------
- TensorType["batch", "channels":3, "image_height", "image_width"]
single view, labeled batch
- TensorType["batch", "frames", "channels":3, "image_height", "image_width"]
single view, labeled context batch
- TensorType["seq_len", "channels":3, "image_height", "image_width"]
single view, unlabeled batch from DALI
- TensorType["batch", "views", "frames", "channels":3, "image_height", "image_width"]
multivew, labeled context batch
- TensorType["seq_len", "views", "channels":3, "image_height", "image_width"]
multiview, unlabeled batch from DALI
Args:
images: a batch of images
is_multiview: flag to distinguish batches of the same size
Returns:
a representation of the images; features differ as a function of resnet version.
Representation height and width differ as a function of image dimensions, and are not
necessarily equal.
"""
if self.do_context:
if (len(images.shape) == 5 and not is_multiview) or len(images.shape) == 6:
# len = 5
# incoming batch: singleview labeled batch
# incoming shape: (batch, frames, channels, height, width)
#
# len = 6
# incoming batch: multiview labeled batch
# incoming shape: (batch, num_views, frames, channels, height, width)
if len(images.shape) == 6:
# stacking all the views in batch dimension
shape = images.shape
images = images.reshape(-1, shape[-4], shape[-3], shape[-2], shape[-1])
batch, frames, channels, image_height, image_width = images.shape
frames_batch_shape = batch * frames
images_batch_frames = images.reshape(
frames_batch_shape, channels, image_height, image_width,
)
outputs: TensorType[
"batch*frames", "features", "rep_height", "rep_width"
] = self.backbone(images_batch_frames)
outputs: TensorType[
"batch", "frames", "features", "rep_height", "rep_width"
] = outputs.reshape(
images.shape[0],
images.shape[1],
outputs.shape[1],
outputs.shape[2],
outputs.shape[3],
)
elif len(images.shape) == 5 and is_multiview:
# incoming batch: multiview unlabeled batch
# incoming shape: (seq, num_views, channels, height, width)
batch, num_views, channels, image_height, image_width = images.shape
batch_views_shape = batch * num_views
images_batch_views = images.reshape(
batch_views_shape, channels, image_height, image_width,
)
outputs: TensorType[
"batch*views", "features", "rep_height", "rep_width"
] = self.backbone(images_batch_views)
outputs: TensorType[
"batch", "views", "features", "rep_height", "rep_width"
] = outputs.reshape(
batch,
num_views,
outputs.shape[1],
outputs.shape[2],
outputs.shape[3],
)
# stack views across feature dimension
outputs: TensorType[
"batch", "views * features", "rep_height", "rep_width"
] = outputs.reshape(batch, -1, outputs.shape[-2], outputs.shape[-1])
# we need to tile the representations to make it into
# (num_valid_frames, features, rep_height, rep_width, num_context_frames)
tiled_representations = get_context_from_sequence(
img_seq=outputs, context_length=5,
)
# get rid of first and last two frames
if tiled_representations.shape[0] < 5:
raise RuntimeError("Not enough valid frames to make a context representation.")
outputs = tiled_representations[2:-2, :, :, :, :]
elif len(images.shape) == 4:
# we have a single sequence of frames from DALI (not a batch of sequences)
# valid frame := a frame that has two frames before it and two frames after it
# we push it as is through the backbone, and then use tiling to make it into
# (sequence_length, features, rep_height, rep_width, num_context_frames)
# for now we discard the padded frames (first and last two)
# the output will be one representation per valid frame
sequence_length, channels, image_height, image_width = images.shape
representations: TensorType[
"sequence_length", "features", "rep_height", "rep_width"
] = self.backbone(images)
# we need to tile the representations to make it into
# (num_valid_frames, features, rep_height, rep_width, num_context_frames)
# TODO: context frames should be configurable
tiled_representations = get_context_from_sequence(
img_seq=representations, context_length=5,
)
# get rid of first and last two frames
if tiled_representations.shape[0] < 5:
raise RuntimeError("Not enough valid frames to make a context representation.")
outputs = tiled_representations[2:-2, :, :, :, :]
# for both types of batches, we reshape in the same way
# context is in the last dimension for the linear layer.
representations: TensorType[
"batch", "features", "rep_height", "rep_width", "frames"
] = torch.permute(outputs, (0, 2, 3, 4, 1))
else:
# incoming batch: singleview labeled/unlabeled, multiview labeled/unlabeled reshaped
# incoming shape: (batch, channels, height, width)
representations = self.backbone(images)
return representations
[docs] def forward(
self,
images: Union[
TensorType["batch", "RGB":3, "image_height", "image_width"],
TensorType["batch", "seq_length", "RGB":3, "image_height", "image_width"],
TensorType["seq_length", "RGB":3, "image_height", "image_width"],
],
) -> Union[
TensorType["new_batch", "features", "rep_height", "rep_width"],
TensorType["new_batch", "features", "rep_height", "rep_width", "frames"],
]:
"""Forward pass from images to representations.
Wrapper around self.get_representations().
Fancier childern models will use get_representations() in their forward methods.
Args:
images: a batch of images.
Returns:
a representation of the images.
"""
return self.get_representations(images)
[docs] def get_scheduler(self, optimizer):
# define a scheduler that reduces the base learning rate
if self.lr_scheduler == "multisteplr" or self.lr_scheduler == "multistep_lr":
if self.lr_scheduler_params is None:
milestones = MULTISTEPLR_MILESTONES_DEFAULT
gamma = MULTISTEPLR_GAMMA_DEFAULT
else:
milestones = self.lr_scheduler_params.get(
"milestones", MULTISTEPLR_MILESTONES_DEFAULT)
gamma = self.lr_scheduler_params.get("gamma", MULTISTEPLR_GAMMA_DEFAULT)
scheduler = MultiStepLR(optimizer, milestones=milestones, gamma=gamma)
else:
raise NotImplementedError("'%s' is an invalid LR scheduler" % self.lr_scheduler)
return scheduler
[docs] def get_parameters(self):
if getattr(self, "upsampling_layers", None) is not None:
params = [
{"params": self.backbone.parameters(), "lr": 0, "name": "backbone"},
{"params": self.upsampling_layers.parameters(), "name": "upsampling"},
]
else:
params = filter(lambda p: p.requires_grad, self.parameters())
return params
[docs] def configure_optimizers(self) -> dict:
"""Select optimizer, lr scheduler, and metric for monitoring."""
# get trainable params
params = self.get_parameters()
# init standard adam optimizer
optimizer = Adam(params, lr=1e-3)
# get learning rate scheduler
scheduler = self.get_scheduler(optimizer)
return {
"optimizer": optimizer,
"lr_scheduler": scheduler,
"monitor": "val_supervised_loss",
}
[docs]class BaseSupervisedTracker(BaseFeatureExtractor):
"""Base class for supervised trackers."""
[docs] def get_loss_inputs_labeled(
self,
batch_dict: Union[
BaseLabeledBatchDict,
HeatmapLabeledBatchDict,
MultiviewLabeledBatchDict,
MultiviewHeatmapLabeledBatchDict,
],
) -> dict:
"""Return predicted coordinates for a batch of data."""
raise NotImplementedError
[docs] def evaluate_labeled(
self,
batch_dict: Union[
BaseLabeledBatchDict,
HeatmapLabeledBatchDict,
MultiviewLabeledBatchDict,
MultiviewHeatmapLabeledBatchDict,
],
stage: Optional[Literal["train", "val", "test"]] = None,
) -> TensorType[(), float]:
"""Compute and log the losses on a batch of labeled data."""
# forward pass; collected true and predicted heatmaps, keypoints
data_dict = self.get_loss_inputs_labeled(batch_dict=batch_dict)
# compute and log loss on labeled data
loss, log_list = self.loss_factory(stage=stage, **data_dict)
# compute and log pixel_error loss on labeled data
loss_rmse, _ = self.rmse_loss(stage=stage, **data_dict)
if stage:
# logging with sync_dist=True will average the metric across GPUs in
# multi-GPU training. Performance overhead was found negligible.
# log overall supervised loss
self.log(f"{stage}_supervised_loss", loss, prog_bar=True, sync_dist=True)
# log supervised pixel_error
self.log(f"{stage}_supervised_rmse", loss_rmse, sync_dist=True)
# log individual supervised losses
for log_dict in log_list:
self.log(
log_dict['name'],
log_dict['value'].to(self.device),
prog_bar=log_dict.get('prog_bar', False),
sync_dist=True)
return loss
[docs] def training_step(
self,
batch_dict: Union[
BaseLabeledBatchDict,
HeatmapLabeledBatchDict,
MultiviewLabeledBatchDict,
MultiviewHeatmapLabeledBatchDict,
],
batch_idx: int,
) -> Dict[str, TensorType[(), float]]:
"""Base training step, a wrapper around the `evaluate_labeled` method."""
loss = self.evaluate_labeled(batch_dict, "train")
return {"loss": loss}
[docs] def validation_step(
self,
batch_dict: Union[
BaseLabeledBatchDict,
HeatmapLabeledBatchDict,
MultiviewLabeledBatchDict,
MultiviewHeatmapLabeledBatchDict,
],
batch_idx: int,
) -> None:
"""Base validation step, a wrapper around the `evaluate_labeled` method."""
self.evaluate_labeled(batch_dict, "val")
[docs] def test_step(
self,
batch_dict: Union[
BaseLabeledBatchDict,
HeatmapLabeledBatchDict,
MultiviewLabeledBatchDict,
MultiviewHeatmapLabeledBatchDict,
],
batch_idx: int,
) -> None:
"""Base test step, a wrapper around the `evaluate_labeled` method."""
self.evaluate_labeled(batch_dict, "test")
[docs]@typechecked
class SemiSupervisedTrackerMixin(object):
"""Mixin class providing training step function for semi-supervised models."""
[docs] def get_loss_inputs_unlabeled(self, batch_dict: UnlabeledBatchDict) -> dict:
"""Return predicted heatmaps and their softmaxes (estimated keypoints)."""
raise NotImplementedError
[docs] def evaluate_unlabeled(
self,
batch_dict: Union[UnlabeledBatchDict, MultiviewUnlabeledBatchDict],
stage: Optional[Literal["train", "val", "test"]] = None,
anneal_weight: Union[float, torch.Tensor] = 1.0,
) -> TensorType[(), float]:
"""Compute and log the losses on a batch of unlabeled data (frames only)."""
# forward pass: collect predicted heatmaps and keypoints
data_dict = self.get_loss_inputs_unlabeled(batch_dict=batch_dict)
# compute loss on unlabeled data
loss, log_list = self.loss_factory_unsup(
stage=stage,
anneal_weight=anneal_weight,
**data_dict,
)
if stage:
# log individual unsupervised losses
for log_dict in log_list:
self.log(
log_dict['name'],
log_dict['value'].to(self.device),
prog_bar=log_dict.get('prog_bar', False),
sync_dist=True)
return loss
[docs] def training_step(
self,
batch_dict: Union[SemiSupervisedBatchDict, SemiSupervisedHeatmapBatchDict],
batch_idx: int,
) -> Dict[str, TensorType[(), float]]:
"""Training step computes and logs both supervised and unsupervised losses."""
# on each epoch, self.total_unsupervised_importance is modified by the
# AnnealWeight callback
self.log(
"total_unsupervised_importance",
self.total_unsupervised_importance,
prog_bar=True,
# don't need to sync_dist because this is always the same across processes.
)
# computes and logs supervised losses
# train_batch["labeled"] contains:
# - images
# - keypoints
# - heatmaps
loss_super = self.evaluate_labeled(
batch_dict=batch_dict["labeled"],
stage="train",
)
# computes and logs unsupervised losses
# train_batch["unlabeled"] contains:
# - images
loss_unsuper = self.evaluate_unlabeled(
batch_dict=batch_dict["unlabeled"],
stage="train",
anneal_weight=self.total_unsupervised_importance,
)
# log total loss
total_loss = loss_super + loss_unsuper
self.log("total_loss", total_loss, prog_bar=True, sync_dist=True)
return {"loss": total_loss}