"""Models that produce (x, y) coordinates of keypoints from images."""
from typing import Any
import torch
from jaxtyping import Float
from omegaconf import DictConfig, ListConfig
from lightning_pose.data.datatypes import BaseLabeledBatchDict, UnlabeledBatchDict
from lightning_pose.data.utils import undo_affine_transform
from lightning_pose.losses.factory import LossFactory
from lightning_pose.losses.losses import RegressionRMSELoss
from lightning_pose.models.backbones import ALLOWED_BACKBONES
from lightning_pose.models.base import BaseSupervisedTracker, SemiSupervisedTrackerMixin
from lightning_pose.models.heads import LinearRegressionHead
# to ignore imports for sphix-autoapidoc
__all__ = []
class RegressionTracker(BaseSupervisedTracker):
"""Base model that produces (x, y) predictions of keypoints from images."""
[docs]
def __init__(
self,
num_keypoints: int,
loss_factory: LossFactory | None = None,
backbone: ALLOWED_BACKBONES = "resnet50",
pretrained: bool = True,
torch_seed: int = 123,
optimizer: str = "Adam",
optimizer_params: DictConfig | ListConfig | dict | None = None,
lr_scheduler: str = "multisteplr",
lr_scheduler_params: DictConfig | ListConfig | dict | None = None,
**kwargs: Any,
) -> None:
"""Base model that produces (x, y) coordinates of keypoints from images.
Args:
num_keypoints: number of body parts
loss_factory: object to orchestrate loss computation
backbone: ResNet or EfficientNet variant to be used
pretrained: True to load pretrained imagenet weights
torch_seed: make weight initialization reproducible
lr_scheduler: how to schedule learning rate
multisteplr
lr_scheduler_params: params for specific learning rate schedulers
multisteplr: milestones, gamma
"""
# for reproducible weight initialization
self.torch_seed = torch_seed
torch.manual_seed(torch_seed)
if "vit" in backbone:
raise ValueError("Regression trackers are not compatible with ViT backbones")
# for backwards compatibility
if "do_context" in kwargs.keys():
del kwargs["do_context"]
super().__init__(
backbone=backbone,
pretrained=pretrained,
optimizer=optimizer,
optimizer_params=optimizer_params,
lr_scheduler=lr_scheduler,
lr_scheduler_params=lr_scheduler_params,
model_type="regression",
**kwargs,
)
self.num_keypoints = num_keypoints
self.num_targets = self.num_keypoints * 2
self.loss_factory = loss_factory
self.head = LinearRegressionHead(self.num_fc_input_features, self.num_targets)
# use this to log auxiliary information: pixel_error on labeled data
self.rmse_loss = RegressionRMSELoss()
# necessary so we don't have to pass in model arguments when loading
# also, "loss_factory" and "loss_factory_unsupervised" cannot be pickled
# (loss_factory_unsupervised might come from SemiSupervisedRegressionTracker.__super__().
# otherwise it's ignored, important so that it doesn't try to pickle the dali loaders)
self.save_hyperparameters(ignore=["loss_factory", "loss_factory_unsupervised"])
[docs]
def forward(
self,
images: Float[torch.Tensor, "batch channels image_height image_width"]
) -> Float[torch.Tensor, "batch two_x_num_keypoints"]:
"""Forward pass through the network."""
# see input lines for shape of "images"
representations = self.get_representations(images)
# "representations" is shape (batch, features, rep_height, rep_width)
out = self.head(representations)
# "out" is shape (batch, 2 * num_keypoints)
return out
[docs]
def predict_step(
self,
batch_dict: BaseLabeledBatchDict | UnlabeledBatchDict,
batch_idx: int,
**kwargs: Any,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Predict keypoints for a batch of video frames.
Assuming a DALI video loader is passed in
> trainer = Trainer(devices=8, accelerator="gpu")
> predictions = trainer.predict(model, data_loader)
"""
if "images" in batch_dict.keys(): # can't do isinstance(o, c) on TypedDicts
# labeled image dataloaders
images = batch_dict["images"] # type: ignore[typeddict-item]
else:
# unlabeled dali video dataloaders
images = batch_dict["frames"] # type: ignore[typeddict-item]
# images -> keypoints
predicted_keypoints = self.forward(images)
# regression model does not include a notion of confidence, set to all zeros
confidence = torch.zeros((predicted_keypoints.shape[0], predicted_keypoints.shape[1] // 2))
return predicted_keypoints, confidence
[docs]
def get_parameters(self) -> list[dict]:
"""Return per-parameter-group optimizer configuration for backbone and head.
Returns:
List of dicts with ``"params"``, ``"name"``, and optionally ``"lr"`` keys; the
backbone starts with learning rate 0 (frozen until unfreezing).
"""
params = [
{"params": self.backbone.parameters(), "lr": 0, "name": "backbone"},
{"params": self.head.parameters(), "name": "head"},
]
return params
class SemiSupervisedRegressionTracker(SemiSupervisedTrackerMixin, RegressionTracker):
"""Model produces vectors of keypoints from labeled/unlabeled images."""
[docs]
def __init__(
self,
num_keypoints: int,
loss_factory: LossFactory | None = None,
loss_factory_unsupervised: LossFactory | None = None,
backbone: ALLOWED_BACKBONES = "resnet50",
pretrained: bool = True,
torch_seed: int = 123,
optimizer: str = "Adam",
optimizer_params: DictConfig | ListConfig | dict | None = None,
lr_scheduler: str = "multisteplr",
lr_scheduler_params: DictConfig | ListConfig | dict | None = None,
**kwargs: Any,
) -> None:
"""
Args:
num_keypoints: number of body parts
loss_factory: object to orchestrate supervised loss computation
loss_factory_unsupervised: object to orchestrate unsupervised loss
computation
backbone: ResNet or EfficientNet variant to be used
pretrained: True to load pretrained imagenet weights
torch_seed: make weight initialization reproducible
lr_scheduler: how to schedule learning rate
multisteplr
lr_scheduler_params: params for specific learning rate schedulers
multisteplr: milestones, gamma
do_context: use temporal context frames to improve predictions
"""
super().__init__(
num_keypoints=num_keypoints,
loss_factory=loss_factory,
backbone=backbone,
pretrained=pretrained,
torch_seed=torch_seed,
optimizer=optimizer,
optimizer_params=optimizer_params,
lr_scheduler=lr_scheduler,
lr_scheduler_params=lr_scheduler_params,
**kwargs,
)
self.loss_factory_unsup = loss_factory_unsupervised
assert loss_factory_unsupervised is not None
loss_names = loss_factory_unsupervised.loss_instance_dict.keys()
if "unimodal_mse" in loss_names or "unimodal_wasserstein" in loss_names:
raise ValueError("cannot use unimodal loss in regression tracker")
# this attribute will be modified by AnnealWeight callback during training
self.total_unsupervised_importance = torch.tensor(1.0)