AnnealWeight
- class lightning_pose.callbacks.AnnealWeight[source]
Bases:
CallbackCallback to change weight value during training.
Methods Summary
on_train_epoch_start(trainer, pl_module)Increment the annealed weight attribute at the start of each training epoch.
on_train_start(trainer, pl_module)Set the annealed weight attribute to its initial value at training start.
Methods Documentation
- on_train_epoch_start(trainer: Trainer, pl_module: LightningModule) None[source]
Increment the annealed weight attribute at the start of each training epoch.
- on_train_start(trainer: Trainer, pl_module: LightningModule) None[source]
Set the annealed weight attribute to its initial value at training start.
- __init__(attr_name: str, init_val: float = 0.0, increase_factor: float = 0.01, final_val: float = 1.0, freeze_until_epoch: int = 0) None[source]
Initialize AnnealWeight callback.
- Parameters:
attr_name – name of the attribute on the pl_module to update each epoch.
init_val – initial value of the weight.
increase_factor – amount to increase the weight per epoch after unfreezing.
final_val – maximum value the weight can reach.
freeze_until_epoch – epoch at which the weight begins to increase.
- __new__(**kwargs)