Source code for lightning_pose.models.base

"""Base class for backbone that acts as a feature extractor."""

from __future__ import annotations

from collections.abc import Iterator
from typing import TYPE_CHECKING, Any, Literal, cast

if TYPE_CHECKING:
    from lightning_pose.losses.factory import LossFactory
    from lightning_pose.losses.losses import RegressionRMSELoss

import torch
from jaxtyping import Float
from lightning.pytorch import LightningModule
from omegaconf import DictConfig, ListConfig, OmegaConf
from torch import optim
from torch.optim.lr_scheduler import MultiStepLR

from lightning_pose.data.datatypes import (
    BaseLabeledBatchDict,
    HeatmapLabeledBatchDict,
    MultiviewHeatmapLabeledBatchDict,
    MultiviewLabeledBatchDict,
    MultiviewUnlabeledBatchDict,
    SemiSupervisedBatchDict,
    SemiSupervisedHeatmapBatchDict,
    UnlabeledBatchDict,
)
from lightning_pose.models.backbones import ALLOWED_BACKBONES

# to ignore imports for sphix-autoapidoc
__all__ = [
    'check_if_semi_supervised',
    'get_context_from_sequence',
    'BaseFeatureExtractor',
    'BaseSupervisedTracker',
    'SemiSupervisedTrackerMixin',
]


[docs] def check_if_semi_supervised(losses_to_use: ListConfig | list | None = None) -> bool: """Determine from the losses config whether the model is semi-supervised. Args: losses_to_use: the cfg entry specifying unsupervised losses to use. Returns: True if the model is semi-supervised, False otherwise. """ if losses_to_use is None: return False if len(losses_to_use) == 0: return False if len(losses_to_use) == 1 and losses_to_use[0] == '': return False return True
DEFAULT_LR_SCHEDULER_PARAMS = OmegaConf.create( { "milestones": [150, 200, 250], "gamma": 0.5, } ) DEFAULT_OPTIMIZER_PARAMS = OmegaConf.create( { "learning_rate": 1e-3, } ) class LrNotImplementedError(NotImplementedError): def __init__(self, lr_scheduler: str) -> None: """Initialize LrNotImplementedError. Args: lr_scheduler: the unsupported scheduler name that caused the error. """ super().__init__( f"'{lr_scheduler}' is an invalid LR scheduler. Must be multisteplr." ) self.lr_scheduler = lr_scheduler class OptimizerNotImplementedError(NotImplementedError): def __init__(self, optimizer: str) -> None: """Initialize OptimizerNotImplementedError. Args: optimizer: the unsupported optimizer name that caused the error. """ super().__init__( f"'{optimizer}' is an invalid optimizer. Must be Adam or AdamW." ) self.optimizer = optimizer def _apply_defaults_for_lr_scheduler_params( lr_scheduler: str, lr_scheduler_params: DictConfig | ListConfig | dict | None ) -> DictConfig | ListConfig: """Merge user-supplied LR scheduler params with defaults. Args: lr_scheduler: name of the learning rate scheduler (currently only ``"multisteplr"``). lr_scheduler_params: user-supplied parameter overrides, or ``None`` to use defaults. Returns: Merged ``DictConfig`` / ``ListConfig`` with all required scheduler parameters. Raises: LrNotImplementedError: if ``lr_scheduler`` is not a supported scheduler name. """ if lr_scheduler not in ("multistep_lr", "multisteplr"): raise LrNotImplementedError(lr_scheduler) if lr_scheduler_params is None: lr_scheduler_params = DEFAULT_LR_SCHEDULER_PARAMS else: lr_scheduler_params = OmegaConf.merge( DEFAULT_LR_SCHEDULER_PARAMS, lr_scheduler_params ) return lr_scheduler_params def _apply_defaults_for_optimizer_params( optimizer: str, optimizer_params: DictConfig | ListConfig | dict | None ) -> DictConfig | ListConfig: """Merge user-supplied optimizer params with defaults. Args: optimizer: optimizer name; currently ``"Adam"`` or ``"AdamW"``. optimizer_params: user-supplied parameter overrides, or ``None`` to use defaults. Returns: Merged ``DictConfig`` / ``ListConfig`` with all required optimizer parameters. Raises: OptimizerNotImplementedError: if ``optimizer`` is not a supported optimizer name. """ if optimizer not in ("Adam", "AdamW"): raise OptimizerNotImplementedError(optimizer) if optimizer_params is None: optimizer_params = DEFAULT_OPTIMIZER_PARAMS else: optimizer_params = OmegaConf.merge(DEFAULT_OPTIMIZER_PARAMS, optimizer_params) return optimizer_params
[docs] def get_context_from_sequence( img_seq: ( Float[torch.Tensor, "seq_len RGB image_height image_width"] | Float[torch.Tensor, "seq_len n_features rep_height rep_width"] ), context_length: int, ) -> ( Float[torch.Tensor, "seq_len context_length RGB image_height image_width"] | Float[torch.Tensor, "seq_len context_length n_features rep_height rep_width"] ): """Build overlapping context windows from a sequence of frames or feature maps. The sequence is padded at the start and end by repeating the first/last element so that every original frame has a full ``context_length``-frame window centred on it. Args: img_seq: sequence tensor of shape ``(seq_len, ...)``. context_length: number of frames in each context window (e.g., 5). Returns: Tensor of shape ``(seq_len, context_length, ...)`` where each row is the context window centred on the corresponding input frame. """ # our goal is to extract 5-frame sequences from this sequence img_shape = img_seq.shape[1:] # e.g., (3, H, W) seq_len = img_seq.shape[0] # how many images in batch train_seq = torch.zeros((seq_len, context_length, *img_shape), device=img_seq.device) # define pads: start pad repeats the zeroth image twice. end pad repeats the last image twice. # this is to give padding for the first and last frames of the sequence pad_start = torch.tile(img_seq[0].unsqueeze(0), (2, 1, 1, 1)) pad_end = torch.tile(img_seq[-1].unsqueeze(0), (2, 1, 1, 1)) # pad the sequence padded_seq = torch.cat((pad_start, img_seq, pad_end), dim=0) # padded_seq = torch.cat((two_pad, img_seq, two_pad), dim=0) for i in range(seq_len): # extract 5-frame sequences from the padded sequence train_seq[i] = padded_seq[i : i + context_length] return train_seq
[docs] class BaseFeatureExtractor(LightningModule): """Object that contains the base resnet feature extractor."""
[docs] def __init__( self, backbone: ALLOWED_BACKBONES = "resnet50", pretrained: bool = True, lr_scheduler: str = "multisteplr", lr_scheduler_params: DictConfig | ListConfig | dict | None = None, optimizer: str = "Adam", optimizer_params: DictConfig | ListConfig | dict | None = None, do_context: bool = False, image_size: int = 256, model_type: Literal["heatmap", "regression"] = "heatmap", **kwargs: Any, ) -> None: """A CNN model that takes in images and generates features. ResNets will be loaded from torchvision and can be either pre-trained on ImageNet or randomly initialized. These were originally used for classification tasks, so we truncate their final fully connected layer. Args: backbone: which backbone version to use; defaults to resnet50 pretrained: True to load weights pretrained on imagenet (torchvision models only) optimizer: optimizer class to instantiate (Adam, AdamW, more to be added in future) optimizer_params: arguments to pass to optimizer lr_scheduler: how to schedule learning rate lr_scheduler_params: params for specific learning rate schedulers do_context: include temporal context when processing each frame image_size: height/width of frames, for ViT models only model_type: type of model """ super().__init__() if self.local_rank == 0: print(f"\nInitializing a {self._get_name()} instance with {backbone} backbone.") self.backbone_arch = backbone if self.backbone_arch.startswith("vit"): from lightning_pose.models.backbones.vits import build_backbone else: from lightning_pose.models.backbones.torchvision import build_backbone self.backbone, self.num_fc_input_features = build_backbone( backbone_arch=self.backbone_arch, pretrained=pretrained, model_type=model_type, # for torchvision only image_size=image_size, # for ViTs only backbone_checkpoint=kwargs.get('backbone_checkpoint'), # for ViTMAE's only ) self.lr_scheduler = lr_scheduler self.lr_scheduler_params = _apply_defaults_for_lr_scheduler_params( lr_scheduler, lr_scheduler_params ) self.optimizer = optimizer self.optimizer_params = _apply_defaults_for_optimizer_params( optimizer, optimizer_params ) self.do_context = do_context
[docs] def get_representations( self, images: ( Float[torch.Tensor, "batch channels image_height image_width"] | Float[torch.Tensor, "batch frames channels image_height image_width"] | Float[torch.Tensor, "seq_len channels image_height image_width"] | Float[torch.Tensor, "batch views frames channels image_height image_width"] | Float[torch.Tensor, "seq_len view frames channels image_height image_width"] ), is_multiview: bool = False, ) -> ( Float[torch.Tensor, "new_batch features rep_height rep_width"] | Float[torch.Tensor, "new_batch features rep_height rep_width frames"] ): """Forward pass from images to feature maps. Wrapper around the backbone's feature_extractor() method for typechecking purposes. See tests/models/test_base.py for example shapes. Batch options ------------- - Float[torch.Tensor, "batch channels image_height image_width"] single view, labeled batch - Float[torch.Tensor, "batch frames channels image_height image_width"] single view, labeled context batch - Float[torch.Tensor, "seq_len channels image_height image_width"] single view, unlabeled batch from DALI - Float[torch.Tensor, "batch views frames channels image_height image_width"] multivew, labeled context batch - Float[torch.Tensor, "seq_len views channels image_height image_width"] multiview, unlabeled batch from DALI Args: images: a batch of images is_multiview: flag to distinguish batches of the same size Returns: a representation of the images; features differ as a function of resnet version. Representation height and width differ as a function of image dimensions, and are not necessarily equal. """ if self.do_context: if (len(images.shape) == 5 and not is_multiview) or len(images.shape) == 6: # len = 5 # incoming batch: singleview labeled batch # incoming shape: (batch, frames, channels, height, width) # # len = 6 # incoming batch: multiview labeled batch # incoming shape: (batch, num_views, frames, channels, height, width) if len(images.shape) == 6: # stacking all the views in batch dimension shape = images.shape images = images.reshape(-1, shape[-4], shape[-3], shape[-2], shape[-1]) batch, frames, channels, image_height, image_width = images.shape frames_batch_shape = batch * frames images_batch_frames = images.reshape( frames_batch_shape, channels, image_height, image_width, ) outputs = self.backbone(images_batch_frames) outputs = outputs.reshape( images.shape[0], images.shape[1], outputs.shape[1], outputs.shape[2], outputs.shape[3], ) elif len(images.shape) == 5 and is_multiview: # incoming batch: multiview unlabeled batch # incoming shape: (seq, num_views, channels, height, width) batch, num_views, channels, image_height, image_width = images.shape batch_views_shape = batch * num_views images_batch_views = images.reshape( batch_views_shape, channels, image_height, image_width, ) outputs = self.backbone(images_batch_views) outputs = outputs.reshape( batch, num_views, outputs.shape[1], outputs.shape[2], outputs.shape[3], ) # stack views across feature dimension outputs = outputs.reshape(batch, -1, outputs.shape[-2], outputs.shape[-1]) # we need to tile the representations to make it into # (num_valid_frames, features, rep_height, rep_width, num_context_frames) tiled_representations = get_context_from_sequence( img_seq=outputs, context_length=5, ) # get rid of first and last two frames if tiled_representations.shape[0] < 5: raise RuntimeError("Not enough valid frames to make a context representation.") outputs = tiled_representations[2:-2, :, :, :, :] elif len(images.shape) == 4: # we have a single sequence of frames from DALI (not a batch of sequences) # valid frame := a frame that has two frames before it and two frames after it # we push it as is through the backbone, and then use tiling to make it into # (sequence_length, features, rep_height, rep_width, num_context_frames) # for now we discard the padded frames (first and last two) # the output will be one representation per valid frame sequence_length, channels, image_height, image_width = images.shape representations = self.backbone(images) # we need to tile the representations to make it into # (num_valid_frames, features, rep_height, rep_width, num_context_frames) # TODO: context frames should be configurable tiled_representations = get_context_from_sequence( img_seq=representations, context_length=5, ) # get rid of first and last two frames if tiled_representations.shape[0] < 5: raise RuntimeError("Not enough valid frames to make a context representation.") outputs = tiled_representations[2:-2, :, :, :, :] # for both types of batches, we reshape in the same way # context is in the last dimension for the linear layer. representations = torch.permute(outputs, (0, 2, 3, 4, 1)) else: # incoming batch: singleview labeled/unlabeled, multiview labeled/unlabeled reshaped # incoming shape: (batch, channels, height, width) representations = self.backbone(images) return representations
[docs] def forward( self, images: ( Float[torch.Tensor, "batch RGB image_height image_width"] | Float[torch.Tensor, "batch seq_length RGB image_height image_width"] | Float[torch.Tensor, "seq_length RGB image_height image_width"] ), ) -> ( Float[torch.Tensor, "new_batch features rep_height rep_width"] | Float[torch.Tensor, "new_batch features rep_height rep_width frames"] ): """Forward pass from images to representations. Wrapper around self.get_representations(). Fancier childern models will use get_representations() in their forward methods. Args: images: a batch of images. Returns: a representation of the images. """ return self.get_representations(images)
[docs] def get_scheduler(self, optimizer: torch.optim.Optimizer) -> MultiStepLR: """Build and return the learning rate scheduler. Args: optimizer: the optimizer whose learning rate will be scheduled. Returns: ``MultiStepLR`` scheduler configured from ``self.lr_scheduler_params``. Raises: LrNotImplementedError: if ``self.lr_scheduler`` is not supported. """ if self.lr_scheduler not in ("multistep_lr", "multisteplr"): raise LrNotImplementedError(self.lr_scheduler) # define a scheduler that reduces the base learning rate milestones = self.lr_scheduler_params.milestones gamma = self.lr_scheduler_params.gamma scheduler = MultiStepLR(optimizer, milestones=milestones, gamma=gamma) return scheduler
[docs] def get_parameters(self) -> Iterator[torch.nn.Parameter]: """Return an iterator over trainable (requires_grad) model parameters. Returns: Iterator of ``torch.nn.Parameter`` objects that require gradients. """ params = filter(lambda p: p.requires_grad, self.parameters()) return params
[docs] def configure_optimizers(self) -> dict: """Select optimizer, lr scheduler, and metric for monitoring.""" # get trainable params params = self.get_parameters() # init standard adam optimizer if self.optimizer == "Adam": optimizer = optim.Adam(params, lr=self.optimizer_params.learning_rate) elif self.optimizer == "AdamW": optimizer = optim.AdamW(params, lr=self.optimizer_params.learning_rate) else: raise OptimizerNotImplementedError(self.optimizer) # get learning rate scheduler scheduler = self.get_scheduler(optimizer) return { "optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "val_supervised_loss", }
[docs] class BaseSupervisedTracker(BaseFeatureExtractor): """Base class for supervised trackers.""" loss_factory: LossFactory | None rmse_loss: RegressionRMSELoss
[docs] def get_loss_inputs_labeled( self, batch_dict: ( BaseLabeledBatchDict | HeatmapLabeledBatchDict | MultiviewLabeledBatchDict | MultiviewHeatmapLabeledBatchDict ), ) -> dict: """Return predicted coordinates for a batch of data.""" raise NotImplementedError
[docs] def evaluate_labeled( self, batch_dict: ( BaseLabeledBatchDict | HeatmapLabeledBatchDict | MultiviewLabeledBatchDict | MultiviewHeatmapLabeledBatchDict ), stage: Literal["train", "val", "test"] | None = None, anneal_weight: torch.Tensor | None = None, ) -> Float[torch.Tensor, ""]: """Compute and log the losses on a batch of labeled data.""" # forward pass; collected true and predicted heatmaps, keypoints data_dict = self.get_loss_inputs_labeled(batch_dict=batch_dict) # compute and log loss on labeled data assert self.loss_factory is not None loss, log_list = self.loss_factory(stage=stage, anneal_weight=anneal_weight, **data_dict) # compute and log pixel_error loss on labeled data loss_rmse, _ = self.rmse_loss(stage=stage, **data_dict) if stage: # logging with sync_dist=True will average the metric across GPUs in # multi-GPU training. Performance overhead was found negligible. # log overall supervised loss self.log(f"{stage}_supervised_loss", loss, prog_bar=True, sync_dist=True) # log supervised pixel_error self.log(f"{stage}_supervised_rmse", loss_rmse, sync_dist=True) # log individual supervised losses for log_dict in log_list: self.log( log_dict['name'], log_dict['value'].to(self.device), prog_bar=log_dict.get('prog_bar', False), sync_dist=True) return loss
[docs] def training_step( self, batch_dict: ( BaseLabeledBatchDict | HeatmapLabeledBatchDict | MultiviewLabeledBatchDict | MultiviewHeatmapLabeledBatchDict ), batch_idx: int, ) -> dict[str, Float[torch.Tensor, ""]]: """Base training step, a wrapper around the `evaluate_labeled` method.""" # on each epoch, self.total_unsupervised_importance is modified by the # AnnealWeight callback if hasattr(self, "total_unsupervised_importance"): unsup_importance = cast(torch.Tensor, self.total_unsupervised_importance) self.log( "total_unsupervised_importance", unsup_importance, prog_bar=True, # don't need to sync_dist because this is always the same across processes. ) anneal_weight = unsup_importance else: anneal_weight = None loss = self.evaluate_labeled(batch_dict, "train", anneal_weight=anneal_weight) return {"loss": loss}
[docs] def validation_step( self, batch_dict: ( BaseLabeledBatchDict | HeatmapLabeledBatchDict | MultiviewLabeledBatchDict | MultiviewHeatmapLabeledBatchDict ), batch_idx: int, ) -> None: """Base validation step, a wrapper around the `evaluate_labeled` method.""" self.evaluate_labeled(batch_dict, "val")
[docs] def test_step( self, batch_dict: ( BaseLabeledBatchDict | HeatmapLabeledBatchDict | MultiviewLabeledBatchDict | MultiviewHeatmapLabeledBatchDict ), batch_idx: int, ) -> None: """Base test step, a wrapper around the `evaluate_labeled` method.""" self.evaluate_labeled(batch_dict, "test")
[docs] class SemiSupervisedTrackerMixin(BaseSupervisedTracker if TYPE_CHECKING else object): """Mixin class providing training step function for semi-supervised models. Always mixed with BaseSupervisedTracker (which provides LightningModule methods). The conditional inheritance from BaseSupervisedTracker at TYPE_CHECKING time gives pyright visibility into log(), device, evaluate_labeled(), loss_factory, etc. """ loss_factory_unsup: LossFactory | None total_unsupervised_importance: torch.Tensor
[docs] def get_loss_inputs_unlabeled( self, batch_dict: UnlabeledBatchDict | MultiviewUnlabeledBatchDict, ) -> dict: """Return predicted heatmaps and their softmaxes (estimated keypoints).""" raise NotImplementedError
[docs] def evaluate_unlabeled( self, batch_dict: UnlabeledBatchDict | MultiviewUnlabeledBatchDict, stage: Literal["train", "val", "test"] | None = None, anneal_weight: float | torch.Tensor = 1.0, ) -> Float[torch.Tensor, ""]: """Compute and log the losses on a batch of unlabeled data (frames only).""" # forward pass: collect predicted heatmaps and keypoints data_dict = self.get_loss_inputs_unlabeled(batch_dict=batch_dict) # compute loss on unlabeled data assert self.loss_factory_unsup is not None loss, log_list = self.loss_factory_unsup( stage=stage, anneal_weight=anneal_weight, **data_dict, ) if stage: # log individual unsupervised losses for log_dict in log_list: self.log( log_dict['name'], log_dict['value'].to(self.device), prog_bar=log_dict.get('prog_bar', False), sync_dist=True) return loss
[docs] def training_step( self, batch_dict: SemiSupervisedBatchDict | SemiSupervisedHeatmapBatchDict, batch_idx: int, ) -> dict[str, Float[torch.Tensor, ""]]: """Training step computes and logs both supervised and unsupervised losses.""" # on each epoch, self.total_unsupervised_importance is modified by the # AnnealWeight callback unsup_importance = self.total_unsupervised_importance self.log( "total_unsupervised_importance", unsup_importance, prog_bar=True, # don't need to sync_dist because this is always the same across processes. ) # computes and logs supervised losses # train_batch["labeled"] contains: # - images # - keypoints # - heatmaps loss_super = self.evaluate_labeled( batch_dict=batch_dict["labeled"], stage="train", anneal_weight=unsup_importance, ) # computes and logs unsupervised losses # train_batch["unlabeled"] contains: # - images loss_unsuper = self.evaluate_unlabeled( batch_dict=batch_dict["unlabeled"], stage="train", anneal_weight=unsup_importance, ) # log total loss total_loss = loss_super + loss_unsuper self.log("total_loss", total_loss, prog_bar=True, sync_dist=True) return {"loss": total_loss}