Temporal Context Network¶
The use of context frames is another core component of the Lightning Pose algorithm - rather than predicting keypoints at time t solely from the frame at time t, the Temporal Context Network (TCN) uses frames at times [t-2, t-1, t, t+1, t+2] (but only requires labels for time t). This temporal context can be especially helpful for resolving brief occlusions. This page describes updates to the data and config file in order to properly use the TCN.
Data¶
The TCN requires the addition of context frames in the labeled data directory
(referred to as <LABELED_DATA_DIR>
in Organizing your data).
For example, if the labels csv file contains a frame named labeled-data/session_00/img009.png
then you will need to add the frames img007.png
, img008.png
, img010.png
, img011.png
to the directory labeled-data/session_00
. You do not need to change the labels csv file.
To extract specific frames from a video file, you can use the following python function:
import numpy as np
def get_frames_from_idxs(cap, idxs):
"""Helper function to load video segments.
Note
----
To create the VideoCapture object:
>>> import cv2
>>> cap = cv2.VideoCapture(/path/to/video_file)
Parameters
----------
cap : cv2.VideoCapture object
idxs : array-like
frame indices into video
Returns
-------
np.ndarray
returned frames of shape shape (n_frames, n_channels, ypix, xpix)
"""
is_contiguous = np.sum(np.diff(idxs)) == (len(idxs) - 1)
n_frames = len(idxs)
for fr, i in enumerate(idxs):
if fr == 0 or not is_contiguous:
cap.set(1, i)
ret, frame = cap.read()
if ret:
if fr == 0:
height, width, _ = frame.shape
frames = np.zeros((n_frames, 1, height, width), dtype='uint8')
frames[fr, 0, :, :] = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
else:
print(
'warning! reached end of video; returning blank frames for remainder of ' +
'requested indices')
break
return frames
Warning
Check that this function returns the correct frames!
Select a labeled frame that lives in the <LABELED_DATA_DIR>
directory.
Load that exact frame from the raw video and make sure the two match.
Config file¶
Note
Recall that any of the config options can be updated directly from the command line; see the Training section.
There is only one field of the config file that must be updated to properly fit the TCN model,
found in the model
section:
model:
model_type: heatmap_mhcrnn
Batch sizes¶
Supervised training:
The supervised TCN model requires 5x more memory than the standard supervised model, due to the
context frames. You might need to reduce the labeled batch size in training.train_batch_size
to
avoid out of memory errors.
Semi-supervised training:
Context frames can be trivially combined with unsupervised losses to produce a semi-supervised
context model; all that is required is to set model.losses_to_use
as described in the
Unsupervised losses section.
The semi-supervised context model requires at least 5x more memory than the supervised model,
depending on the unlabeled batch size.
The unlabeled batch size for the context model can be set with dali.context.train.batch_size
.
Supervised/semi-supervised inference:
Inference in the TCN model (supervised or unsupervised) is efficiently implemented so that each
frame in a sequence is only processed once; therefore you may not need to adjust inference
batch size, which is found at dali.context.predict.sequence_length
.