"""Dataset objects store images, labels, and functions for manipulation."""
import os
from pathlib import Path
from typing import Callable, List, Literal, Tuple
import imgaug.augmenters as iaa
import numpy as np
import pandas as pd
import torch
from PIL import Image
from torchtyping import TensorType
from torchvision import transforms
from lightning_pose.data import _IMAGENET_MEAN, _IMAGENET_STD
from lightning_pose.data.datatypes import (
BaseLabeledExampleDict,
HeatmapLabeledExampleDict,
MultiviewHeatmapLabeledExampleDict,
)
from lightning_pose.data.utils import generate_heatmaps
from lightning_pose.utils import io as io_utils
# to ignore imports for sphix-autoapidoc
__all__ = [
"BaseTrackingDataset",
"HeatmapDataset",
"MultiviewHeatmapDataset",
]
[docs]
class BaseTrackingDataset(torch.utils.data.Dataset):
"""Base dataset that contains images and keypoints as (x, y) pairs."""
[docs]
def __init__(
self,
root_directory: str,
csv_path: str,
image_resize_height: int,
image_resize_width: int,
header_rows: list[int] | None = [0, 1, 2],
imgaug_transform: Callable | None = None,
do_context: bool = False,
resize: bool = True,
) -> None:
"""Initialize a dataset for regression (rather than heatmap) models.
The csv file of labels will be searched for in the following order:
1. assume csv is located at `root_directory/csv_path` (i.e. `csv_path`
argument is a path relative to `root_directory`)
2. if not found, assume `csv_path` is absolute. Note the image paths
within the csv must still be relative to `root_directory`
3. if not found, assume dlc directory structure:
`root_directory/training-data/iteration-0/csv_path` (`csv_path`
argument will look like "CollectedData_<scorer>.csv")
Args:
root_directory: path to data directory
csv_path: path to CSV file (within root_directory). CSV file should be in the form
(image_path, bodypart_1_x, bodypart_1_y, ..., bodypart_n_y)
Note: image_path is relative to the given root_directory
resize_height: height to resize images before sending to network
resize_width: height to resize images before sending to network
header_rows: which rows in the csv are header rows
imgaug_transform: imgaug transform pipeline to apply to images
do_context: include additional frames of context if possible.
resize: True to add final resizing augmentation before sending data to network. This
can be set to False if inheritors of this class need to implement more
sophisticated augmentations before resizing (e.g. 3d augmentations). Note that when
this is False, it is up to the child class to perform this resizing on both images
and keypoints before returning a batch of data.
"""
self.root_directory = Path(root_directory)
self.image_resize_height = image_resize_height
self.image_resize_width = image_resize_width
self.csv_path = csv_path
self.header_rows = header_rows
self.do_context = do_context
if resize:
imgaug_transform.add(iaa.Resize({
"height": image_resize_height,
"width": image_resize_width,
}))
self.imgaug_transform = imgaug_transform
# load csv data
if os.path.isfile(csv_path):
csv_file = csv_path
else:
csv_file = os.path.join(root_directory, csv_path)
if not os.path.exists(csv_file):
raise FileNotFoundError(f"Could not find csv file at {csv_file}!")
csv_data = pd.read_csv(csv_file, header=header_rows, index_col=0)
csv_data = io_utils.fix_empty_first_row(csv_data)
self.keypoint_names = io_utils.get_keypoint_names(csv_file=csv_file, header_rows=header_rows)
self.image_names = list(csv_data.index)
self.keypoints = torch.tensor(csv_data.to_numpy(), dtype=torch.float32)
# convert to x,y coordinates
self.keypoints = self.keypoints.reshape(self.keypoints.shape[0], -1, 2)
# send image to tensor and normalize
pytorch_transform_list = [
transforms.ToTensor(),
transforms.Normalize(mean=_IMAGENET_MEAN, std=_IMAGENET_STD),
]
self.pytorch_transform = transforms.Compose(pytorch_transform_list)
# keypoints has been already transformed above
self.num_targets = self.keypoints.shape[1] * 2
self.num_keypoints = self.keypoints.shape[1]
self.data_length = len(self.image_names)
@property
def height(self) -> int:
return self.image_resize_height
@property
def width(self) -> int:
return self.image_resize_width
def __len__(self) -> int:
return self.data_length
def __getitem__(self, idx: int) -> BaseLabeledExampleDict:
img_name = self.image_names[idx]
keypoints_on_image = self.keypoints[idx]
img_path = self.root_directory / img_name
if not self.do_context:
# read image from file and apply transformations (if any)
# if 1 color channel, change to 3.
image = Image.open(img_path).convert("RGB")
if self.imgaug_transform is not None:
transformed_images, transformed_keypoints = self.imgaug_transform(
images=np.expand_dims(image, axis=0),
keypoints=np.expand_dims(keypoints_on_image, axis=0),
) # expands add batch dim for imgaug
# get rid of the batch dim
transformed_images = transformed_images[0]
transformed_keypoints = transformed_keypoints[0].reshape(-1)
else:
transformed_images = np.expand_dims(image, axis=0)
transformed_keypoints = np.expand_dims(keypoints_on_image, axis=0)
transformed_images = self.pytorch_transform(transformed_images)
else:
context_img_paths = io_utils.get_context_img_paths(img_path)
# read the images from image list to create dataset
images = []
for path in context_img_paths:
# read image from file and apply transformations (if any)
if not path.exists():
# revert to center frame
path = context_img_paths[2]
# if 1 color channel, change to 3.
image = Image.open(path).convert("RGB")
images.append(np.asarray(image))
# apply data aug pipeline
if self.imgaug_transform is not None:
# need to apply the same transform to all context frames
seed = np.random.randint(low=0, high=123456)
transformed_images = []
for img in images:
self.imgaug_transform.seed_(seed)
transformed_image, transformed_keypoints = self.imgaug_transform(
images=[img], keypoints=[keypoints_on_image.numpy()]
)
transformed_images.append(transformed_image[0])
transformed_images = np.asarray(transformed_images)
transformed_keypoints = transformed_keypoints[0].reshape(-1)
else:
transformed_images = np.asarray(images)
transformed_keypoints = keypoints_on_image.numpy().reshape(-1)
# send frames to tensors and normalize
# need to loop through because ToTensor transform only operates on single images
for i, transformed_image in enumerate(transformed_images):
transformed_image = self.pytorch_transform(transformed_image)
if i == 0:
image_frames_tensor = torch.unsqueeze(transformed_image, dim=0)
else:
image_expand = torch.unsqueeze(transformed_image, dim=0)
image_frames_tensor = torch.cat(
(image_frames_tensor, image_expand), dim=0
)
transformed_images = image_frames_tensor
assert transformed_keypoints.shape == (self.num_targets,)
return BaseLabeledExampleDict(
images=transformed_images, # shape (3, img_height, img_width) or (5, 3, H, W)
keypoints=torch.from_numpy(transformed_keypoints), # shape (n_targets,)
idxs=idx,
bbox=torch.tensor([0, 0, image.height, image.width]) # x,y,h,w of bounding box
)
# the only addition here, should be the heatmap creation method.
[docs]
class HeatmapDataset(BaseTrackingDataset):
"""Heatmap dataset that contains the images and keypoints in 2D arrays."""
[docs]
def __init__(
self,
root_directory: str,
csv_path: str,
image_resize_height: int,
image_resize_width: int,
header_rows: list[int] | None = [0, 1, 2],
imgaug_transform: Callable | None = None,
downsample_factor: Literal[1, 2, 3] = 2,
do_context: bool = False,
resize: bool = True,
uniform_heatmaps: bool = False,
) -> None:
"""Initialize the Heatmap Dataset.
Args:
root_directory: path to data directory
csv_path: path to CSV or h5 file (within root_directory). CSV file
should be in the form
(image_path, bodypart_1_x, bodypart_1_y, ..., bodypart_n_y)
Note: image_path is relative to the given root_directory
image_resize_height: height to resize images before sending to network
image_resize_width: height to resize images before sending to network
header_rows: which rows in the csv are header rows
imgaug_transform: imgaug transform pipeline to apply to images
downsample_factor: factor by which to downsample original image dims to have a smaller
heatmap
do_context: include additional frames of context if possible
resize: True to add final resizing augmentation before sending data to network. This
can be set to False if inheritors of this class need to implement more
sophisticated augmentations before resizing (e.g. 3d augmentations). Note that when
this is False, it is up to the child class to perform this resizing on both images
and keypoints before returning a batch of data.
uniform_heatmaps: True to force the model to output uniform heatmaps for missing data;
False will output all-zero heatmaps
"""
super().__init__(
root_directory=root_directory,
csv_path=csv_path,
image_resize_height=image_resize_height,
image_resize_width=image_resize_width,
header_rows=header_rows,
imgaug_transform=imgaug_transform,
do_context=do_context,
resize=resize,
)
if self.height % 128 != 0 or self.height % 128 != 0:
print(
"image dimensions (after transformation) must be repeatably "
+ "divisible by 2!"
)
print("current image dimensions after transformation are:")
exit()
self.downsample_factor = downsample_factor
self.output_sigma = 1.25 # should be sigma/2 ^downsample factor
self.uniform_heatmaps = uniform_heatmaps
self.num_targets = torch.numel(self.keypoints[0])
self.num_keypoints = self.num_targets // 2
@property
def output_shape(self) -> tuple:
return (
self.height // 2**self.downsample_factor,
self.width // 2**self.downsample_factor,
)
[docs]
def compute_heatmap(
self, example_dict: BaseLabeledExampleDict
) -> TensorType["num_keypoints", "heatmap_height", "heatmap_width"]:
"""Compute 2D heatmaps from arbitrary (x, y) coordinates."""
# reshape
keypoints = example_dict["keypoints"].reshape(self.num_keypoints, 2)
# introduce new nans where data augmentation has moved the keypoint out of the original
# frame
new_nans = torch.logical_or(
torch.lt(keypoints[:, 0], torch.tensor(0)),
torch.lt(keypoints[:, 1], torch.tensor(0)),
)
new_nans = torch.logical_or(
new_nans, torch.ge(keypoints[:, 0], torch.tensor(self.width))
)
new_nans = torch.logical_or(
new_nans, torch.ge(keypoints[:, 1], torch.tensor(self.height))
)
keypoints[new_nans, :] = torch.nan
y_heatmap = generate_heatmaps(
keypoints=keypoints.unsqueeze(0), # add batch dim
height=self.height,
width=self.width,
output_shape=self.output_shape,
sigma=self.output_sigma,
uniform_heatmaps=self.uniform_heatmaps,
)
return y_heatmap[0]
[docs]
def compute_heatmaps(self):
"""Compute initial 2D heatmaps for all labeled data. Note this will apply augmentations.
original image dims e.g., (406, 396) ->
resized image dims e.g., (384, 384) ->
potentially downsampled heatmaps e.g., (96, 96)
"""
label_heatmaps = torch.empty(
size=(len(self.image_names), self.num_keypoints, *self.output_shape)
)
for idx in range(len(self.image_names)):
example_dict: BaseLabeledExampleDict = super().__getitem__(idx)
label_heatmaps[idx] = self.compute_heatmap(example_dict)
return label_heatmaps
def __getitem__(self, idx: int) -> HeatmapLabeledExampleDict:
"""Get an example from the dataset."""
# call base dataset to get an image and labels
example_dict: BaseLabeledExampleDict = super().__getitem__(idx)
# compute the corresponding heatmaps
example_dict["heatmaps"] = self.compute_heatmap(example_dict)
return example_dict
[docs]
class MultiviewHeatmapDataset(torch.utils.data.Dataset):
"""Heatmap dataset that contains the images and keypoints in 2D arrays from all the cameras."""
[docs]
def __init__(
self,
root_directory: str,
csv_paths: list[str],
view_names: list[str],
image_resize_height: int,
image_resize_width: int,
header_rows: list[int] | None = [0, 1, 2],
imgaug_transform: Callable | None = None,
downsample_factor: Literal[1, 2, 3] = 2,
do_context: bool = False,
resize: bool = True,
uniform_heatmaps: bool = False,
) -> None:
"""Initialize the MultiViewHeatmap Dataset.
Args:
root_directory: path to data directory
csv_paths: paths to CSV files (within root_directory). CSV files
should be in this form
(image_path, bodypart_1_x, bodypart_1_y, ..., bodypart_n_y)
these should match in all CSV files
Note: image_path is relative to the given root_directory
we suggest that these CSV files start with the view numbers
view_names: a list of integers with the view numbers
image_resize_height: height to resize images before sending to network
image_resize_width: height to resize images before sending to network
header_rows: which rows in the csv are header rows
imgaug_transform: imgaug transform pipeline to apply to images
downsample_factor: factor by which to downsample original image dims to have a smaller
heatmap
do_context: include additional frames of context if possible
resize: True to add final resizing augmentation before sending data to network. This
can be set to False if inheritors of this class need to implement more
sophisticated augmentations before resizing (e.g. 3d augmentations). Note that when
this is False, it is up to the child class to perform this resizing on both images
and keypoints before returning a batch of data.
uniform_heatmaps: True to force the model to output uniform heatmaps for missing data;
False will output all-zero heatmaps
"""
if len(view_names) != len(csv_paths):
raise ValueError("number of names does not match with the number of files!")
self.root_directory = root_directory
self.csv_paths = csv_paths
self.view_names = view_names
self.image_resize_height = image_resize_height
self.image_resize_width = image_resize_width
self.do_context = do_context
# do this here so resizing doesn't get added multiple times when iterating over views
if resize:
imgaug_transform.add(iaa.Resize({
"height": image_resize_height,
"width": image_resize_width,
}))
self.imgaug_transform = imgaug_transform
self.downsample_factor = downsample_factor
self.dataset = {}
self.keypoint_names = {}
self.data_length = {}
self.num_keypoints = {}
for view, csv_path in zip(view_names, csv_paths):
self.dataset[view] = HeatmapDataset(
root_directory=root_directory,
csv_path=csv_path,
image_resize_height=image_resize_height,
image_resize_width=image_resize_width,
header_rows=header_rows,
imgaug_transform=imgaug_transform,
downsample_factor=downsample_factor,
do_context=do_context,
resize=False,
uniform_heatmaps=uniform_heatmaps,
)
self.keypoint_names[view] = self.dataset[view].keypoint_names
self.data_length[view] = len(self.dataset[view])
self.num_keypoints[view] = self.dataset[view].num_keypoints
# check if all CSV files have the same number of columns
self.num_keypoints = sum(self.num_keypoints.values())
# check if all the data is in correct order, self.data_length changes here
self.check_data_images_names()
self.num_targets = self.num_keypoints * 2
[docs]
def check_data_images_names(self):
"""Data checking
Each object in self.datasets will have the attribute image_names
(i.e. self.datasets['top'].image_names) since each values is a
HeatmapDataset. Include a check to make sure that the image names
are the same across all views, so that when it loads element n from
each individual view we know these are properly matched.
"""
# check if all CSV files have the same number of rows
if len(set(list(self.data_length.values()))) != 1:
raise ImportError("the CSV files do not match in row numbers!")
for key_num, keypoint in enumerate(self.keypoint_names[self.view_names[0]]):
for view, keypointComp in self.keypoint_names.items():
if keypoint != keypointComp[key_num]:
raise ImportError(f"the keypoints are not in correct order! \
view: {self.view_names[0]} vs {view} | \
{keypoint} != {keypointComp}")
self.data_length = list(self.data_length.values())[0]
for idx in range(self.data_length):
img_file_names = set()
for view, heatmaps in self.dataset.items():
img_file_names.add(Path(heatmaps.image_names[idx]).name)
if len(img_file_names) > 1:
raise ImportError(
f"Discrepancy in image file names across CSV files! "
"index:{idx}, image file names:{img_file_names}"
)
@property
def height(self) -> int:
return self.image_resize_height
@property
def width(self) -> int:
return self.image_resize_width
def __len__(self) -> int:
return self.data_length
@property
def output_shape(self) -> tuple:
return (
self.height // 2**self.downsample_factor,
self.width // 2**self.downsample_factor,
)
@property
def num_views(self) -> int:
return len(self.view_names)
[docs]
def fusion(
self, datadict: dict
) -> Tuple[
TensorType["num_views", "RGB":3, "image_height", "image_width", float]
| TensorType[
"num_views", "frames", "RGB":3, "image_height", "image_width", float
],
TensorType["keypoints"],
TensorType["num_views", "heatmap_height", "heatmap_width", float],
TensorType["num_views * xyhw", float],
List,
]:
"""Merge images, heatmaps, keypoints, and bboxes across views.
Args:
datadict: this comes from HeatmapDataset.__getItems__(idx) for each view.
Returns:
tuple
- images
- keypoints
- heatmaps
- bboxes
- concat order
"""
images = []
keypoints = []
heatmaps = []
bboxes = []
concat_order = []
for view, data in datadict.items():
images.append(data["images"].unsqueeze(0))
data["keypoints"] = data["keypoints"].reshape(int(data["keypoints"].shape[0] / 2), 2)
keypoints.append(data["keypoints"])
heatmaps.append(data["heatmaps"])
bboxes.append(data["bbox"])
concat_order.append(view)
images = torch.cat(images, dim=0)
keypoints = torch.cat(keypoints, dim=0).reshape(-1)
heatmaps = torch.cat(heatmaps, dim=0)
bboxes = torch.cat(bboxes, dim=0)
assert keypoints.shape == (self.num_targets,)
return images, keypoints, heatmaps, bboxes, concat_order
def __getitem__(self, idx: int) -> MultiviewHeatmapLabeledExampleDict:
"""Get an example from the dataset.
Calls the heatmapdataset for each csv file to get
Images and their heatmaps and then stacks them.
"""
datadict = {}
for view in self.view_names:
datadict[view] = self.dataset[view][idx]
images, keypoints, heatmaps, bboxes, concat_order = self.fusion(datadict)
# images normal:[view, RGB, H, W] context:[view, context, RGB, H, W]
return MultiviewHeatmapLabeledExampleDict(
images=images, # shape (3, H, W) or (5, 3, H, W)
keypoints=keypoints, # shape (n_targets,)
heatmaps=heatmaps,
bbox=bboxes,
idxs=idx,
num_views=self.num_views, # int
concat_order=concat_order, # list[str]
view_names=self.view_names, # list[str]
)