"""Models that produce heatmaps of keypoints from images."""
from typing import Any, Dict, Optional, Tuple, Union
import torch
from kornia.filters import filter2d
from kornia.geometry.subpix import spatial_expectation2d, spatial_softmax2d
from kornia.geometry.transform.pyramid import _get_pyramid_gaussian_kernel
from omegaconf import DictConfig
from torch import nn
from torchtyping import TensorType
from typeguard import typechecked
from typing_extensions import Literal
from lightning_pose.data.utils import (
HeatmapLabeledBatchDict,
MultiviewHeatmapLabeledBatchDict,
MultiviewUnlabeledBatchDict,
UnlabeledBatchDict,
evaluate_heatmaps_at_location,
undo_affine_transform_batch,
)
from lightning_pose.losses.factory import LossFactory
from lightning_pose.losses.losses import RegressionRMSELoss
from lightning_pose.models.base import (
ALLOWED_BACKBONES,
BaseSupervisedTracker,
SemiSupervisedTrackerMixin,
convert_bbox_coords,
)
# to ignore imports for sphix-autoapidoc
__all__ = [
"upsample",
"HeatmapTracker",
"SemiSupervisedHeatmapTracker",
]
[docs]def upsample(
inputs: TensorType["batch", "num_keypoints", "heatmap_height", "heatmap_width"],
) -> TensorType["batch", "num_keypoints", "two_x_heatmap_height", "two_x_heatmap_width"]:
"""Upsample batch of heatmaps using interpolation (no learned weights).
This is a copy of kornia's pyrup function but with better defaults.
"""
kernel = _get_pyramid_gaussian_kernel()
_, _, height, width = inputs.shape
# align_corners=False is important!! otherwise the offsets below don't hold
inputs_up = nn.functional.interpolate(
inputs, size=(height * 2, width * 2), mode="bicubic", align_corners=False,
)
inputs_up = filter2d(inputs_up, kernel, border_type="constant")
return inputs_up
[docs]class HeatmapTracker(BaseSupervisedTracker):
"""Base model that produces heatmaps of keypoints from images."""
def __init__(
self,
num_keypoints: int,
num_targets: int = None,
loss_factory: Optional[LossFactory] = None,
backbone: ALLOWED_BACKBONES = "resnet50",
downsample_factor: Literal[1, 2, 3] = 2,
pretrained: bool = True,
output_shape: Optional[tuple] = None, # change
torch_seed: int = 123,
lr_scheduler: str = "multisteplr",
lr_scheduler_params: Optional[Union[DictConfig, dict]] = None,
**kwargs: Any,
) -> None:
"""Initialize a DLC-like model with resnet backbone.
Args:
num_keypoints: number of body parts
loss_factory: object to orchestrate loss computation
backbone: ResNet or EfficientNet variant to be used
downsample_factor: make heatmap smaller than original frames to save memory; subpixel
operations are performed for increased precision
pretrained: True to load pretrained imagenet weights
output_shape: hard-coded image size to avoid dynamic shape computations
torch_seed: make weight initialization reproducible
lr_scheduler: how to schedule learning rate
lr_scheduler_params: params for specific learning rate schedulers
multisteplr: milestones, gamma
"""
# for reproducible weight initialization
torch.manual_seed(torch_seed)
super().__init__(
backbone=backbone,
pretrained=pretrained,
lr_scheduler=lr_scheduler,
lr_scheduler_params=lr_scheduler_params,
**kwargs,
)
self.num_keypoints = num_keypoints
if num_targets is None:
self.num_targets = num_keypoints * 2
else:
self.num_targets = num_targets
self.loss_factory = loss_factory
# TODO: downsample_factor may be in mismatch between datamodule and model.
self.downsample_factor = downsample_factor
self.upsampling_layers = self.make_upsampling_layers()
self.initialize_upsampling_layers()
self.output_shape = output_shape
# TODO: temp=1000 works for 64x64 heatmaps, need to generalize to other shapes
self.temperature = torch.tensor(1000.0, device=self.device) # soft argmax temp
self.torch_seed = torch_seed
# use this to log auxiliary information: pixel_error on labeled data
self.rmse_loss = RegressionRMSELoss()
# necessary so we don't have to pass in model arguments when loading
# also, "loss_factory" and "loss_factory_unsupervised" cannot be pickled
# (loss_factory_unsupervised might come from SemiSupervisedHeatmapTracker.__super__().
# otherwise it's ignored, important so that it doesn't try to pickle the dali loaders)
self.save_hyperparameters(ignore=["loss_factory", "loss_factory_unsupervised"])
@property
def num_filters_for_upsampling(self) -> int:
return self.num_fc_input_features
[docs] def run_subpixelmaxima(
self,
heatmaps: TensorType["batch", "num_keypoints", "heatmap_height", "heatmap_width"],
) -> Tuple[TensorType["batch", "num_targets"], TensorType["batch", "num_keypoints"]]:
"""Use soft argmax on heatmaps.
Args:
heatmaps: output of upsampling layers
Returns:
tuple
- soft argmax of shape (batch, num_targets)
- confidences of shape (batch, num_keypoints)
"""
# upsample heatmaps
for _ in range(self.downsample_factor):
heatmaps = upsample(heatmaps)
# find soft argmax
softmaxes = spatial_softmax2d(heatmaps, temperature=self.temperature)
preds = spatial_expectation2d(softmaxes, normalized_coordinates=False)
# compute confidences as softmax value pooled around prediction
confidences = evaluate_heatmaps_at_location(heatmaps=softmaxes, locs=preds)
# fix grid offsets from upsampling
if self.downsample_factor == 1:
preds -= 0.5
elif self.downsample_factor == 2:
preds -= 1.5
elif self.downsample_factor == 3:
preds -= 2.5
# NOTE: we cannot use
# `preds.reshape(-1, self.num_targets)`
# This works fine for the non-multiview case
# This works fine for multiview training
# This fails during multiview inference when we might have an arbitrary number of views
# that we are processing (self.num_targets is tied to the labeled data)
return preds.reshape(-1, heatmaps.shape[1] * 2), confidences
[docs] def run_hard_argmax(
self,
heatmaps: TensorType["batch", "num_keypoints", "heatmap_height", "heatmap_width"],
) -> Tuple[TensorType["batch", "num_targets"], TensorType["batch", "num_keypoints"]]:
"""Use hard argmax on heatmaps.
Args:
heatmaps: output of upsampling layers
Returns:
tuple
- hard argmax of shape (batch, num_targets)
- confidences of shape (batch, num_keypoints)
"""
# upsample heatmaps
for _ in range(self.downsample_factor):
heatmaps = upsample(heatmaps)
# find hard argmax
softmaxes = spatial_softmax2d(heatmaps, temperature=self.temperature)
preds = self._spatial_argmax2d(softmaxes)
# compute confidences as softmax value pooled around prediction
confidences = evaluate_heatmaps_at_location(heatmaps=softmaxes, locs=preds)
# fix grid offsets from upsampling
if self.downsample_factor == 1:
preds -= 0.5
elif self.downsample_factor == 2:
preds -= 1.5
elif self.downsample_factor == 3:
preds -= 2.5
return preds.reshape(-1, self.num_targets), confidences
@staticmethod
def _spatial_argmax2d(heatmaps):
flat_indexes = heatmaps.flatten(start_dim=-2).argmax(-1)
B = heatmaps.shape[0]
N = heatmaps.shape[1]
peaks = torch.zeros(B, N, 2, device=heatmaps.device, dtype=torch.float32)
for i in range(B):
for j in range(N):
idxs_ = divmod(flat_indexes[i, j].item(), heatmaps.shape[-1])
peaks[i, j, 0] = idxs_[1] # x coords
peaks[i, j, 1] = idxs_[0] # y coords
return peaks
[docs] def initialize_upsampling_layers(self) -> None:
"""Intialize the Conv2DTranspose upsampling layers."""
for index, layer in enumerate(self.upsampling_layers):
if index > 0: # we ignore the PixelShuffle
if isinstance(layer, nn.ConvTranspose2d):
torch.nn.init.xavier_uniform_(layer.weight, gain=0.01)
torch.nn.init.zeros_(layer.bias)
elif isinstance(layer, nn.BatchNorm2d):
torch.nn.init.constant_(layer.weight, 1.0)
torch.nn.init.constant_(layer.bias, 0.0)
[docs] def make_upsampling_layers(self) -> torch.nn.Sequential:
# Note:
# https://github.com/jgraving/DeepPoseKit/blob/
# cecdb0c8c364ea049a3b705275ae71a2f366d4da/deepposekit/models/DeepLabCut.py#L131
# in their model, the pixel shuffle happens only for downsample_factor=2
upsampling_layers = []
upsampling_layers.append(nn.PixelShuffle(2))
# upsampling_layers.append(nn.BatchNorm2d(self.num_filters_for_upsampling // 4))
# upsampling_layers.append(nn.ReLU(inplace=True))
upsampling_layers.append(
self.create_double_upsampling_layer(
in_channels=self.num_filters_for_upsampling // 4,
out_channels=self.num_keypoints,
)
) # up to here results in downsample_factor=3
n_layers_to_build = 4 - self.downsample_factor - 1
if self.backbone_arch in ["vit_h_sam", "vit_b_sam"]:
n_layers_to_build = -1
for _ in range(n_layers_to_build):
# add upsampling layer to account for heatmap downsampling
# upsampling_layers.append(nn.BatchNorm2d(self.num_keypoints))
# upsampling_layers.append(nn.ReLU(inplace=True))
upsampling_layers.append(
self.create_double_upsampling_layer(
in_channels=self.num_keypoints,
out_channels=self.num_keypoints,
)
)
return nn.Sequential(*upsampling_layers)
[docs] @staticmethod
def create_double_upsampling_layer(
in_channels: int,
out_channels: int,
) -> torch.nn.ConvTranspose2d:
"""Perform ConvTranspose2d to double the output shape."""
return nn.ConvTranspose2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=(3, 3),
stride=(2, 2),
padding=(1, 1),
output_padding=(1, 1),
)
[docs] def heatmaps_from_representations(
self,
representations: TensorType["batch", "features", "rep_height", "rep_width"],
) -> TensorType["batch", "num_keypoints", "heatmap_height", "heatmap_width"]:
"""Upsample representations to get final heatmaps."""
heatmaps = self.upsampling_layers(representations)
return heatmaps
[docs] def forward(
self,
images: Union[
TensorType["batch", "channels":3, "image_height", "image_width"],
TensorType["batch", "views", "channels":3, "image_height", "image_width"],
]
) -> TensorType["num_valid_outputs", "num_keypoints", "heatmap_height", "heatmap_width"]:
"""Forward pass through the network."""
# we get one representation for each desired output.
shape = images.shape
# if len(shape) > 4 we assume we have multiple views and need to combine images across
# batch/views before passing to network, then we reshape
if len(shape) > 4:
images = images.reshape(-1, shape[-3], shape[-2], shape[-1])
representations = self.get_representations(images)
heatmaps = self.heatmaps_from_representations(representations)
heatmaps = heatmaps.reshape(shape[0], -1, heatmaps.shape[-2], heatmaps.shape[-1])
else:
representations = self.get_representations(images)
heatmaps = self.heatmaps_from_representations(representations)
# softmax temp stays 1 here; to modify for model predictions, see constructor
return spatial_softmax2d(heatmaps, temperature=torch.tensor([1.0]))
[docs] def predict_step(
self,
batch_dict: Union[
HeatmapLabeledBatchDict,
MultiviewHeatmapLabeledBatchDict,
UnlabeledBatchDict,
],
batch_idx: int,
return_heatmaps: Optional[bool] = False,
) -> Union[
Tuple[torch.Tensor, torch.Tensor],
Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
]:
"""Predict heatmaps and keypoints for a batch of video frames.
Assuming a DALI video loader is passed in
> trainer = Trainer(devices=8, accelerator="gpu")
> predictions = trainer.predict(model, data_loader)
"""
if "images" in batch_dict.keys(): # can't do isinstance(o, c) on TypedDicts
# labeled image dataloaders
images = batch_dict["images"]
else:
# unlabeled dali video dataloaders
images = batch_dict["frames"]
# images -> heatmaps
predicted_heatmaps = self.forward(images)
# heatmaps -> keypoints
predicted_keypoints, confidence = self.run_subpixelmaxima(predicted_heatmaps)
# bounding box coords -> original image coords
predicted_keypoints = convert_bbox_coords(batch_dict, predicted_keypoints)
if return_heatmaps:
return predicted_keypoints, confidence, predicted_heatmaps
else:
return predicted_keypoints, confidence
[docs]@typechecked
class SemiSupervisedHeatmapTracker(SemiSupervisedTrackerMixin, HeatmapTracker):
"""Model produces heatmaps of keypoints from labeled/unlabeled images."""
def __init__(
self,
num_keypoints: int,
loss_factory: Optional[LossFactory] = None,
loss_factory_unsupervised: Optional[LossFactory] = None,
backbone: ALLOWED_BACKBONES = "resnet50",
downsample_factor: Literal[1, 2, 3] = 2,
pretrained: bool = True,
output_shape: Optional[tuple] = None,
torch_seed: int = 123,
lr_scheduler: str = "multisteplr",
lr_scheduler_params: Optional[Union[DictConfig, dict]] = None,
**kwargs: Any,
) -> None:
"""
Args:
num_keypoints: number of body parts
loss_factory: object to orchestrate supervised loss computation
loss_factory_unsupervised: object to orchestrate unsupervised loss
computation
backbone: ResNet or EfficientNet variant to be used
downsample_factor: make heatmap smaller than original frames to
save memory; subpixel operations are performed for increased
precision
pretrained: True to load pretrained imagenet weights
output_shape: hard-coded image size to avoid dynamic shape
computations
torch_seed: make weight initialization reproducible
lr_scheduler: how to schedule learning rate
multisteplr
lr_scheduler_params: params for specific learning rate schedulers
multisteplr: milestones, gamma
"""
super().__init__(
num_keypoints=num_keypoints,
loss_factory=loss_factory,
backbone=backbone,
downsample_factor=downsample_factor,
pretrained=pretrained,
output_shape=output_shape,
torch_seed=torch_seed,
lr_scheduler=lr_scheduler,
lr_scheduler_params=lr_scheduler_params,
**kwargs,
)
self.loss_factory_unsup = loss_factory_unsupervised
# this attribute will be modified by AnnealWeight callback during training
# self.register_buffer("total_unsupervised_importance", torch.tensor(1.0))
self.total_unsupervised_importance = torch.tensor(1.0)