Source code for lightning_pose.callbacks

import lightning.pytorch as pl
import torch
from lightning.pytorch.callbacks import Callback

# to ignore imports for sphix-autoapidoc
__all__ = [
    "AnnealWeight",
    "UnfreezeBackbone",
]


[docs] class AnnealWeight(Callback): """Callback to change weight value during training."""
[docs] def __init__( self, attr_name: str, init_val: float = 0.0, increase_factor: float = 0.01, final_val: float = 1.0, freeze_until_epoch: int = 0, ) -> None: super().__init__() self.init_val = init_val self.increase_factor = increase_factor self.final_val = final_val self.freeze_until_epoch = freeze_until_epoch self.attr_name = attr_name
[docs] def on_train_start(self, trainer, pl_module) -> None: # Dan: removed buffer; seems to complicate checkpoint loading # pl_module.register_buffer(self.attr_name, torch.tensor(self.init_val)) setattr(pl_module, self.attr_name, torch.tensor(self.init_val))
[docs] def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: if pl_module.current_epoch <= self.freeze_until_epoch: pass else: eff_epoch: int = pl_module.current_epoch - self.freeze_until_epoch value: float = min( self.init_val + eff_epoch * self.increase_factor, self.final_val ) # Dan: removed buffer; seems to complicate checkpoint loading # pl_module.register_buffer(self.attr_name, torch.tensor(value)) setattr(pl_module, self.attr_name, torch.tensor(value))
[docs] class UnfreezeBackbone(Callback): """Callback that ramps up the backbone learning rate from 0 to `upsampling_lr` on `unfreeze_epoch` or `unfreeze_step`. Starts LR at `initial_ratio * upsampling_lr`. Grows lr by a factor of `warm_up_ratio` per epoch or step. Once LR reaches `upsampling_lr`, keeps it in sync with `upsampling_lr`. Use instead of pl.callbacks.BackboneFinetuning in order to use multi-GPU (DDP). See lightning-ai/pytorch-lightning#20340 for context. """
[docs] def __init__( self, unfreeze_epoch: int | None = None, unfreeze_step: int | None = None, initial_ratio=0.1, warm_up_ratio=1.5, ): assert (unfreeze_epoch is None) != ( unfreeze_step is None ), "Exactly one must be provided." self.unfreeze_epoch = unfreeze_epoch self.unfreeze_step = unfreeze_step self.initial_ratio = initial_ratio self.warm_up_ratio = warm_up_ratio self._warmed_up = False
[docs] def on_train_batch_start(self, trainer, pl_module, batch, batch_idx): # Once backbone_lr warms up to upsampling_lr, this callback does nothing. # Control of backbone lr is then the sole job of the main lr scheduler. if self._warmed_up: return optimizer = pl_module.optimizers() # Check our assumptions about param group indices assert optimizer.param_groups[0]["name"] == "backbone" head_lr = optimizer.param_groups[1]["lr"] optimizer.param_groups[0]["lr"] = self._get_backbone_lr( pl_module.global_step, pl_module.current_epoch, head_lr )
def _get_backbone_lr(self, current_step, current_epoch, upsampling_lr): """Returns what the backbone LR should be at this point in time. Args: Only one of `current_step` and `current_epoch` will be used. If self.unfreeze_epoch is not None, then we'll use `current_epoch` Otherwise, unfreeze_step is not None and we'll use `current_step`. """ assert not self._warmed_up # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # In the code below, the variables are named in terms of epoch, # but the same logic applies for steps, conveniently enough. # So if we're in "step mode", plug in steps into epoch variables. # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # unfreeze_epoch = self.unfreeze_epoch if self.unfreeze_step is not None: unfreeze_epoch = self.unfreeze_step current_epoch = current_step # After this point, use `unfreeze_epoch` instead of `self.unfreeze_[epoch|step]`. # Main logic begins: # Before unfreeze, learning_rate is 0. if current_epoch < unfreeze_epoch: return 0.0 # On unfreeze, initialize learning rate. # Remember this initial value for warm up. if current_epoch == unfreeze_epoch: self._initial_lr = self.initial_ratio * upsampling_lr return self._initial_lr # Warm up: compute inital_ratio * epoch_ratio ** epochs_since_thaw. # Use stored initial_ratio rather than recomputing it since # upsampling_lr is subject to change via the scheduler. if current_epoch > unfreeze_epoch: epochs_since_thaw = current_epoch - unfreeze_epoch next_lr = min( self._initial_lr * self.warm_up_ratio**epochs_since_thaw, upsampling_lr ) if next_lr == upsampling_lr: self._warmed_up = True return next_lr