Source code for lightning_pose.models.factory

"""Factory function for creating pose estimation models from config."""

from __future__ import annotations

import glob
import os
from collections import OrderedDict
from typing import TYPE_CHECKING

import torch
from omegaconf import DictConfig, ListConfig

from lightning_pose.data.datamodules import BaseDataModule, UnlabeledDataModule
from lightning_pose.models.base import (
    _apply_defaults_for_lr_scheduler_params,
    _apply_defaults_for_optimizer_params,
    check_if_semi_supervised,
)

if TYPE_CHECKING:
    from lightning_pose.losses.factory import LossFactory
    from lightning_pose.models import ALLOWED_MODELS

__all__ = ['get_model']


[docs] def get_model( cfg: DictConfig | ListConfig, data_module: BaseDataModule | UnlabeledDataModule | None, loss_factories: dict[str, LossFactory] | dict[str, None], ) -> ALLOWED_MODELS: """Create model: regression or heatmap based, supervised or semi-supervised.""" optimizer = cfg.training.get('optimizer', 'Adam') optimizer_params = _apply_defaults_for_optimizer_params( optimizer, cfg.training.get('optimizer_params'), ) lr_scheduler = cfg.training.get('lr_scheduler', 'multisteplr') lr_scheduler_params = _apply_defaults_for_lr_scheduler_params( lr_scheduler, cfg.training.get('lr_scheduler_params', {}).get(f'{lr_scheduler}'), ) semi_supervised = check_if_semi_supervised(cfg.model.losses_to_use) image_h = cfg.data.image_resize_dims.height image_w = cfg.data.image_resize_dims.width if 'vit' in cfg.model.backbone: if image_h != image_w: raise RuntimeError('ViT model requires resized height and width to be equal') backbone_pretrained = cfg.model.get('backbone_pretrained', True) if not semi_supervised: if cfg.model.model_type == 'regression': from lightning_pose.models import RegressionTracker model = RegressionTracker( num_keypoints=cfg.data.num_keypoints, loss_factory=loss_factories['supervised'], backbone=cfg.model.backbone, pretrained=backbone_pretrained, torch_seed=cfg.training.rng_seed_model_pt, optimizer=optimizer, optimizer_params=optimizer_params, lr_scheduler=lr_scheduler, lr_scheduler_params=lr_scheduler_params, image_size=image_h, ) elif cfg.model.model_type == 'heatmap': num_targets = data_module.dataset.num_targets if data_module else None from lightning_pose.models import HeatmapTracker model = HeatmapTracker( num_keypoints=cfg.data.num_keypoints, num_targets=num_targets, loss_factory=loss_factories['supervised'], backbone=cfg.model.backbone, pretrained=backbone_pretrained, downsample_factor=cfg.data.get('downsample_factor', 2), torch_seed=cfg.training.rng_seed_model_pt, optimizer=optimizer, optimizer_params=optimizer_params, lr_scheduler=lr_scheduler, lr_scheduler_params=lr_scheduler_params, image_size=image_h, backbone_checkpoint=cfg.model.get('backbone_checkpoint'), ) elif cfg.model.model_type == 'heatmap_mhcrnn': from lightning_pose.models import HeatmapTrackerMHCRNN model = HeatmapTrackerMHCRNN( num_keypoints=cfg.data.num_keypoints, loss_factory=loss_factories['supervised'], backbone=cfg.model.backbone, pretrained=backbone_pretrained, downsample_factor=cfg.data.get('downsample_factor', 2), torch_seed=cfg.training.rng_seed_model_pt, optimizer=optimizer, optimizer_params=optimizer_params, lr_scheduler=lr_scheduler, lr_scheduler_params=lr_scheduler_params, image_size=image_h, backbone_checkpoint=cfg.model.get('backbone_checkpoint'), ) elif cfg.model.model_type == 'heatmap_multiview_transformer': from lightning_pose.models import HeatmapTrackerMultiviewTransformer model = HeatmapTrackerMultiviewTransformer( num_keypoints=cfg.data.num_keypoints, num_views=len(cfg.data.view_names), loss_factory=loss_factories['supervised'], backbone=cfg.model.backbone, pretrained=backbone_pretrained, head=cfg.model.get('head', 'heatmap_cnn'), downsample_factor=cfg.data.get('downsample_factor', 2), torch_seed=cfg.training.rng_seed_model_pt, optimizer=optimizer, optimizer_params=optimizer_params, lr_scheduler=lr_scheduler, lr_scheduler_params=lr_scheduler_params, image_size=image_h, backbone_checkpoint=cfg.model.get('backbone_checkpoint'), ) else: raise NotImplementedError( f'{cfg.model.model_type} is an invalid cfg.model.model_type for a fully ' f'supervised model' ) else: if cfg.model.model_type == 'regression': from lightning_pose.models import SemiSupervisedRegressionTracker model = SemiSupervisedRegressionTracker( num_keypoints=cfg.data.num_keypoints, loss_factory=loss_factories['supervised'], loss_factory_unsupervised=loss_factories['unsupervised'], backbone=cfg.model.backbone, pretrained=backbone_pretrained, torch_seed=cfg.training.rng_seed_model_pt, optimizer=optimizer, optimizer_params=optimizer_params, lr_scheduler=lr_scheduler, lr_scheduler_params=lr_scheduler_params, image_size=image_h, ) elif cfg.model.model_type == 'heatmap': from lightning_pose.models import SemiSupervisedHeatmapTracker model = SemiSupervisedHeatmapTracker( num_keypoints=cfg.data.num_keypoints, loss_factory=loss_factories['supervised'], loss_factory_unsupervised=loss_factories['unsupervised'], backbone=cfg.model.backbone, pretrained=backbone_pretrained, downsample_factor=cfg.data.get('downsample_factor', 2), torch_seed=cfg.training.rng_seed_model_pt, optimizer=optimizer, optimizer_params=optimizer_params, lr_scheduler=lr_scheduler, lr_scheduler_params=lr_scheduler_params, image_size=image_h, backbone_checkpoint=cfg.model.get('backbone_checkpoint'), ) elif cfg.model.model_type == 'heatmap_mhcrnn': from lightning_pose.models import SemiSupervisedHeatmapTrackerMHCRNN model = SemiSupervisedHeatmapTrackerMHCRNN( num_keypoints=cfg.data.num_keypoints, loss_factory=loss_factories['supervised'], loss_factory_unsupervised=loss_factories['unsupervised'], backbone=cfg.model.backbone, pretrained=backbone_pretrained, downsample_factor=cfg.data.get('downsample_factor', 2), torch_seed=cfg.training.rng_seed_model_pt, optimizer=optimizer, optimizer_params=optimizer_params, lr_scheduler=lr_scheduler, lr_scheduler_params=lr_scheduler_params, image_size=image_h, backbone_checkpoint=cfg.model.get('backbone_checkpoint'), ) elif cfg.model.model_type == 'heatmap_multiview_transformer': from lightning_pose.models import SemiSupervisedHeatmapTrackerMultiviewTransformer model = SemiSupervisedHeatmapTrackerMultiviewTransformer( num_keypoints=cfg.data.num_keypoints, num_views=len(cfg.data.view_names), loss_factory=loss_factories['supervised'], loss_factory_unsupervised=loss_factories['unsupervised'], backbone=cfg.model.backbone, pretrained=backbone_pretrained, head=cfg.model.get('head', 'heatmap_cnn'), downsample_factor=cfg.data.get('downsample_factor', 2), torch_seed=cfg.training.rng_seed_model_pt, optimizer=optimizer, optimizer_params=optimizer_params, lr_scheduler=lr_scheduler, lr_scheduler_params=lr_scheduler_params, image_size=image_h, patch_mask_config=cfg.training.get('patch_mask', {}), ) else: raise NotImplementedError( f'{cfg.model.model_type} invalid cfg.model.model_type for a semi-supervised model' ) if cfg.model.get('checkpoint', None): ckpt = cfg.model.checkpoint print(f'Loading weights from {ckpt}') if not ckpt.endswith('.ckpt'): ckpt = glob.glob(os.path.join(ckpt, '**', '*.ckpt'), recursive=True)[0] try: state_dict = torch.load(ckpt)['state_dict'] except Exception as e: print(f'Warning: Failed to load checkpoint with default settings: {e}') print('Attempting to load with weights_only=False...') state_dict = torch.load(ckpt, weights_only=False)['state_dict'] try: model.load_state_dict(state_dict, strict=False) except RuntimeError: new_state_dict = OrderedDict() for key, val in state_dict.items(): if 'backbone' in key: new_state_dict[key] = val model.load_state_dict(new_state_dict, strict=False) return model