AnnealWeight

class lightning_pose.callbacks.AnnealWeight[source]

Bases: Callback

Callback to change weight value during training.

Methods Summary

on_train_epoch_start(trainer, pl_module)

Increment the annealed weight attribute at the start of each training epoch.

on_train_start(trainer, pl_module)

Set the annealed weight attribute to its initial value at training start.

Methods Documentation

on_train_epoch_start(trainer: Trainer, pl_module: LightningModule) None[source]

Increment the annealed weight attribute at the start of each training epoch.

on_train_start(trainer: Trainer, pl_module: LightningModule) None[source]

Set the annealed weight attribute to its initial value at training start.

__init__(attr_name: str, init_val: float = 0.0, increase_factor: float = 0.01, final_val: float = 1.0, freeze_until_epoch: int = 0) None[source]

Initialize AnnealWeight callback.

Parameters:
  • attr_name – name of the attribute on the pl_module to update each epoch.

  • init_val – initial value of the weight.

  • increase_factor – amount to increase the weight per epoch after unfreezing.

  • final_val – maximum value the weight can reach.

  • freeze_until_epoch – epoch at which the weight begins to increase.

__new__(**kwargs)