"""Factory functions to build data pipeline components from config."""
import warnings
import imgaug.augmenters as iaa
import numpy as np
from omegaconf import DictConfig, ListConfig, OmegaConf
from omegaconf.errors import ValidationError
from lightning_pose.data.augmentations import (
expand_imgaug_str_to_dict,
imgaug_transform,
)
from lightning_pose.data.datamodules import BaseDataModule, UnlabeledDataModule
from lightning_pose.data.datasets import (
BaseTrackingDataset,
HeatmapDataset,
MultiviewHeatmapDataset,
)
# to ignore imports for sphinx-autoapidoc
__all__ = [
'get_imgaug_transform',
'get_dataset',
'get_data_module',
]
[docs]
def get_dataset(
cfg: DictConfig | ListConfig,
data_dir: str,
imgaug_transform: iaa.Sequential,
) -> BaseTrackingDataset | HeatmapDataset | MultiviewHeatmapDataset:
"""Create a dataset that contains labeled data."""
if cfg.model.model_type == 'regression':
if cfg.data.get('view_names', None) and len(cfg.data.view_names) > 1:
raise NotImplementedError('Multi-view support only available for heatmap-based models')
else:
dataset = BaseTrackingDataset(
root_directory=data_dir,
csv_path=cfg.data.csv_file,
image_resize_height=cfg.data.image_resize_dims.height,
image_resize_width=cfg.data.image_resize_dims.width,
imgaug_transform=imgaug_transform,
do_context=False, # no context for regression models
)
elif cfg.model.model_type.find('heatmap') > -1:
if cfg.data.get('view_names', None) and len(cfg.data.view_names) > 1:
UserWarning(
'No precautions regarding the size of the images were considered here, '
'images will be resized accordingly to configs!'
)
if (
cfg.training.imgaug in ['default', 'none']
or not cfg.data.get('camera_params_file')
):
# we are either
# 1. running inference on un-augmented data, and need to make sure to resize
# 2. using a multiview model w/o camera params, and need to take care of resizing
resize = True
else:
resize = False
dataset = MultiviewHeatmapDataset(
root_directory=data_dir,
csv_paths=cfg.data.csv_file,
view_names=list(cfg.data.view_names),
image_resize_height=cfg.data.image_resize_dims.height,
image_resize_width=cfg.data.image_resize_dims.width,
imgaug_transform=imgaug_transform,
downsample_factor=cfg.data.get('downsample_factor', 2),
do_context=cfg.model.model_type == 'heatmap_mhcrnn', # context only for mhcrnn
resize=resize,
uniform_heatmaps=cfg.training.get('uniform_heatmaps_for_nan_keypoints', False),
camera_params_path=cfg.data.get('camera_params_file', None),
bbox_paths=cfg.data.get('bbox_file', None),
)
else:
dataset = HeatmapDataset(
root_directory=data_dir,
csv_path=cfg.data.csv_file,
image_resize_height=cfg.data.image_resize_dims.height,
image_resize_width=cfg.data.image_resize_dims.width,
imgaug_transform=imgaug_transform,
downsample_factor=cfg.data.get('downsample_factor', 2),
do_context=cfg.model.model_type == 'heatmap_mhcrnn', # context only for mhcrnn
uniform_heatmaps=cfg.training.get('uniform_heatmaps_for_nan_keypoints', False),
)
else:
raise NotImplementedError(f'{cfg.model.model_type} is an invalid cfg.model.model_type')
return dataset
[docs]
def get_data_module(
cfg: DictConfig | ListConfig,
dataset: BaseTrackingDataset | HeatmapDataset | MultiviewHeatmapDataset,
video_dir: str | None = None,
) -> BaseDataModule | UnlabeledDataModule:
"""Create a data module that splits a dataset into train/val/test iterators."""
# Old configs may have num_gpus: 0. We will remove support in a future release.
if cfg.training.num_gpus == 0:
warnings.warn(
'Config contains unsupported value num_gpus: 0. '
'Update num_gpus to 1 in your config.',
stacklevel=2,
)
cfg.training.num_gpus = max(cfg.training.num_gpus, 1)
# Divide config batch_size by num_gpus to maintain the same effective batch
# size in a multi-gpu setting.
train_batch_size = int(
np.ceil(cfg.training.train_batch_size / cfg.training.num_gpus)
)
val_batch_size = int(np.ceil(cfg.training.val_batch_size / cfg.training.num_gpus))
from lightning_pose.models.base import check_if_semi_supervised
semi_supervised = check_if_semi_supervised(cfg.model.losses_to_use)
if not semi_supervised:
data_module = BaseDataModule(
dataset=dataset,
train_batch_size=train_batch_size,
val_batch_size=val_batch_size,
test_batch_size=cfg.training.test_batch_size,
num_workers=cfg.training.get('num_workers'),
train_probability=cfg.training.train_prob,
val_probability=cfg.training.val_prob,
train_frames=cfg.training.train_frames,
torch_seed=cfg.training.rng_seed_data_pt,
)
else:
# Divide config batch_size by num_gpus to maintain the same effective batch
# size in a multi-gpu setting.
base_sequence_length = int(
np.ceil(cfg.dali.base.train.sequence_length / cfg.training.num_gpus)
)
# Maintain effective context batch size in num_gpus adjustment,
# otherwise the effective context batch size will be too small due to the
# 2 context frames on each side of center.
_effective_context_batch_size = max(cfg.dali.context.train.batch_size - 4, 0)
# Each GPU should get the effective batch size / num_gpus, + 4 for context frames.
context_batch_size = int(
np.ceil(_effective_context_batch_size / cfg.training.num_gpus + 4)
)
if cfg.model.model_type == 'heatmap_mhcrnn' and context_batch_size < 5:
raise ValidationError(
'dali.context.train.batch_size must be >= 5 * num_gpus for '
'semi-supervised context models. '
'Found {cfg.dali.context.train.batch_size}'
)
dali_config = OmegaConf.merge(
cfg.dali,
{
'base': {'train': {'sequence_length': base_sequence_length}},
'context': {'train': {'batch_size': context_batch_size}},
},
)
assert video_dir is not None, 'video_dir must be provided for semi-supervised training'
view_names = cfg.data.get('view_names', None)
view_names = list(view_names) if view_names is not None else None
data_module = UnlabeledDataModule(
dataset=dataset,
video_paths_list=video_dir,
view_names=view_names,
train_batch_size=train_batch_size,
val_batch_size=val_batch_size,
test_batch_size=cfg.training.test_batch_size,
num_workers=cfg.training.get('num_workers'),
train_probability=cfg.training.train_prob,
val_probability=cfg.training.val_prob,
train_frames=cfg.training.train_frames,
dali_config=dali_config,
torch_seed=cfg.training.rng_seed_data_pt,
imgaug=cfg.training.get('imgaug', 'default'),
)
return data_module