Lightning Pose APIο
Train functionο
- lightning_pose.train.train(cfg: DictConfig | ListConfig, model_dir: str | Path | None = None, skip_evaluation: bool = False) Model[source]ο
Train a model using the configuration
cfg, saving outputs tomodel_dir.- Parameters:
cfg β hydra config object.
model_dir β directory to save model outputs; defaults to cwd if unspecified.
skip_evaluation β if True, skip post-training evaluation.
- Returns:
trained Model instance.
- To train a model using
config.yamland output tooutputs/doc_model: import os from lightning_pose.train import train from omegaconf import OmegaConf cfg = OmegaConf.load("config.yaml") os.chdir("outputs/doc_model") train(cfg)
- To override settings before training:
cfg = OmegaConf.load("config.yaml") overrides = { "training": { "min_epochs": 5, "max_epochs": 5 } } cfg = OmegaConf.merge(cfg, overrides) train(cfg)
Training returns a Model object, which is described next.
Model classο
The Model class provides an easy-to-use interface to a lightning-pose
model. It supports running inference and accessing model metadata.
The set of supported Model operations will expand as we continue development.
You create a model object using Model.from_dir:
from lightning_pose.api.model import Model
model = Model.from_dir("outputs/doc_model")
Then, to predict on new data:
model.predict_on_video_file("path/to/video.mp4")
or:
model.predict_on_label_csv("path/to/csv_file.csv")
To predict on a single numpy frame (no file I/O):
import numpy as np
frame = np.array(...) # (H, W, 3) uint8 RGB
result = model.predict_frame(frame)
keypoints = result["keypoints"] # (num_kp, 2) float32
confidence = result["confidence"] # (num_kp,) float32
API Referenceο
- class lightning_pose.api.model.Model[source]ο
High-level interface for inference with a trained lightning-pose model.
Load a saved model with Model.from_dir, then call prediction methods directly. Model weights are loaded lazily on the first prediction call.
- model_dirο
absolute path to the directory the model is stored in.
- Type:
pathlib.Path
- configο
the model configuration as a ModelConfig object.
- Type:
lightning_pose.api.model_config.ModelConfig
- modelο
the underlying PyTorch model; None until the first prediction call.
- Type:
lightning_pose.models.HeatmapTracker | lightning_pose.models.SemiSupervisedHeatmapTracker | lightning_pose.models.HeatmapTrackerMHCRNN | lightning_pose.models.SemiSupervisedHeatmapTrackerMHCRNN | lightning_pose.models.HeatmapTrackerMultiviewTransformer | lightning_pose.models.SemiSupervisedHeatmapTrackerMultiviewTransformer | lightning_pose.models.RegressionTracker | lightning_pose.models.SemiSupervisedRegressionTracker | None
Examples
>>> from lightning_pose.api import Model >>> model = Model.from_dir("outputs/2024-01-01/12-00-00")
Single-frame inference (no file I/O): >>> import numpy as np >>> frame = np.zeros((256, 256, 3), dtype=np.uint8) >>> result = model.predict_frame(frame) >>> result[βkeypointsβ].shape # (num_keypoints, 2) >>> result[βconfidenceβ].shape # (num_keypoints,)
Predict on a video file: >>> pred_result = model.predict_on_video_file(βpath/to/video.mp4β) >>> pred_result.predictions # pd.DataFrame with MultiIndex columns >>> pred_result.metrics # ComputeMetricsSingleResult or None
Predict on a labeled CSV (also computes pixel error): >>> pred_result = model.predict_on_label_csv(βpath/to/CollectedData.csvβ)
- property cfg: DictConfig | ListConfigο
The model configuration as an omegaconf.DictConfig.
- config: ModelConfigο
The model configuration stored as a ModelConfig object. ModelConfig wraps the omegaconf.DictConfig and provides util functions over it.
- cropped_csv_file_path(csv_file_path: str | Path) Path[source]ο
Return the path where a cropzoom-adjusted CSV file will be saved.
- Parameters:
csv_file_path β path to the original labeled CSV file.
- Returns:
path of the form
{model_dir}/image_preds/{csv_name}/cropped_{csv_name}.
- static from_dir(model_dir: str | Path) Model[source]ο
Create a Model instance for a model stored at model_dir.
- Parameters:
model_dir β path to a model output directory containing
config.yamland a.ckptcheckpoint file.- Returns:
Model ready for inference. Weights are loaded lazily on the first prediction call.
Examples
>>> from lightning_pose.api import Model >>> model = Model.from_dir("outputs/2024-01-01/12-00-00") >>> model.config.is_multi_view() False
- labeled_videos_dir() Path[source]ο
Return the directory where prediction-annotated videos are saved.
- model_dir: Pathο
Directory the model is stored in.
- predict_frame(frame_rgb: ndarray, bbox: tuple[int, int, int, int] | None = None) dict[str, ndarray][source]ο
Single-frame inference. No file I/O, no DALI.
Preprocessing uses cv2 (not DALI). Results will differ numerically from
predict_on_video_filedue to interpolation and normalization differences. Do not mix results from the two paths in quantitative analysis.For MHCRNN (context) models, pass a
(T, H, W, 3)array where T is the temporal context length (typically 5). Passing a single frame to a context model raisesValueErrorβ usepredict_on_video_filefor proper temporal inference.The first call triggers model loading and CUDA initialization, which may take several seconds. Subsequent calls are fast (~5-50ms depending on backbone). For latency-sensitive loops, call once on a dummy frame before entering the loop.
- Parameters:
frame_rgb β
(H, W, 3)uint8 RGB array for standard models, or(T, H, W, 3)uint8 RGB array for context (MHCRNN) models.bbox β Optional
(x, y, w, h)crop region. Note: this is(x, y, width, height), NOT(x1, y1, x2, y2). If provided, crops first, then remaps keypoints back to original coordinates.
- Returns:
- (num_kp, 2) float32 array (x, y) in original frame coords,
- βconfidenceβ: (num_kp,) float32 in [0, 1] β likelihood/confidence
per keypoint. For regression models, confidence is always 1.0.}
- Return type:
{βkeypointsβ
- Raises:
ValueError β If frame_rgb has wrong shape/dtype, bbox has non-positive dimensions, bbox produces an empty crop, or a context model receives single-frame input.
Examples
>>> import numpy as np >>> frame = np.zeros((256, 256, 3), dtype=np.uint8) >>> result = model.predict_frame(frame) >>> result["keypoints"].shape # (num_keypoints, 2) >>> result["confidence"].shape # (num_keypoints,)
With a bounding-box crop (x, y, width, height): >>> result = model.predict_frame(frame, bbox=(100, 50, 128, 128))
- predict_on_label_csv(csv_file: str | Path, data_dir: str | Path | None = None, compute_metrics: bool = True, add_train_val_test_set: bool = False, bbox_file: str | Path | None = None) PredictionResult[source]ο
Predicts on a labeled dataset and computes error/loss metrics if applicable.
- Parameters:
csv_file β path to the CSV file of images and keypoint locations.
data_dir β root path for relative image paths in the CSV file. Defaults to the data_dir used during training.
compute_metrics β whether to compute pixel error and loss metrics on predictions.
add_train_val_test_set β set to True when predicting on the training dataset to add a
setcolumn to the output.bbox_file β optional path to a bbox CSV produced by
litpose create_bbox(or any compatible source). When provided, each frame is cropped to its bounding box before being passed to the model, and predictions are returned in the original (un-cropped) coordinate space.
- Returns:
A PredictionResult object containing the predictions and metrics.
- Return type:
Examples
>>> result = model.predict_on_label_csv("path/to/CollectedData.csv") >>> result.predictions # pd.DataFrame with MultiIndex columns >>> result.metrics.pixel_error # mean pixel error per keypoint
Skip metric computation for faster inference: >>> result = model.predict_on_label_csv( β¦ βpath/to/CollectedData.csvβ, β¦ compute_metrics=False, β¦ )
- predict_on_label_csv_multiview(csv_file_per_view: list[str] | list[Path], bbox_file_per_view: list[str] | list[Path] | None = None, camera_params_file: str | Path | None = None, data_dir: str | Path | None = None, compute_metrics: bool = True, add_train_val_test_set: bool = False) MultiviewPredictionResult[source]ο
Version of
predict_on_label_csvthat gives models access to all views of each frame.- Parameters:
csv_file_per_view β a list of csv files each from a different view of the same session; order must match
view_namesin the config file.
See
predict_on_label_csvdocstring for other arguments.
- predict_on_video_file(video_file: str | Path, output_dir: str | Path | None = 'unspecified', compute_metrics: bool = True, generate_labeled_video: bool = False, progress_file: Path | None = None, bbox_file: str | Path | None = None) PredictionResult[source]ο
Predicts on a video file and computes unsupervised loss metrics if applicable.
- Parameters:
video_file (str | Path) β Path to the video file.
output_dir (str | Path, optional) β The directory to save outputs to. Defaults to {model_dir}/image_preds/{csv_file_name}. If set to None, outputs are not saved.
compute_metrics (bool, optional) β Whether to compute pixel error and loss metrics on predictions.
generate_labeled_video (bool, optional) β Whether to save a labeled video. Defaults to False.
progress_file (Path, optional) β Path to a file to save progress information for the App. Defaults to None.
bbox_file (str | Path, optional) β Path to a per-frame bbox CSV (columns x, y, h, w; one row per frame). When provided, each frame is cropped to its bounding box before being passed to the model, and predictions are returned in the original coordinate space. Single-view only. Defaults to None.
- Returns:
A PredictionResult object containing the predictions and metrics.
- Return type:
Examples
>>> result = model.predict_on_video_file("path/to/video.mp4") >>> result.predictions # pd.DataFrame, one row per frame
Save a keypoint-annotated video alongside the predictions CSV: >>> result = model.predict_on_video_file( β¦ βpath/to/video.mp4β, β¦ generate_labeled_video=True, β¦ )
- predict_on_video_file_multiview(video_file_per_view: list[str] | list[Path], output_dir: str | Path | None = 'unspecified', compute_metrics: bool = True, generate_labeled_video: bool = False, progress_file: Path | None = None) MultiviewPredictionResult[source]ο
Version of
predict_on_video_filethat accesses multiple camera views of each frame.- Parameters:
video_file_per_view β a list of video files each from a different view of the same session; number of files must match
view_namesin the config; order does not matter as files are matched to views by filename.output_dir β directory to save outputs to; defaults to
{model_dir}/video_preds; set to None to skip saving.compute_metrics β whether to compute pixel error and loss metrics on predictions.
generate_labeled_video β whether to save a labeled video.
progress_file β path to a file to save progress information for the App.
- Returns:
object containing the predictions and metrics for each view.
Return typesο
- class lightning_pose.data.datatypes.PredictionResult[source]ο
- metrics: ComputeMetricsSingleResult | None = Noneο
- predictions: DataFrame = <dataclasses._MISSING_TYPE object>ο
- to_dict() dict[str, Any][source]ο
Return predictions and metrics as a flat dict of named numpy arrays.
All arrays have shape
(n_frames, n_keypoints)and share the same row order. Metric arrays areNonewhen the metric was not computed.- Returns:
keypoint_names: list of keypoint name strings.index: list of frame identifiers (file paths or integer indices).x: float array of predicted x coordinates.y: float array of predicted y coordinates.confidence: float array of per-keypoint likelihood in [0, 1].pixel_error: float array or None.temporal_norm: float array or None.pca_singleview_error: float array or None.pca_multiview_error: float array or None.
- Return type:
dict with keys
- class lightning_pose.data.datatypes.MultiviewPredictionResult[source]ο
- metrics: dict[str, ComputeMetricsSingleResult] | None = Noneο
- predictions: dict[str, DataFrame] = <dataclasses._MISSING_TYPE object>ο
- to_dict() dict[str, dict[str, Any]][source]ο
Return predictions and metrics for each view as a flat dict of named numpy arrays.
Wraps
PredictionResult.to_dict()for each view.- Returns:
dict keyed by view name, where each value is the
to_dict()output for that view.