calculate_train_batches

lightning_pose.utils.scripts.calculate_train_batches(cfg: DictConfig, dataset: BaseTrackingDataset | HeatmapDataset | MultiviewHeatmapDataset | None = None) int[source]

For semi-supervised models, this tells us how many batches to take from each dataloader (labeled and unlabeled) during a given epoch. The default set here is to exhaust all batches from the labeled data loader, often leaving many video frames untouched. But the unlabeled data loader will be randomly reset for the next epoch. We also enforce a minimum value of 10 so that models with a small number of labeled frames will cycle through the dataset multiple times per epoch, which we have found to be useful empirically.