LitDaliWrapper

class lightning_pose.data.dali.LitDaliWrapper[source]

Bases: DALIGenericIterator

Typed wrapper around a DALI pipeline iterator for Lightning Pose models.

Converts the raw list-of-dicts that DALIGenericIterator yields into UnlabeledBatchDict or MultiviewUnlabeledBatchDict instances. When a bbox_df is provided, each batch’s frames are also cropped per-frame and resized to the model’s input dimensions before being returned.

_frame_idx tracks the iterator’s position in the video (by frame number) so that _apply_bbox_crop reads the correct rows from bbox_df. It advances by seq_len for base models (non-overlapping windows) and by seq_len - 4 for context models, because the DALI reader for context prediction uses a step of seq_len - 4 so that consecutive 5-frame windows overlap by 4 frames.

Testing without a GPU

This class can be instantiated without a real DALI pipeline or GPU for unit tests that only exercise the PyTorch post-processing logic (e.g. _apply_bbox_crop):

wrapper = object.__new__(LitDaliWrapper)
wrapper.do_context = False
wrapper.bbox_df = my_df
wrapper.resize_dims = [256, 256]
wrapper._frame_idx = 0
__init__(*args: Any, eval_mode: Literal['train', 'predict'], num_iters: int = 1, do_context: bool = False, bbox_df: DataFrame | None = None, resize_dims: list[int] | None = None, **kwargs: Any) None[source]

Wrapper around DALIGenericIterator to get batches for pl.

Parameters:
  • eval_mode – "train" or "predict".

  • num_iters – number of enumerations of dataloader (should be computed outside for now; should be fixed by lightning/dali teams)

  • do_context – whether model/loader use 5-frame context or not

  • bbox_df – optional DataFrame with columns ["x", "y", "h", "w"], one row per frame. When provided, each batch’s frames are cropped per-frame and resized to resize_dims before being returned, and the bbox field of the batch dict is populated with the actual bbox coordinates.

  • resize_dims – target [height, width] for post-crop resize; required when bbox_df is not None.

__new__(**kwargs)