Source code for lightning_pose.models.regression_tracker

"""Models that produce (x, y) coordinates of keypoints from images."""

from typing import Any

import torch
from jaxtyping import Float
from omegaconf import DictConfig, ListConfig

from lightning_pose.data.datatypes import BaseLabeledBatchDict, UnlabeledBatchDict
from lightning_pose.data.utils import undo_affine_transform
from lightning_pose.losses.factory import LossFactory
from lightning_pose.losses.losses import RegressionRMSELoss
from lightning_pose.models.backbones import ALLOWED_BACKBONES
from lightning_pose.models.base import BaseSupervisedTracker, SemiSupervisedTrackerMixin
from lightning_pose.models.heads import LinearRegressionHead

# to ignore imports for sphix-autoapidoc
__all__ = []


class RegressionTracker(BaseSupervisedTracker):
    """Base model that produces (x, y) predictions of keypoints from images."""

[docs] def __init__( self, num_keypoints: int, loss_factory: LossFactory | None = None, backbone: ALLOWED_BACKBONES = "resnet50", pretrained: bool = True, torch_seed: int = 123, optimizer: str = "Adam", optimizer_params: DictConfig | ListConfig | dict | None = None, lr_scheduler: str = "multisteplr", lr_scheduler_params: DictConfig | ListConfig | dict | None = None, **kwargs: Any, ) -> None: """Base model that produces (x, y) coordinates of keypoints from images. Args: num_keypoints: number of body parts loss_factory: object to orchestrate loss computation backbone: ResNet or EfficientNet variant to be used pretrained: True to load pretrained imagenet weights torch_seed: make weight initialization reproducible lr_scheduler: how to schedule learning rate multisteplr lr_scheduler_params: params for specific learning rate schedulers multisteplr: milestones, gamma """ # for reproducible weight initialization self.torch_seed = torch_seed torch.manual_seed(torch_seed) if "vit" in backbone: raise ValueError("Regression trackers are not compatible with ViT backbones") # for backwards compatibility if "do_context" in kwargs.keys(): del kwargs["do_context"] super().__init__( backbone=backbone, pretrained=pretrained, optimizer=optimizer, optimizer_params=optimizer_params, lr_scheduler=lr_scheduler, lr_scheduler_params=lr_scheduler_params, model_type="regression", **kwargs, ) self.num_keypoints = num_keypoints self.num_targets = self.num_keypoints * 2 self.loss_factory = loss_factory self.head = LinearRegressionHead(self.num_fc_input_features, self.num_targets) # use this to log auxiliary information: pixel_error on labeled data self.rmse_loss = RegressionRMSELoss() # necessary so we don't have to pass in model arguments when loading # also, "loss_factory" and "loss_factory_unsupervised" cannot be pickled # (loss_factory_unsupervised might come from SemiSupervisedRegressionTracker.__super__(). # otherwise it's ignored, important so that it doesn't try to pickle the dali loaders) self.save_hyperparameters(ignore=["loss_factory", "loss_factory_unsupervised"])
[docs] def forward( self, images: Float[torch.Tensor, "batch channels image_height image_width"] ) -> Float[torch.Tensor, "batch two_x_num_keypoints"]: """Forward pass through the network.""" # see input lines for shape of "images" representations = self.get_representations(images) # "representations" is shape (batch, features, rep_height, rep_width) out = self.head(representations) # "out" is shape (batch, 2 * num_keypoints) return out
[docs] def get_loss_inputs_labeled(self, batch_dict: BaseLabeledBatchDict) -> dict: """Return predicted coordinates for a batch of data.""" predicted_keypoints = self.forward(batch_dict["images"]) return { "keypoints_targ": batch_dict["keypoints"], "keypoints_pred": predicted_keypoints, }
[docs] def predict_step( self, batch_dict: BaseLabeledBatchDict | UnlabeledBatchDict, batch_idx: int, **kwargs: Any, ) -> tuple[torch.Tensor, torch.Tensor]: """Predict keypoints for a batch of video frames. Assuming a DALI video loader is passed in > trainer = Trainer(devices=8, accelerator="gpu") > predictions = trainer.predict(model, data_loader) """ if "images" in batch_dict.keys(): # can't do isinstance(o, c) on TypedDicts # labeled image dataloaders images = batch_dict["images"] # type: ignore[typeddict-item] else: # unlabeled dali video dataloaders images = batch_dict["frames"] # type: ignore[typeddict-item] # images -> keypoints predicted_keypoints = self.forward(images) # regression model does not include a notion of confidence, set to all zeros confidence = torch.zeros((predicted_keypoints.shape[0], predicted_keypoints.shape[1] // 2)) return predicted_keypoints, confidence
[docs] def get_parameters(self) -> list[dict]: """Return per-parameter-group optimizer configuration for backbone and head. Returns: List of dicts with ``"params"``, ``"name"``, and optionally ``"lr"`` keys; the backbone starts with learning rate 0 (frozen until unfreezing). """ params = [ {"params": self.backbone.parameters(), "lr": 0, "name": "backbone"}, {"params": self.head.parameters(), "name": "head"}, ] return params
class SemiSupervisedRegressionTracker(SemiSupervisedTrackerMixin, RegressionTracker): """Model produces vectors of keypoints from labeled/unlabeled images."""
[docs] def __init__( self, num_keypoints: int, loss_factory: LossFactory | None = None, loss_factory_unsupervised: LossFactory | None = None, backbone: ALLOWED_BACKBONES = "resnet50", pretrained: bool = True, torch_seed: int = 123, optimizer: str = "Adam", optimizer_params: DictConfig | ListConfig | dict | None = None, lr_scheduler: str = "multisteplr", lr_scheduler_params: DictConfig | ListConfig | dict | None = None, **kwargs: Any, ) -> None: """ Args: num_keypoints: number of body parts loss_factory: object to orchestrate supervised loss computation loss_factory_unsupervised: object to orchestrate unsupervised loss computation backbone: ResNet or EfficientNet variant to be used pretrained: True to load pretrained imagenet weights torch_seed: make weight initialization reproducible lr_scheduler: how to schedule learning rate multisteplr lr_scheduler_params: params for specific learning rate schedulers multisteplr: milestones, gamma do_context: use temporal context frames to improve predictions """ super().__init__( num_keypoints=num_keypoints, loss_factory=loss_factory, backbone=backbone, pretrained=pretrained, torch_seed=torch_seed, optimizer=optimizer, optimizer_params=optimizer_params, lr_scheduler=lr_scheduler, lr_scheduler_params=lr_scheduler_params, **kwargs, ) self.loss_factory_unsup = loss_factory_unsupervised assert loss_factory_unsupervised is not None loss_names = loss_factory_unsupervised.loss_instance_dict.keys() if "unimodal_mse" in loss_names or "unimodal_wasserstein" in loss_names: raise ValueError("cannot use unimodal loss in regression tracker") # this attribute will be modified by AnnealWeight callback during training self.total_unsupervised_importance = torch.tensor(1.0)
[docs] def get_loss_inputs_unlabeled(self, batch_dict: UnlabeledBatchDict) -> dict: """Return predicted heatmaps and their softmaxes (estimated keypoints).""" predicted_keypoints = self.forward(batch_dict["frames"]) # undo augmentation if needed if batch_dict["transforms"].shape[-1] == 3: # reshape to (seq_len, n_keypoints, 2) pred_kps = torch.reshape(predicted_keypoints, (predicted_keypoints.shape[0], -1, 2)) # undo pred_kps = undo_affine_transform(pred_kps, batch_dict["transforms"]) # reshape to (seq_len, n_keypoints * 2) predicted_keypoints = torch.reshape(pred_kps, (pred_kps.shape[0], -1)) return {"keypoints_pred": predicted_keypoints}