Source code for lightning_pose.callbacks

import lightning.pytorch as pl
import torch
from lightning.pytorch.callbacks import Callback

# to ignore imports for sphix-autoapidoc
__all__ = [
    "AnnealWeight",
]


[docs]class AnnealWeight(Callback): """Callback to change weight value during training.""" def __init__( self, attr_name: str, init_val: float = 0.0, increase_factor: float = 0.01, final_val: float = 1.0, freeze_until_epoch: int = 0, ) -> None: super().__init__() self.init_val = init_val self.increase_factor = increase_factor self.final_val = final_val self.freeze_until_epoch = freeze_until_epoch self.attr_name = attr_name
[docs] def on_train_start(self, trainer, pl_module) -> None: # Dan: removed buffer; seems to complicate checkpoint loading # pl_module.register_buffer(self.attr_name, torch.tensor(self.init_val)) setattr(pl_module, self.attr_name, torch.tensor(self.init_val))
[docs] def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: if pl_module.current_epoch <= self.freeze_until_epoch: pass else: eff_epoch: int = pl_module.current_epoch - self.freeze_until_epoch value: float = min( self.init_val + eff_epoch * self.increase_factor, self.final_val ) # Dan: removed buffer; seems to complicate checkpoint loading # pl_module.register_buffer(self.attr_name, torch.tensor(value)) setattr(pl_module, self.attr_name, torch.tensor(value))
class UnfreezeBackbone(Callback): """Callback that ramps up the backbone learning rate from 0 to `upsampling_lr` on `unfreeze_epoch`. Starts LR at `initial_ratio * upsampling_lr`. Grows lr by a factor of `epoch_ratio` per epoch. Once LR reaches `upsampling_lr`, keeps it in sync with `upsampling_lr`. Use instead of pl.callbacks.BackboneFinetuning in order to use multi-GPU (DDP). See lightning-ai/pytorch-lightning#20340 for context. """ def __init__( self, unfreeze_epoch, initial_ratio=0.1, epoch_ratio=1.5, ): self.unfreeze_epoch = unfreeze_epoch self.initial_ratio = initial_ratio self.epoch_ratio = epoch_ratio self._warmed_up = False def on_train_epoch_start(self, trainer, pl_module): # This callback is only applicable to heatmap models but we # might encounter a RegressionModel. if not hasattr(pl_module, "upsampling_layers"): return # Once backbone_lr warms up to upsampling_lr, this callback does nothing. # Control of backbone lr is then the sole job of the main lr scheduler. if self._warmed_up: return optimizer = pl_module.optimizers() # Check our assumptions about param group indices assert optimizer.param_groups[0]["name"] == "backbone" assert optimizer.param_groups[1]["name"].startswith("upsampling") upsampling_lr = optimizer.param_groups[1]["lr"] optimizer.param_groups[0]["lr"] = self._get_backbone_lr( pl_module.current_epoch, upsampling_lr ) def _get_backbone_lr(self, current_epoch, upsampling_lr): assert not self._warmed_up # Before unfreeze, learning_rate is 0. if current_epoch < self.unfreeze_epoch: return 0.0 # On unfreeze, initialize learning rate. # Remember this initial value for warm up. if current_epoch == self.unfreeze_epoch: self._initial_lr = self.initial_ratio * upsampling_lr return self._initial_lr # Warm up: compute inital_ratio * epoch_ratio ** epochs_since_thaw. # Use stored initial_ratio rather than recomputing it since # upsampling_lr is subject to change via the scheduler. if current_epoch > self.unfreeze_epoch: epochs_since_thaw = current_epoch - self.unfreeze_epoch next_lr = min( self._initial_lr * self.epoch_ratio**epochs_since_thaw, upsampling_lr ) if next_lr == upsampling_lr: self._warmed_up = True return next_lr