Source code for lightning_pose.models.heatmap_tracker_mhcrnn

"""Models that produce heatmaps of keypoints from images."""

from typing import Any, Literal

import torch
from jaxtyping import Float
from omegaconf import DictConfig, ListConfig

from lightning_pose.data.datatypes import (
    HeatmapLabeledBatchDict,
    MultiviewHeatmapLabeledBatchDict,
    MultiviewUnlabeledBatchDict,
    UnlabeledBatchDict,
)
from lightning_pose.data.utils import convert_bbox_coords, undo_affine_transform_batch
from lightning_pose.losses.factory import LossFactory
from lightning_pose.losses.losses import RegressionRMSELoss
from lightning_pose.models.backbones import ALLOWED_BACKBONES
from lightning_pose.models.base import (
    BaseSupervisedTracker,
    SemiSupervisedTrackerMixin,
)
from lightning_pose.models.heads import HeatmapMHCRNNHead

# to ignore imports for sphix-autoapidoc
__all__ = []


class HeatmapTrackerMHCRNN(BaseSupervisedTracker):
    """Multi-headed Convolutional RNN network that handles context frames."""

[docs] def __init__( self, num_keypoints: int, num_targets: int | None = None, loss_factory: LossFactory | None = None, backbone: ALLOWED_BACKBONES = "resnet50", pretrained: bool = True, downsample_factor: Literal[1, 2, 3] = 2, torch_seed: int = 123, optimizer: str = "Adam", optimizer_params: DictConfig | ListConfig | dict | None = None, lr_scheduler: str = "multisteplr", lr_scheduler_params: DictConfig | ListConfig | dict | None = None, **kwargs: Any, ) -> None: """Initialize a DLC-like model with resnet backbone. Args: num_keypoints: number of body parts loss_factory: object to orchestrate loss computation backbone: ResNet or EfficientNet variant to be used pretrained: True to load pretrained imagenet weights downsample_factor: make heatmap smaller than original frames to save memory; subpixel operations are performed for increased precision torch_seed: make weight initialization reproducible lr_scheduler: how to schedule learning rate lr_scheduler_params: params for specific learning rate schedulers - multisteplr: milestones, gamma """ if downsample_factor != 2: raise NotImplementedError("MHCRNN currently only implements downsample_factor=2") # for reproducible weight initialization self.torch_seed = torch_seed torch.manual_seed(torch_seed) # for backwards compatibility if "do_context" in kwargs.keys(): del kwargs["do_context"] super().__init__( backbone=backbone, pretrained=pretrained, optimizer=optimizer, optimizer_params=optimizer_params, lr_scheduler=lr_scheduler, lr_scheduler_params=lr_scheduler_params, do_context=True, **kwargs, ) self.num_keypoints = num_keypoints if num_targets is None: self.num_targets = num_keypoints * 2 else: self.num_targets = num_targets self.downsample_factor = downsample_factor self.head = HeatmapMHCRNNHead( backbone_arch=backbone, in_channels=self.num_fc_input_features, out_channels=self.num_keypoints, downsample_factor=self.downsample_factor, upsampling_factor=1 if "vit" in backbone else 2, ) self.loss_factory = loss_factory # use this to log auxiliary information: pixel_error on labeled data self.rmse_loss = RegressionRMSELoss() # necessary so we don't have to pass in model arguments when loading # also, "loss_factory" and "loss_factory_unsupervised" cannot be pickled # (loss_factory_unsupervised might come from SemiSupervisedHeatmapTracker.__super__(). # otherwise it's ignored, important so that it doesn't try to pickle the dali loaders) self.save_hyperparameters(ignore=["loss_factory", "loss_factory_unsupervised"])
[docs] def forward( self, images: ( Float[torch.Tensor, "batch frames channels image_height image_width"] | Float[torch.Tensor, "batch channels image_height image_width"] | Float[torch.Tensor, "batch view frames channels image_height image_width"] | Float[torch.Tensor, "batch view channels image_height image_width"] ), is_multiview: bool = False, ) -> tuple[ Float[torch.Tensor, "num_valid_outputs num_keypoints heatmap_height heatmap_width"], Float[torch.Tensor, "num_valid_outputs num_keypoints heatmap_height heatmap_width"], ]: """Forward pass through the network. Batch options ------------- - Float[torch.Tensor, "batch frames channels image_height image_width"] single view, labeled context batch - Float[torch.Tensor, "batch channels image_height image_width"] single view, unlabeled batch from DALI - Float[torch.Tensor, "batch view frames channels image_height image_width"] multivew, labeled context batch - Float[torch.Tensor, "batch view channels image_height image_width"] multiview, unlabeled batch from DALI """ shape = images.shape # get one representation for each frame representations = self.get_representations(images, is_multiview=is_multiview) # representations shape is (batch, features, height, width, frames) # get two heatmaps for each representation (single frame, multi-frame) heatmaps_sf, heatmaps_mf = self.head(representations, shape, is_multiview) return heatmaps_sf, heatmaps_mf
[docs] def get_loss_inputs_labeled( self, batch_dict: HeatmapLabeledBatchDict | MultiviewHeatmapLabeledBatchDict, ) -> dict: """Return predicted heatmaps and their softmaxes (estimated keypoints).""" # images -> heatmaps pred_heatmaps_sf, pred_heatmaps_mf = self.forward(batch_dict["images"]) # heatmaps -> keypoints pred_keypoints_sf, confidence_sf = self.head.run_subpixelmaxima(pred_heatmaps_sf) pred_keypoints_mf, confidence_mf = self.head.run_subpixelmaxima(pred_heatmaps_mf) # bounding box coords -> original image coords target_keypoints = convert_bbox_coords(batch_dict, batch_dict["keypoints"]) pred_keypoints_sf = convert_bbox_coords(batch_dict, pred_keypoints_sf) pred_keypoints_mf = convert_bbox_coords(batch_dict, pred_keypoints_mf) return { "heatmaps_targ": torch.cat([batch_dict["heatmaps"], batch_dict["heatmaps"]], dim=0), "heatmaps_pred": torch.cat([pred_heatmaps_sf, pred_heatmaps_mf], dim=0), "keypoints_targ": torch.cat([target_keypoints, target_keypoints], dim=0), "keypoints_pred": torch.cat([pred_keypoints_sf, pred_keypoints_mf], dim=0), "confidences": torch.cat([confidence_sf, confidence_mf], dim=0), }
[docs] def predict_step( self, batch_dict: ( HeatmapLabeledBatchDict | MultiviewHeatmapLabeledBatchDict | UnlabeledBatchDict ), batch_idx: int, return_heatmaps: bool | None = False, ) -> tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Predict heatmaps and keypoints for a batch of video frames. Assuming a DALI video loader is passed in > trainer = Trainer(devices=8, accelerator="gpu") > predictions = trainer.predict(model, data_loader) """ if "images" in batch_dict.keys(): # can't do isinstance(o, c) on TypedDicts # labeled image dataloaders images = batch_dict["images"] # type: ignore[typeddict-item] else: # unlabeled dali video dataloaders images = batch_dict["frames"] # type: ignore[typeddict-item] # images -> heatmaps pred_heatmaps_sf, pred_heatmaps_mf = self.forward(images) # heatmaps -> keypoints pred_keypoints_sf, confidence_sf = self.head.run_subpixelmaxima(pred_heatmaps_sf) pred_keypoints_mf, confidence_mf = self.head.run_subpixelmaxima(pred_heatmaps_mf) # reshape keypoints to be (batch, n_keypoints, 2) pred_keypoints_sf = pred_keypoints_sf.reshape(pred_keypoints_sf.shape[0], -1, 2) pred_keypoints_mf = pred_keypoints_mf.reshape(pred_keypoints_mf.shape[0], -1, 2) # find higher confidence indices mf_conf_gt = torch.gt(confidence_mf, confidence_sf) # select higher confidence indices pred_keypoints_sf[mf_conf_gt] = pred_keypoints_mf[mf_conf_gt] pred_keypoints_sf = pred_keypoints_sf.reshape(pred_keypoints_sf.shape[0], -1) confidence_sf[mf_conf_gt] = confidence_mf[mf_conf_gt] # bounding box coords -> original image coords pred_keypoints_sf = convert_bbox_coords(batch_dict, pred_keypoints_sf) if return_heatmaps: pred_heatmaps_sf[mf_conf_gt] = pred_heatmaps_mf[mf_conf_gt] return pred_keypoints_sf, confidence_sf, pred_heatmaps_sf else: return pred_keypoints_sf, confidence_sf
[docs] def get_parameters(self) -> list[dict]: """Return per-parameter-group optimizer configuration for backbone and head. Returns: List of dicts with ``"params"``, ``"name"``, and optionally ``"lr"`` keys; the backbone starts with learning rate 0 (frozen until unfreezing). """ params = [ {"params": self.backbone.parameters(), "name": "backbone", "lr": 0.0}, {"params": self.head.parameters(), "name": "head"}, ] return params
class SemiSupervisedHeatmapTrackerMHCRNN(SemiSupervisedTrackerMixin, HeatmapTrackerMHCRNN): """Model produces heatmaps of keypoints from labeled/unlabeled images."""
[docs] def __init__( self, num_keypoints: int, loss_factory: LossFactory | None = None, loss_factory_unsupervised: LossFactory | None = None, backbone: ALLOWED_BACKBONES = "resnet50", downsample_factor: Literal[2, 3] = 2, pretrained: bool = True, torch_seed: int = 123, optimizer: str = "Adam", optimizer_params: DictConfig | ListConfig | dict | None = None, lr_scheduler: str = "multisteplr", lr_scheduler_params: DictConfig | ListConfig | dict | None = None, **kwargs: Any, ) -> None: """ Args: num_keypoints: number of body parts loss_factory: object to orchestrate supervised loss computation loss_factory_unsupervised: object to orchestrate unsupervised loss computation backbone: ResNet or EfficientNet variant to be used downsample_factor: make heatmap smaller than original frames to save memory; subpixel operations are performed for increased precision pretrained: True to load pretrained imagenet weights torch_seed: make weight initialization reproducible lr_scheduler: how to schedule learning rate multisteplr lr_scheduler_params: params for specific learning rate schedulers multisteplr: milestones, gamma """ super().__init__( num_keypoints=num_keypoints, loss_factory=loss_factory, backbone=backbone, downsample_factor=downsample_factor, pretrained=pretrained, torch_seed=torch_seed, optimizer=optimizer, optimizer_params=optimizer_params, lr_scheduler=lr_scheduler, lr_scheduler_params=lr_scheduler_params, **kwargs, ) self.loss_factory_unsup = loss_factory_unsupervised # this attribute will be modified by AnnealWeight callback during training # self.register_buffer("total_unsupervised_importance", torch.tensor(1.0)) self.total_unsupervised_importance = torch.tensor(1.0)
[docs] def get_loss_inputs_unlabeled( self, batch_dict: UnlabeledBatchDict | MultiviewUnlabeledBatchDict ) -> dict: """Return predicted heatmaps and their softmaxes (estimated keypoints)""" # images -> heatmaps pred_heatmaps_crnn, pred_heatmaps_sf = self.forward( batch_dict["frames"], is_multiview=batch_dict["is_multiview"], ) # heatmaps -> keypoints pred_keypoints_crnn, confidence_crnn = self.head.run_subpixelmaxima(pred_heatmaps_crnn) pred_keypoints_sf, confidence_sf = self.head.run_subpixelmaxima(pred_heatmaps_sf) # undo augmentations if needed pred_keypoints_crnn = undo_affine_transform_batch( keypoints_augmented=pred_keypoints_crnn, transforms=batch_dict["transforms"], is_multiview=batch_dict["is_multiview"], ) pred_keypoints_sf = undo_affine_transform_batch( keypoints_augmented=pred_keypoints_sf, transforms=batch_dict["transforms"], is_multiview=batch_dict["is_multiview"], ) # keypoints -> original image coords keypoints pred_keypoints_crnn = convert_bbox_coords(batch_dict, pred_keypoints_crnn) pred_keypoints_sf = convert_bbox_coords(batch_dict, pred_keypoints_sf) return { "heatmaps_pred": torch.cat([pred_heatmaps_crnn, pred_heatmaps_sf], dim=0), "keypoints_pred": torch.cat([pred_keypoints_crnn, pred_keypoints_sf], dim=0), "confidences": torch.cat([confidence_crnn, confidence_sf], dim=0), }