Source code for lightning_pose.callbacks

import json
import os
import time
from pathlib import Path
from typing import Any, Dict, Tuple

import lightning.pytorch as pl
import pytest
import torch
from lightning import Trainer, LightningModule
from lightning.pytorch.callbacks import Callback

# to ignore imports for sphix-autoapidoc
__all__ = [
    "AnnealWeight",
    "UnfreezeBackbone",
    "PatchMasking",
]


[docs] class AnnealWeight(Callback): """Callback to change weight value during training."""
[docs] def __init__( self, attr_name: str, init_val: float = 0.0, increase_factor: float = 0.01, final_val: float = 1.0, freeze_until_epoch: int = 0, ) -> None: super().__init__() self.init_val = init_val self.increase_factor = increase_factor self.final_val = final_val self.freeze_until_epoch = freeze_until_epoch self.attr_name = attr_name
[docs] def on_train_start(self, trainer, pl_module) -> None: # Dan: removed buffer; seems to complicate checkpoint loading # pl_module.register_buffer(self.attr_name, torch.tensor(self.init_val)) setattr(pl_module, self.attr_name, torch.tensor(self.init_val))
[docs] def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: if pl_module.current_epoch <= self.freeze_until_epoch: pass else: eff_epoch: int = pl_module.current_epoch - self.freeze_until_epoch value: float = min( self.init_val + eff_epoch * self.increase_factor, self.final_val ) # Dan: removed buffer; seems to complicate checkpoint loading # pl_module.register_buffer(self.attr_name, torch.tensor(value)) setattr(pl_module, self.attr_name, torch.tensor(value))
[docs] class UnfreezeBackbone(Callback): """Callback that ramps up the backbone learning rate from 0 to `upsampling_lr` on `unfreeze_epoch` or `unfreeze_step`. Starts LR at `initial_ratio * upsampling_lr`. Grows lr by a factor of `warm_up_ratio` per epoch or step. Once LR reaches `upsampling_lr`, keeps it in sync with `upsampling_lr`. Use instead of pl.callbacks.BackboneFinetuning in order to use multi-GPU (DDP). See lightning-ai/pytorch-lightning#20340 for context. """ _initial_lr: int
[docs] def __init__( self, unfreeze_epoch: int | None = None, unfreeze_step: int | None = None, initial_ratio=0.1, warm_up_ratio=1.5, ): assert (unfreeze_epoch is None) != ( unfreeze_step is None ), "Exactly one must be provided." self.unfreeze_epoch = unfreeze_epoch self.unfreeze_step = unfreeze_step self.initial_ratio = initial_ratio self.warm_up_ratio = warm_up_ratio self._warmed_up = False
[docs] def on_train_batch_start(self, trainer, pl_module, batch, batch_idx) -> None: # Once backbone_lr warms up to upsampling_lr, this callback does nothing. # Control of backbone lr is then the sole job of the main lr scheduler. if self._warmed_up: return optimizer = pl_module.optimizers() # Check our assumptions about param group indices assert optimizer.param_groups[0]["name"] == "backbone" head_lr = optimizer.param_groups[1]["lr"] optimizer.param_groups[0]["lr"] = self._get_backbone_lr( pl_module.global_step, pl_module.current_epoch, head_lr )
def _get_backbone_lr(self, current_step, current_epoch, upsampling_lr): """Returns what the backbone LR should be at this point in time. Args: Only one of `current_step` and `current_epoch` will be used. If self.unfreeze_epoch is not None, then we'll use `current_epoch` Otherwise, unfreeze_step is not None and we'll use `current_step`. """ assert not self._warmed_up # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # In the code below, the variables are named in terms of epoch, # but the same logic applies for steps, conveniently enough. # So if we're in "step mode", plug in steps into epoch variables. # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # unfreeze_epoch = self.unfreeze_epoch if self.unfreeze_step is not None: unfreeze_epoch = self.unfreeze_step current_epoch = current_step # After this point, use `unfreeze_epoch` instead of `self.unfreeze_[epoch|step]`. # Main logic begins: # Before unfreeze, learning_rate is 0. if current_epoch < unfreeze_epoch: return 0.0 # On unfreeze, initialize learning rate. # Remember this initial value for warm up. if current_epoch == unfreeze_epoch: self._initial_lr = self.initial_ratio * upsampling_lr return self._initial_lr # Warm up: compute inital_ratio * epoch_ratio ** epochs_since_thaw. # Use stored initial_ratio rather than recomputing it since # upsampling_lr is subject to change via the scheduler. if current_epoch > unfreeze_epoch: epochs_since_thaw = current_epoch - unfreeze_epoch next_lr = min( self._initial_lr * self.warm_up_ratio**epochs_since_thaw, upsampling_lr ) if next_lr == upsampling_lr: self._warmed_up = True return next_lr
[docs] class PatchMasking(Callback): """Callback to apply curriculum patch masking during training."""
[docs] def __init__( self, patch_mask_config: dict = None, patch_seed: int = 0, ): super().__init__() # Initialize curriculum masking self.curriculum_masking = PatchMasker( patch_mask_config=patch_mask_config, patch_seed=patch_seed, )
[docs] def on_train_batch_start(self, trainer, pl_module, batch, batch_idx): """Apply patch masking to the batch before it goes to the model.""" if not self.curriculum_masking.use_patch_masking: return # Extract images from batch if isinstance(batch, dict): if "images" in batch: images = batch["images"] elif "frames" in batch: images = batch["frames"] else: return else: # Handle case where batch is just images images = batch # Apply masking masked_images, patch_mask = self.curriculum_masking.apply_patch_masking( images, training_step=trainer.global_step, is_training=True, ) # Update the batch with masked images if isinstance(batch, dict): if "images" in batch: batch["images"] = masked_images elif "frames" in batch: batch["frames"] = masked_images else: # Replace the batch entirely batch = masked_images # Store patch mask for potential use in loss computation pl_module.current_patch_mask = patch_mask
[docs] def on_train_epoch_end(self, trainer, pl_module) -> None: """Log curriculum progress.""" if not self.curriculum_masking.use_patch_masking: return schedule_info = self.curriculum_masking.get_training_schedule_info(trainer.global_step) # Log curriculum information - only log numeric values pl_module.log( "patch_mask_ratio", schedule_info['mask_ratio'], on_step=False, on_epoch=True, prog_bar=True, )
class PatchMasker: """Handles curriculum learning and masking for multiview transformer training.""" def __init__( self, patch_mask_config: dict | None = None, patch_seed: int = 0, ): """Initialize curriculum masking parameters. Args: patch_mask_config: Dictionary containing patch masking configuration - init_step: Step to start patch masking - final_step: Step when patch masking reaches maximum - init_ratio: Initial masking ratio - final_ratio: Final masking ratio patch_seed: Seed for deterministic patch masking to allow reproducibility. """ self.patch_seed = patch_seed # Parse patch masking configuration if patch_mask_config is None: patch_mask_config = {} self.patch_init_step = patch_mask_config.get("init_step", 700) self.patch_final_step = patch_mask_config.get("final_step", 5000) self.patch_init_ratio = patch_mask_config.get("init_ratio", 0.1) self.patch_final_ratio = patch_mask_config.get("final_ratio", 0.5) # Automatically enable patch masking if final_ratio > 0 self.use_patch_masking = self.patch_final_ratio > 0.0 # Validate patch_seed is set for reproducibility if self.use_patch_masking and patch_seed is None: print( "Warning: patch_seed is None but patch masking is enabled. " "Results may not be reproducible." ) def apply_patch_masking( self, images: torch.Tensor, training_step: int = 0, is_training: bool = True, ) -> Tuple[torch.Tensor, torch.Tensor]: """Apply random patch masking with curriculum learning.""" # during training, apply masking batch_size, num_views, channels, height, width = images.shape device = images.device patch_size = 16 num_patches_h = height // patch_size num_patches_w = width // patch_size if not is_training: # Not in training mode # Calculate patch size (assuming 16x16 patches for ViT) total_patches_per_view = num_patches_h * num_patches_w # Create patch mask with all patches kept (1) patch_mask = torch.ones(batch_size, num_views, total_patches_per_view, device=device) return images, patch_mask # Calculate current mask ratio # start with no masking until patch_init_step if training_step < self.patch_init_step: mask_ratio = 0.0 else: # start patch masking at patch_init_step, reach max by patch_final_step curr_steps_for_patch = self.patch_final_step - self.patch_init_step progress = min((training_step - self.patch_init_step) / curr_steps_for_patch, 1.0) mask_ratio = ( self.patch_init_ratio + progress * (self.patch_final_ratio - self.patch_init_ratio) ) # Calculate patch dimensions (assuming 16x16 patches for ViT) total_patches_per_view = num_patches_h * num_patches_w patches_to_mask_per_view = int(mask_ratio * total_patches_per_view) # Initialize masks patch_mask = torch.ones(batch_size, num_views, total_patches_per_view, device=device) masked_images = images.clone() # Apply patch masking per batch sample and view for batch_idx in range(batch_size): for view_idx in range(num_views): if patches_to_mask_per_view > 0: # Create a deterministic seed for this specific combination # Same patches are masked for the same training step, batch, and view # Using multiplication to avoid seed collisions between different combinations deterministic_seed = ( self.patch_seed + training_step + batch_idx * 1000 + view_idx * 100 ) # Create a local generator instead of modifying global torch seed local_generator = torch.Generator(device=device) local_generator.manual_seed(deterministic_seed) # Random patch selection with local generator patch_indices = torch.randperm( total_patches_per_view, device=device, generator=local_generator )[:patches_to_mask_per_view] patch_mask[batch_idx, view_idx, patch_indices] = 0 # Zero out the selected patches for patch_idx in patch_indices: # Convert patch index to spatial coordinates patch_h = (patch_idx // num_patches_w) * patch_size patch_w = (patch_idx % num_patches_w) * patch_size # Zero out the patch region masked_images[ batch_idx, view_idx, :, patch_h:patch_h + patch_size, patch_w:patch_w + patch_size ] = 0 return masked_images, patch_mask def apply_masking( self, images: torch.Tensor, training_step: int = 0, is_training: bool = True, ) -> Tuple[torch.Tensor, torch.Tensor]: """Apply patch masking if enabled, otherwise return original images.""" if self.use_patch_masking: return self.apply_patch_masking(images, training_step, is_training) else: # No masking - return original images and dummy mask batch_size, num_views = images.shape[:2] device = images.device dummy_mask = torch.ones(batch_size, num_views, device=device) return images, dummy_mask def get_training_schedule_info( self, current_step: int ) -> Dict[str, Any]: """Get information about current training schedule progress.""" if self.use_patch_masking: if current_step < self.patch_init_step: current_mask_ratio = 0.0 curriculum_progress = "0.0%" steps_to_patch_masking = self.patch_init_step - current_step steps_to_max_masking = self.patch_final_step - current_step else: curr_steps_for_patch = self.patch_final_step - self.patch_init_step progress = min((current_step - self.patch_init_step) / curr_steps_for_patch, 1.0) current_mask_ratio = ( self.patch_init_ratio + progress * (self.patch_final_ratio - self.patch_init_ratio) ) curriculum_progress = f"{progress*100:.1f}%" steps_to_patch_masking = 0 steps_to_max_masking = max(0, self.patch_final_step - current_step) else: current_mask_ratio = 0.0 curriculum_progress = "0.0%" steps_to_max_masking = 0 steps_to_patch_masking = 0 return { "step": current_step, "mask_ratio": current_mask_ratio, "curriculum_progress": curriculum_progress, "steps_to_patch_masking": steps_to_patch_masking, "steps_to_max_masking": steps_to_max_masking } def should_start_patch_masking( self, current_step: int ) -> bool: """Check if patch masking should start at current step.""" return self.use_patch_masking and current_step == self.patch_init_step class JSONInferenceProgressTracker(Callback): """ A PyTorch Lightning callback that tracks prediction progress and saves it to a specified JSON file. """ def __init__(self, filepath: Path): super().__init__() self.filepath = filepath self.current_step = 0 self.total_steps = 0 # Ensure the file exists (or is cleared) and the directory is available os.makedirs(os.path.dirname(self.filepath) or ".", exist_ok=True) self._save_progress(0, 1) def _save_progress(self, current: int, total: int): """Helper function to write the progress dictionary to the JSON file.""" progress_data = { "completed": current, "total": total, "timestamp": time.time(), } # Use a temporary file and rename to ensure atomic write, # preventing external readers from getting a half-written file. temp_filepath = f"{self.filepath}.tmp" try: with open(temp_filepath, "w") as f: json.dump(progress_data, f, indent=4) os.replace(temp_filepath, self.filepath) except Exception as e: # Handle potential file I/O errors gracefully print(f"\n[Error saving progress to JSON]: {e}") if os.path.exists(temp_filepath): os.remove(temp_filepath) def on_predict_start(self, trainer: Trainer, pl_module: LightningModule) -> None: """Called when prediction starts.""" # Calculate the total number of batches to predict self.total_steps = trainer.num_predict_batches[0] # Assumes one dataloader self.current_step = 0 # Save initial state self._save_progress(self.current_step, self.total_steps) def on_predict_batch_end( self, trainer: Trainer, pl_module: LightningModule, outputs, batch, batch_idx: int, dataloader_idx: int = 0, ) -> None: """Called when a prediction batch ends.""" self.current_step += 1 # Save updated progress self._save_progress(self.current_step, self.total_steps) def on_predict_end(self, trainer: Trainer, pl_module: LightningModule) -> None: """Called when prediction finishes.""" # Save final state self._save_progress(self.total_steps, self.total_steps) class JSONTrainingProgressTracker(Callback): """ Tracks training progress (epochs or epochs) and saves it to a JSON file. """ steps_mode: bool def __init__(self, filepath: Path): super().__init__() self.filepath = filepath self.current = 0 self.total = 0 os.makedirs(os.path.dirname(self.filepath) or ".", exist_ok=True) # Initialize with a base state (0 completed out of 1 total placeholder) self._save_progress(0, 1) def _save_progress(self, completed: int, total: int): """Helper function to write the progress dictionary to the JSON file. Training is different from inference because the existing file has pid and status information that we should not entirely overwrite. """ progress_data = { "status": "TRAINING" if completed < total else "EVALUATING", "progress": { "completed": completed, "total": total, "timestamp": time.time(), }, } existing_file_contents = ( json.load(open(self.filepath)) if os.path.exists(self.filepath) else {} ) new_file_contents = {**existing_file_contents, **progress_data} # Use a temporary file and rename to ensure atomic write, # preventing external readers from getting a half-written file. temp_filepath = f"{self.filepath}.tmp" try: with open(temp_filepath, "w") as f: json.dump(new_file_contents, f, indent=4) os.replace(temp_filepath, self.filepath) except Exception as e: # Handle potential file I/O errors gracefully print(f"\n[Error saving progress to JSON]: {e}") if os.path.exists(temp_filepath): os.remove(temp_filepath) def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: """Called when training starts.""" # Determine tracking mode based on Trainer configuration max_epochs = trainer.max_epochs if trainer.max_epochs is not None else 0 max_steps = trainer.max_steps if trainer.max_steps is not None else 0 # Default to epoch tracking unless max_epochs is 0 or -1 (unlimited) self.total, self.steps_mode = max_epochs, False if not self.total or max_epochs == -1: self.total, self.steps_mode = max_steps, True self.current = 0 # Save initial state self._save_progress(self.current, self.total) def on_train_batch_end( self, trainer: Trainer, pl_module: LightningModule, outputs, batch, batch_idx: int ) -> None: """Called when a training batch ends, used for step mode.""" if self.steps_mode and self.total > 0: # trainer.global_step is 0-indexed self.current = trainer.global_step + 1 self._save_progress(self.current, self.total) def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None: """Called after an epoch finishes, used for epoch mode.""" if not self.steps_mode and self.total > 0: # trainer.current_epoch is 0-indexed, so we add 1 for "completed" count self.current = trainer.current_epoch + 1 self._save_progress(self.current, self.total) def on_train_end(self, trainer: Trainer, pl_module: LightningModule) -> None: """Called when training finishes.""" self.current = self.total # Ensure completed == total self._save_progress(self.current, self.total) print( f"\n[JSONTrainingProgressTracker] Training finished. Final status saved to {self.filepath}" )