predict_dataset

lightning_pose.utils.predictions.predict_dataset(cfg: DictConfig, data_module: BaseDataModule, preds_file: str | list[str], ckpt_file: str | None = None, trainer: Trainer | None = None, model: HeatmapTracker | SemiSupervisedHeatmapTracker | HeatmapTrackerMHCRNN | SemiSupervisedHeatmapTrackerMHCRNN | HeatmapTrackerMultiviewTransformer | SemiSupervisedHeatmapTrackerMultiviewTransformer | RegressionTracker | SemiSupervisedRegressionTracker | None = None) DataFrame | dict[str, DataFrame][source]

Save predicted keypoints for a labeled dataset.

Parameters:
  • cfg – hydra config

  • data_module – data module that contains dataloaders for train, val, test splits

  • preds_file – path for the predictions .csv file

  • ckpt_file – absolute path to the checkpoint of your trained model; requires .ckpt suffix

  • trainer – pl.Trainer object

  • model – Lightning Module

Returns:

pandas dataframe with predictions or dict with dataframe of predictions for each view