PredictionHandler

class lightning_pose.utils.predictions.PredictionHandler[source]

Bases: object

Convert batches of model outputs into a prediction dataframe.

Attributes Summary

do_context

frame_count

Returns the number of frames in the video or the labeled dataset

keypoint_names

Methods Summary

__call__(preds[, is_multiview_video])

Call this function to get a pandas dataframe of the predictions for a single video. Assuming you've already run trainer.predict(), and have a list of Tuple predictions. :param preds: list of tuples of (predictions, confidences) :param is_multiview_video: specify True when you are using multiview video prediction dataloader, i.e. for heatmap_multiview.

add_split_indices_to_df(df)

Add split indices to the dataframe.

fix_context_preds_confs(stacked_preds[, ...])

In the context model, ind=0 is associated with image[2], and ind=1 is associated with image[3], so we need to shift the predictions and confidences by two and eliminate the edges.

make_dlc_pandas_index([keypoint_names])

make_pred_arr_undo_resize(keypoints_np, ...)

Resize keypoints and add confidences into one numpy array.

unpack_preds(preds)

unpack list of preds coming out from pl.trainer.predict, confs tuples into tensors.

Attributes Documentation

do_context
frame_count

Returns the number of frames in the video or the labeled dataset

keypoint_names

Methods Documentation

__call__(preds: list[~typing.Tuple[~torch.Annotated[~torch.Tensor, {'__torchtyping__': True, 'details': ('batch', 'two_times_num_keypoints'), 'cls_name': 'TensorType'}], ~torch.Annotated[~torch.Tensor, {'__torchtyping__': True, 'details': ('batch', 'num_keypoints'), 'cls_name': 'TensorType'}]]], is_multiview_video: bool = False) DataFrame | dict[str, DataFrame][source]

Call this function to get a pandas dataframe of the predictions for a single video. Assuming you’ve already run trainer.predict(), and have a list of Tuple predictions. :param preds: list of tuples of (predictions, confidences) :param is_multiview_video: specify True when you are using multiview video prediction dataloader,

i.e. for heatmap_multiview.

Returns:

index is (frame, bodypart, x, y, likelihood)

Return type:

pd.DataFrame

add_split_indices_to_df(df: DataFrame) DataFrame[source]

Add split indices to the dataframe.

fix_context_preds_confs(stacked_preds: TensorType, zero_pad_confidence: bool = False)[source]

In the context model, ind=0 is associated with image[2], and ind=1 is associated with image[3], so we need to shift the predictions and confidences by two and eliminate the edges. NOTE: confidences are not zero in the first and last two images, they are instead replicas of images[-2] and images[-3]

make_dlc_pandas_index(keypoint_names: list | None = None) MultiIndex[source]
static make_pred_arr_undo_resize(keypoints_np: array, confidence_np: array) array[source]

Resize keypoints and add confidences into one numpy array.

Parameters:
  • keypoints_np – shape (n_frames, n_keypoints * 2)

  • confidence_np – shape (n_frames, n_keypoints)

Returns:

cols are (bp0_x, bp0_y, bp0_likelihood, bp1_x, bp1_y, …)

Return type:

np.ndarray

unpack_preds(preds: list[~typing.Tuple[~torch.Annotated[~torch.Tensor, {'__torchtyping__': True, 'details': ('batch', 'two_times_num_keypoints'), 'cls_name': 'TensorType'}], ~torch.Annotated[~torch.Tensor, {'__torchtyping__': True, 'details': ('batch', 'num_keypoints'), 'cls_name': 'TensorType'}]]]) Tensor, {'__torchtyping__': True, 'details': ('num_frames', 'num_keypoints',), 'cls_name': 'TensorType'}]][source]

unpack list of preds coming out from pl.trainer.predict, confs tuples into tensors. It still returns unnecessary final rows, which should be discarded at the dataframe stage. This works for the output of predict_loader, suitable for batch_size=1, sequence_length=16, step=16

__init__(cfg: DictConfig, data_module: LightningDataModule | None = None, video_file: str | None = None) None[source]
Args

cfg data_module: Only required for prediction of CSV files. video_file: For prediction on video, path to the video file.

Used to get frame_count.

__new__(**kwargs)