Source code for lightning_pose.models.heads.heatmap

"""Heads that produce heatmap predictions for heatmap regression."""

from typing import Tuple

import torch
from kornia.filters import filter2d
from kornia.geometry.subpix import spatial_expectation2d, spatial_softmax2d
from kornia.geometry.transform.pyramid import _get_pyramid_gaussian_kernel
from torch import nn
from torchtyping import TensorType

from lightning_pose.data.utils import evaluate_heatmaps_at_location

# to ignore imports for sphix-autoapidoc
__all__ = [
    "make_upsampling_layers",
    "initialize_upsampling_layers",
    "upsample",
    "run_subpixelmaxima",
    "HeatmapHead",
]


[docs] def make_upsampling_layers( in_channels: int, out_channels: int, int_channels: int, n_layers: int, ) -> torch.nn.Sequential: # Note: # https://github.com/jgraving/DeepPoseKit/blob/ # cecdb0c8c364ea049a3b705275ae71a2f366d4da/deepposekit/models/DeepLabCut.py#L131 # in their model, the pixel shuffle happens only for downsample_factor=2 upsampling_layers = [] upsampling_layers.append(nn.PixelShuffle(2)) for layer in range(n_layers): if layer == 0: in_ = in_channels // 4 # division by 4 to account for PixelShuffle layer out_ = int_channels elif layer == n_layers - 1: in_ = int_channels if n_layers > 1 else in_channels // 4 out_ = out_channels else: in_ = int_channels out_ = int_channels upsampling_layers.append( nn.ConvTranspose2d( in_channels=in_, out_channels=out_, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1), ) ) return nn.Sequential(*upsampling_layers)
[docs] def initialize_upsampling_layers(layers) -> None: """Intialize the Conv2DTranspose upsampling layers.""" for index, layer in enumerate(layers): if index > 0: # we ignore the PixelShuffle if isinstance(layer, nn.ConvTranspose2d): torch.nn.init.xavier_uniform_(layer.weight, gain=0.01) torch.nn.init.zeros_(layer.bias) elif isinstance(layer, nn.BatchNorm2d): torch.nn.init.constant_(layer.weight, 1.0) torch.nn.init.constant_(layer.bias, 0.0)
[docs] def upsample( inputs: TensorType["batch", "num_keypoints", "heatmap_height", "heatmap_width"], ) -> TensorType["batch", "num_keypoints", "two_x_heatmap_height", "two_x_heatmap_width"]: """Upsample batch of heatmaps by a factor of two using interpolation (no learned weights). This is a copy of kornia's pyrup function but with better defaults. """ kernel = _get_pyramid_gaussian_kernel() _, _, height, width = inputs.shape # align_corners=False is important!! otherwise the offsets below don't hold inputs_up = nn.functional.interpolate( inputs, size=(height * 2, width * 2), mode="bicubic", align_corners=False, ) inputs_up = filter2d(inputs_up, kernel, border_type="constant") return inputs_up
[docs] def run_subpixelmaxima( heatmaps: TensorType["batch", "num_keypoints", "heatmap_height", "heatmap_width"], downsample_factor: int, temperature: torch.tensor, ) -> Tuple[TensorType["batch", "num_targets"], TensorType["batch", "num_keypoints"]]: """Use soft argmax on heatmaps. Args: heatmaps: output of upsampling layers downsample_factor: controls how many times upsampling needs to be performed temperature: temperature parameter of softmax; higher leads to tighter peaks Returns: tuple - soft argmax of shape (batch, num_targets) - confidences of shape (batch, num_keypoints) """ # upsample heatmaps for _ in range(downsample_factor): heatmaps = upsample(heatmaps) # find soft argmax softmaxes = spatial_softmax2d(heatmaps, temperature=temperature) preds = spatial_expectation2d(softmaxes, normalized_coordinates=False) # compute confidences as softmax value pooled around prediction confidences = evaluate_heatmaps_at_location(heatmaps=softmaxes, locs=preds) # fix grid offsets from upsampling if downsample_factor == 1: preds -= 0.5 elif downsample_factor == 2: preds -= 1.5 elif downsample_factor == 3: preds -= 2.5 # NOTE: we cannot use # `preds.reshape(-1, self.num_targets)` # This works fine for the non-multiview case # This works fine for multiview training # This fails during multiview inference when we might have an arbitrary number of views # that we are processing (self.num_targets is tied to the labeled data) return preds.reshape(-1, heatmaps.shape[1] * 2), confidences
[docs] class HeatmapHead(nn.Module): """Simple deconvolution head that converts 2D feature maps to per-keypoint heatmaps. This is the standard heatmap head used in the Lightning Pose package. The head is composed of several deconvolutional layers followed by a 2D spatial softmax to generate normalized heatmaps from low-resolution feature maps. """
[docs] def __init__( self, backbone_arch: str, in_channels: int, out_channels: int, deconv_out_channels: int | None = None, downsample_factor: int = 2, ): """ Args: backbone_arch: string denoting backbone architecture; to remove in future release in_channels: number of channels in the input feature map out_channels: number of channels in the output heatmap (i.e. number of keypoints) deconv_out_channels: output channel number for each intermediate deconv layer; defaults to number of keypoints downsample_factor: make heatmaps smaller than input frames by this factor; subpixel operations are performed for increased precision """ super().__init__() self.backbone_arch = backbone_arch self.in_channels = in_channels self.out_channels = out_channels self.deconv_out_channels = deconv_out_channels self.downsample_factor = downsample_factor # TODO: temp=1000 works for 64x64 heatmaps, need to generalize to other shapes self.temperature = torch.tensor(1000.0) # soft argmax temp n_layers = 4 - self.downsample_factor if self.backbone_arch in ["vit_h_sam", "vit_b_sam"]: n_layers -= 1 self.upsampling_layers = make_upsampling_layers( in_channels=in_channels, out_channels=out_channels, int_channels=deconv_out_channels or out_channels, n_layers=n_layers, ) initialize_upsampling_layers(self.upsampling_layers)
[docs] def forward( self, features: TensorType["batch", "features", "features_height", "features_width"], ) -> TensorType["batch", "num_keypoints", "heatmap_height", "heatmap_width"]: """Upsample representations and normalize to get final heatmaps.""" heatmaps = self.upsampling_layers(features) # softmax temp stays 1 here; to modify for model predictions, see constructor return spatial_softmax2d(heatmaps, temperature=torch.tensor([1.0]))
[docs] def run_subpixelmaxima(self, heatmaps): return run_subpixelmaxima(heatmaps, self.downsample_factor, self.temperature)