predict_dataset
- lightning_pose.utils.predictions.predict_dataset(cfg: DictConfig, data_module: BaseDataModule, preds_file: str, ckpt_file: str | None = None, trainer: Trainer | None = None, model: HeatmapTracker | SemiSupervisedHeatmapTracker | HeatmapTrackerMHCRNN | SemiSupervisedHeatmapTrackerMHCRNN | RegressionTracker | SemiSupervisedRegressionTracker | None = None) DataFrame | Dict[str, DataFrame][source]
Save predicted keypoints for a labeled dataset.
- Parameters:
cfg – hydra config
data_module – data module that contains dataloaders for train, val, test splits
preds_file – absolute filename for the predictions .csv file
ckpt_file – absolute path to the checkpoint of your trained model; requires .ckpt suffix
trainer – pl.Trainer object
model – Lightning Module
- Returns:
pandas dataframe with predictions or dict with dataframe of predictions for each view