"""Custom Lightning callbacks for training schedule, backbone unfreezing, and augmentation."""
import json
import os
import time
from pathlib import Path
from typing import Any
import lightning.pytorch as pl
import torch
from lightning import LightningModule, Trainer
from lightning.pytorch.callbacks import (
Callback,
EarlyStopping,
LearningRateMonitor,
ModelCheckpoint,
)
from omegaconf import DictConfig, ListConfig
# to ignore imports for sphix-autoapidoc
__all__ = [
"AnnealWeight",
"UnfreezeBackbone",
"PatchMasking",
"get_callbacks",
]
[docs]
class AnnealWeight(Callback):
"""Callback to change weight value during training."""
[docs]
def __init__(
self,
attr_name: str,
init_val: float = 0.0,
increase_factor: float = 0.01,
final_val: float = 1.0,
freeze_until_epoch: int = 0,
) -> None:
"""Initialize AnnealWeight callback.
Args:
attr_name: name of the attribute on the pl_module to update each epoch.
init_val: initial value of the weight.
increase_factor: amount to increase the weight per epoch after unfreezing.
final_val: maximum value the weight can reach.
freeze_until_epoch: epoch at which the weight begins to increase.
"""
super().__init__()
self.init_val = init_val
self.increase_factor = increase_factor
self.final_val = final_val
self.freeze_until_epoch = freeze_until_epoch
self.attr_name = attr_name
[docs]
def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
"""Set the annealed weight attribute to its initial value at training start."""
# Dan: removed buffer; seems to complicate checkpoint loading
# pl_module.register_buffer(self.attr_name, torch.tensor(self.init_val))
setattr(pl_module, self.attr_name, torch.tensor(self.init_val))
[docs]
def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Increment the annealed weight attribute at the start of each training epoch."""
if pl_module.current_epoch <= self.freeze_until_epoch:
pass
else:
eff_epoch: int = pl_module.current_epoch - self.freeze_until_epoch
value: float = min(
self.init_val + eff_epoch * self.increase_factor, self.final_val
)
# Dan: removed buffer; seems to complicate checkpoint loading
# pl_module.register_buffer(self.attr_name, torch.tensor(value))
setattr(pl_module, self.attr_name, torch.tensor(value))
[docs]
class UnfreezeBackbone(Callback):
"""Callback that ramps up the backbone learning rate from 0 to `upsampling_lr` on
`unfreeze_epoch` or `unfreeze_step`.
Starts LR at `initial_ratio * upsampling_lr`. Grows lr by a factor of `warm_up_ratio` per
epoch or step. Once LR reaches `upsampling_lr`, keeps it in sync with `upsampling_lr`.
Use instead of pl.callbacks.BackboneFinetuning in order to use multi-GPU (DDP). See
lightning-ai/pytorch-lightning#20340 for context.
"""
_initial_lr: float
[docs]
def __init__(
self,
unfreeze_epoch: int | None = None,
unfreeze_step: int | None = None,
initial_ratio: float = 0.1,
warm_up_ratio: float = 1.5,
) -> None:
"""Initialize UnfreezeBackbone callback.
Exactly one of ``unfreeze_epoch`` or ``unfreeze_step`` must be provided.
Args:
unfreeze_epoch: epoch at which to begin unfreezing the backbone.
unfreeze_step: global step at which to begin unfreezing the backbone.
initial_ratio: backbone LR starts at ``initial_ratio * upsampling_lr``.
warm_up_ratio: backbone LR is multiplied by this factor each epoch/step during warm-up.
"""
assert (unfreeze_epoch is None) != (
unfreeze_step is None
), "Exactly one must be provided."
self.unfreeze_epoch = unfreeze_epoch
self.unfreeze_step = unfreeze_step
self.initial_ratio = initial_ratio
self.warm_up_ratio = warm_up_ratio
self._warmed_up = False
[docs]
def on_train_batch_start(
self,
trainer: Trainer,
pl_module: LightningModule,
batch: Any,
batch_idx: int,
) -> None:
"""Adjust the backbone learning rate at the start of each training batch."""
# Once backbone_lr warms up to upsampling_lr, this callback does nothing.
# Control of backbone lr is then the sole job of the main lr scheduler.
if self._warmed_up:
return
optimizer = pl_module.optimizers()
assert not isinstance(optimizer, list)
# Check our assumptions about param group indices
assert optimizer.param_groups[0]["name"] == "backbone"
head_lr = optimizer.param_groups[1]["lr"]
optimizer.param_groups[0]["lr"] = self._get_backbone_lr(
pl_module.global_step, pl_module.current_epoch, head_lr
)
def _get_backbone_lr(
self, current_step: int | None, current_epoch: int, upsampling_lr: float
) -> float:
"""Returns what the backbone LR should be at this point in time.
Args:
Only one of `current_step` and `current_epoch` will be used.
If self.unfreeze_epoch is not None, then we'll use `current_epoch`
Otherwise, unfreeze_step is not None and we'll use `current_step`.
"""
assert not self._warmed_up
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# In the code below, the variables are named in terms of epoch,
# but the same logic applies for steps, conveniently enough.
# So if we're in "step mode", plug in steps into epoch variables.
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
unfreeze_epoch = self.unfreeze_epoch
if self.unfreeze_step is not None:
unfreeze_epoch = self.unfreeze_step
assert current_step is not None
current_epoch = current_step
# After this point, use `unfreeze_epoch` instead of `self.unfreeze_[epoch|step]`.
# Main logic begins:
assert unfreeze_epoch is not None
# Before unfreeze, learning_rate is 0.
if current_epoch < unfreeze_epoch:
return 0.0
# On unfreeze, initialize learning rate.
# Remember this initial value for warm up.
if current_epoch == unfreeze_epoch:
self._initial_lr = self.initial_ratio * upsampling_lr
return self._initial_lr
# Warm up: compute inital_ratio * epoch_ratio ** epochs_since_thaw.
# Use stored initial_ratio rather than recomputing it since
# upsampling_lr is subject to change via the scheduler.
else: # current_epoch > unfreeze_epoch
epochs_since_thaw = current_epoch - unfreeze_epoch
next_lr = min(
self._initial_lr * self.warm_up_ratio**epochs_since_thaw, upsampling_lr
)
if next_lr == upsampling_lr:
self._warmed_up = True
return next_lr
[docs]
class PatchMasking(Callback):
"""Callback to apply curriculum patch masking during training."""
[docs]
def __init__(
self,
patch_mask_config: dict | None = None,
patch_seed: int = 0,
) -> None:
"""Initialize PatchMasking callback.
Args:
patch_mask_config: dictionary configuring the masking curriculum, with optional keys
``init_step``, ``final_step``, ``init_ratio``, and ``final_ratio``.
patch_seed: seed for reproducible patch selection.
"""
super().__init__()
# Initialize curriculum masking
self.curriculum_masking = PatchMasker(
patch_mask_config=patch_mask_config,
patch_seed=patch_seed,
)
[docs]
def on_train_batch_start(
self,
trainer: Trainer,
pl_module: LightningModule,
batch: Any,
batch_idx: int,
) -> None:
"""Apply patch masking to the batch before it goes to the model."""
if not self.curriculum_masking.use_patch_masking:
return
# Extract images from batch
if isinstance(batch, dict):
if "images" in batch:
images = batch["images"]
elif "frames" in batch:
images = batch["frames"]
else:
return
else:
# Handle case where batch is just images
images = batch
# Apply masking
masked_images, patch_mask = self.curriculum_masking.apply_patch_masking(
images,
training_step=trainer.global_step,
is_training=True,
)
# Update the batch with masked images
if isinstance(batch, dict):
if "images" in batch:
batch["images"] = masked_images
elif "frames" in batch:
batch["frames"] = masked_images
else:
# Replace the batch entirely
batch = masked_images
# Store patch mask for potential use in loss computation
pl_module.current_patch_mask = patch_mask
[docs]
def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
"""Log curriculum progress."""
if not self.curriculum_masking.use_patch_masking:
return
schedule_info = self.curriculum_masking.get_training_schedule_info(trainer.global_step)
# Log curriculum information - only log numeric values
pl_module.log(
"patch_mask_ratio", schedule_info['mask_ratio'],
on_step=False, on_epoch=True, prog_bar=True,
)
class PatchMasker:
"""Handles curriculum learning and masking for multiview transformer training."""
def __init__(
self,
patch_mask_config: dict | None = None,
patch_seed: int = 0,
) -> None:
"""Initialize curriculum masking parameters.
Args:
patch_mask_config: Dictionary containing patch masking configuration
- init_step: Step to start patch masking
- final_step: Step when patch masking reaches maximum
- init_ratio: Initial masking ratio
- final_ratio: Final masking ratio
patch_seed: Seed for deterministic patch masking to allow reproducibility.
"""
self.patch_seed = patch_seed
# Parse patch masking configuration
if patch_mask_config is None:
patch_mask_config = {}
self.patch_init_step = patch_mask_config.get("init_step", 700)
self.patch_final_step = patch_mask_config.get("final_step", 5000)
self.patch_init_ratio = patch_mask_config.get("init_ratio", 0.1)
self.patch_final_ratio = patch_mask_config.get("final_ratio", 0.5)
# Automatically enable patch masking if final_ratio > 0
self.use_patch_masking = self.patch_final_ratio > 0.0
# Validate patch_seed is set for reproducibility
if self.use_patch_masking and patch_seed is None:
print(
"Warning: patch_seed is None but patch masking is enabled. "
"Results may not be reproducible."
)
def apply_patch_masking(
self,
images: torch.Tensor,
training_step: int = 0,
is_training: bool = True,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Apply random patch masking with curriculum learning."""
# during training, apply masking
batch_size, num_views, channels, height, width = images.shape
device = images.device
patch_size = 16
num_patches_h = height // patch_size
num_patches_w = width // patch_size
if not is_training: # Not in training mode
# Calculate patch size (assuming 16x16 patches for ViT)
total_patches_per_view = num_patches_h * num_patches_w
# Create patch mask with all patches kept (1)
patch_mask = torch.ones(batch_size, num_views, total_patches_per_view, device=device)
return images, patch_mask
# Calculate current mask ratio
# start with no masking until patch_init_step
if training_step < self.patch_init_step:
mask_ratio = 0.0
else:
# start patch masking at patch_init_step, reach max by patch_final_step
curr_steps_for_patch = self.patch_final_step - self.patch_init_step
progress = min((training_step - self.patch_init_step) / curr_steps_for_patch, 1.0)
mask_ratio = (
self.patch_init_ratio
+ progress * (self.patch_final_ratio - self.patch_init_ratio)
)
# Calculate patch dimensions (assuming 16x16 patches for ViT)
total_patches_per_view = num_patches_h * num_patches_w
patches_to_mask_per_view = int(mask_ratio * total_patches_per_view)
# Initialize masks
patch_mask = torch.ones(batch_size, num_views, total_patches_per_view, device=device)
masked_images = images.clone()
# Apply patch masking per batch sample and view
for batch_idx in range(batch_size):
for view_idx in range(num_views):
if patches_to_mask_per_view > 0:
# Create a deterministic seed for this specific combination
# Same patches are masked for the same training step, batch, and view
# Using multiplication to avoid seed collisions between different combinations
deterministic_seed = (
self.patch_seed
+ training_step
+ batch_idx * 1000
+ view_idx * 100
)
# Create a local generator instead of modifying global torch seed
local_generator = torch.Generator(device=device)
local_generator.manual_seed(deterministic_seed)
# Random patch selection with local generator
patch_indices = torch.randperm(
total_patches_per_view,
device=device,
generator=local_generator
)[:patches_to_mask_per_view]
patch_mask[batch_idx, view_idx, patch_indices] = 0
# Zero out the selected patches
for patch_idx in patch_indices:
# Convert patch index to spatial coordinates
patch_h = (patch_idx // num_patches_w) * patch_size
patch_w = (patch_idx % num_patches_w) * patch_size
# Zero out the patch region
masked_images[
batch_idx,
view_idx,
:,
patch_h:patch_h + patch_size,
patch_w:patch_w + patch_size
] = 0
return masked_images, patch_mask
def apply_masking(
self,
images: torch.Tensor,
training_step: int = 0,
is_training: bool = True,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Apply patch masking if enabled, otherwise return original images."""
if self.use_patch_masking:
return self.apply_patch_masking(images, training_step, is_training)
else:
# No masking - return original images and dummy mask
batch_size, num_views = images.shape[:2]
device = images.device
dummy_mask = torch.ones(batch_size, num_views, device=device)
return images, dummy_mask
def get_training_schedule_info(
self,
current_step: int
) -> dict[str, Any]:
"""Get information about current training schedule progress."""
if self.use_patch_masking:
if current_step < self.patch_init_step:
current_mask_ratio = 0.0
curriculum_progress = "0.0%"
steps_to_patch_masking = self.patch_init_step - current_step
steps_to_max_masking = self.patch_final_step - current_step
else:
curr_steps_for_patch = self.patch_final_step - self.patch_init_step
progress = min((current_step - self.patch_init_step) / curr_steps_for_patch, 1.0)
current_mask_ratio = (
self.patch_init_ratio
+ progress * (self.patch_final_ratio - self.patch_init_ratio)
)
curriculum_progress = f"{progress * 100:.1f}%"
steps_to_patch_masking = 0
steps_to_max_masking = max(0, self.patch_final_step - current_step)
else:
current_mask_ratio = 0.0
curriculum_progress = "0.0%"
steps_to_max_masking = 0
steps_to_patch_masking = 0
return {
"step": current_step,
"mask_ratio": current_mask_ratio,
"curriculum_progress": curriculum_progress,
"steps_to_patch_masking": steps_to_patch_masking,
"steps_to_max_masking": steps_to_max_masking
}
def should_start_patch_masking(
self,
current_step: int
) -> bool:
"""Check if patch masking should start at current step."""
return self.use_patch_masking and current_step == self.patch_init_step
class JSONInferenceProgressTracker(Callback):
"""
A PyTorch Lightning callback that tracks prediction progress and saves it
to a specified JSON file.
"""
def __init__(self, filepath: Path) -> None:
"""Initialize JSONInferenceProgressTracker.
Args:
filepath: path to the JSON file where progress will be written.
"""
super().__init__()
self.filepath = filepath
self.current_step = 0
self.total_steps = 0
# Ensure the file exists (or is cleared) and the directory is available
os.makedirs(os.path.dirname(self.filepath) or ".", exist_ok=True)
self._save_progress(0, 1)
def _save_progress(self, current: int, total: int) -> None:
"""Helper function to write the progress dictionary to the JSON file."""
progress_data = {
"completed": current,
"total": total,
"timestamp": time.time(),
}
# Use a temporary file and rename to ensure atomic write,
# preventing external readers from getting a half-written file.
temp_filepath = f"{self.filepath}.tmp"
try:
with open(temp_filepath, "w") as f:
json.dump(progress_data, f, indent=4)
os.replace(temp_filepath, self.filepath)
except Exception as e:
# Handle potential file I/O errors gracefully
print(f"\n[Error saving progress to JSON]: {e}")
if os.path.exists(temp_filepath):
os.remove(temp_filepath)
def on_predict_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
"""Called when prediction starts."""
# Calculate the total number of batches to predict
self.total_steps = int(trainer.num_predict_batches[0]) # Assumes one dataloader
self.current_step = 0
# Save initial state
self._save_progress(self.current_step, self.total_steps)
def on_predict_batch_end(
self,
trainer: Trainer,
pl_module: LightningModule,
outputs: Any,
batch: Any,
batch_idx: int,
dataloader_idx: int = 0,
) -> None:
"""Called when a prediction batch ends."""
self.current_step += 1
# Save updated progress
self._save_progress(self.current_step, self.total_steps)
def on_predict_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
"""Called when prediction finishes."""
# Save final state
self._save_progress(self.total_steps, self.total_steps)
class JSONTrainingProgressTracker(Callback):
"""
Tracks training progress (epochs or epochs) and saves it to a JSON file.
"""
steps_mode: bool
def __init__(self, filepath: Path) -> None:
"""Initialize JSONTrainingProgressTracker.
Args:
filepath: path to the JSON file where progress will be written.
"""
super().__init__()
self.filepath = filepath
self.current = 0
self.total = 0
os.makedirs(os.path.dirname(self.filepath) or ".", exist_ok=True)
# Initialize with a base state (0 completed out of 1 total placeholder)
self._save_progress(0, 1)
def _save_progress(self, completed: int, total: int) -> None:
"""Helper function to write the progress dictionary to the JSON file.
Training is different from inference because the existing file has pid and status
information that we should not entirely overwrite.
"""
progress_data = {
"status": "TRAINING" if completed < total else "EVALUATING",
"progress": {
"completed": completed,
"total": total,
"timestamp": time.time(),
},
}
existing_file_contents = (
json.load(open(self.filepath)) if os.path.exists(self.filepath) else {}
)
new_file_contents = {**existing_file_contents, **progress_data}
# Use a temporary file and rename to ensure atomic write,
# preventing external readers from getting a half-written file.
temp_filepath = f"{self.filepath}.tmp"
try:
with open(temp_filepath, "w") as f:
json.dump(new_file_contents, f, indent=4)
os.replace(temp_filepath, self.filepath)
except Exception as e:
# Handle potential file I/O errors gracefully
print(f"\n[Error saving progress to JSON]: {e}")
if os.path.exists(temp_filepath):
os.remove(temp_filepath)
def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
"""Called when training starts."""
# Determine tracking mode based on Trainer configuration
max_epochs = trainer.max_epochs if trainer.max_epochs is not None else 0
max_steps = trainer.max_steps if trainer.max_steps is not None else 0
# Default to epoch tracking unless max_epochs is 0 or -1 (unlimited)
self.total, self.steps_mode = max_epochs, False
if not self.total or max_epochs == -1:
self.total, self.steps_mode = max_steps, True
self.current = 0
# Save initial state
self._save_progress(self.current, self.total)
def on_train_batch_end(
self,
trainer: Trainer,
pl_module: LightningModule,
outputs: Any,
batch: Any,
batch_idx: int,
) -> None:
"""Called when a training batch ends, used for step mode."""
if self.steps_mode and self.total > 0:
# trainer.global_step is 0-indexed
self.current = trainer.global_step + 1
self._save_progress(self.current, self.total)
def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
"""Called after an epoch finishes, used for epoch mode."""
if not self.steps_mode and self.total > 0:
# trainer.current_epoch is 0-indexed, so we add 1 for "completed" count
self.current = trainer.current_epoch + 1
self._save_progress(self.current, self.total)
def on_train_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
"""Called when training finishes."""
self.current = self.total # Ensure completed == total
self._save_progress(self.current, self.total)
print(
f"\n[JSONTrainingProgressTracker] Training finished. "
f"Final status saved to {self.filepath}"
)
[docs]
def get_callbacks(
cfg: DictConfig | ListConfig,
early_stopping: bool = False,
checkpointing: bool = True,
lr_monitor: bool = True,
ckpt_every_n_epochs: int | None = None,
backbone_unfreeze: bool = True,
status_file: Path | None = None,
) -> list:
"""Build and return the list of training callbacks based on the config.
Args:
cfg: hydra config containing training and callback parameters.
early_stopping: if True, add an ``EarlyStopping`` callback.
checkpointing: if True, add a ``ModelCheckpoint`` callback that saves the best model.
lr_monitor: if True, add a ``LearningRateMonitor`` callback.
ckpt_every_n_epochs: if not None, also save a checkpoint every this many epochs.
backbone_unfreeze: if True, add the ``UnfreezeBackbone`` callback.
status_file: if not None, add a ``JSONTrainingProgressTracker`` callback writing to this
path.
Returns:
List of callback objects ready to pass to a ``pl.Trainer``.
"""
callbacks = []
if early_stopping:
early_stopping_cb = EarlyStopping(
monitor='val_supervised_loss',
patience=cfg.training.early_stop_patience,
mode='min',
)
callbacks.append(early_stopping_cb)
if backbone_unfreeze:
unfreeze_step = cfg.training.get('unfreezing_step')
unfreeze_epoch = cfg.training.get('unfreezing_epoch')
unfreeze_backbone_callback = UnfreezeBackbone(
unfreeze_step=unfreeze_step, unfreeze_epoch=unfreeze_epoch,
)
callbacks.append(unfreeze_backbone_callback)
if lr_monitor:
# this callback should be added after UnfreezeBackbone in order to log its learning rate
lr_monitor_cb = LearningRateMonitor(logging_interval='epoch')
callbacks.append(lr_monitor_cb)
if checkpointing:
ckpt_best_callback = ModelCheckpoint(
monitor='val_supervised_loss',
mode='min',
filename='{epoch}-{step}-best',
)
callbacks.append(ckpt_best_callback)
if ckpt_every_n_epochs:
ckpt_callback = ModelCheckpoint(
monitor=None,
every_n_epochs=ckpt_every_n_epochs,
save_top_k=-1,
)
callbacks.append(ckpt_callback)
# we need this callback for both supervised and unsupervised losses
has_supervised_loss = any(
loss_config.get('log_weight') is not None
for loss_name, loss_config in cfg.losses.items() if loss_name.startswith('supervised_')
)
if (
((cfg.model.losses_to_use != []) and (cfg.model.losses_to_use is not None))
or has_supervised_loss
):
anneal_weight_callback = AnnealWeight(**cfg.callbacks.anneal_weight)
callbacks.append(anneal_weight_callback)
if (
cfg.model.model_type == 'heatmap_multiview_transformer'
and cfg.training.get('patch_mask', {}).get('final_ratio', 0.0) > 0.0
):
patch_masking_callback = PatchMasking(
patch_mask_config=cfg.training.get('patch_mask', {}),
patch_seed=cfg.training.rng_seed_model_pt,
)
callbacks.append(patch_masking_callback)
if status_file is not None:
callbacks.append(JSONTrainingProgressTracker(status_file))
return callbacks