"""Path handling functions."""
from __future__ import annotations # python 3.8 compatibility for sphinx
import collections
import os
import re
import warnings
from pathlib import Path
import numpy as np
import pandas as pd
from omegaconf import DictConfig, ListConfig
from typeguard import typechecked
# to ignore imports for sphix-autoapidoc
__all__ = [
"ckpt_path_from_base_path",
"check_if_semi_supervised",
"get_keypoint_names",
"return_absolute_path",
"return_absolute_data_paths",
"extract_session_name_from_video",
"find_video_files_for_views",
"get_videos_in_dir",
"check_video_paths",
"get_context_img_paths",
"split_video_files_by_view",
]
[docs]
@typechecked
def ckpt_path_from_base_path(
base_path: str,
model_name: str,
logging_dir_name: str = "tb_logs/",
) -> str | None:
"""Given a path to a hydra output with trained model, extract the model .ckpt file.
Prioritizes the checkpoint marked with '-best.ckpt' in the latest version directory.
If no 'best' checkpoint is found, falls back to the latest checkpoint (highest step count).
If only one checkpoint exists, returns it directly.
Args:
base_path (str): path to a folder with logs and checkpoint. for example,
function will search base_path/logging_dir_name/model_name...
model_name (str): the name you gave your model before training it; appears as
model_name in lightning-pose/scripts/config/model_params.yaml
logging_dir_name (str, optional): name of the folder in logs, controlled in
train_hydra.py Defaults to "tb_logs/".
Returns:
str: path to model checkpoint, or None if none found.
Raises:
ValueError: If multiple 'best' checkpoint files are found (e.g., from save_top_k > 1).
"""
import glob
model_search_path = os.path.join(
base_path,
logging_dir_name,
glob.escape(model_name),
"version_*",
"checkpoints",
"*.ckpt",
)
# Find all checkpoint files
all_checkpoint_files = glob.glob(model_search_path)
if not all_checkpoint_files:
return None
# Group checkpoints by version
ckpt_files_by_version = {}
for f in all_checkpoint_files:
match = re.search(r"version_(\d+)", f)
if match:
version = int(match.group(1))
if version not in ckpt_files_by_version:
ckpt_files_by_version[version] = []
ckpt_files_by_version[version].append(f)
if not ckpt_files_by_version:
# Should not happen if all_checkpoint_files is not empty and pattern matches
return None
# Get the latest version's checkpoint files
latest_version = max(ckpt_files_by_version.keys())
latest_version_files = ckpt_files_by_version[latest_version]
# Find all "best" checkpoints in the latest version
best_ckpt_files = []
for f in latest_version_files:
if "-best.ckpt" in os.path.basename(f):
best_ckpt_files.append(f)
if len(best_ckpt_files) == 1:
# Found exactly one 'best' checkpoint
return best_ckpt_files[0]
elif len(best_ckpt_files) > 1:
# Multiple 'best' checkpoints found (e.g., from save_top_k parameter)
raise ValueError(
f"Multiple 'best' checkpoint files found in {latest_version_files}. "
f"Found {len(best_ckpt_files)} files marked as 'best': {best_ckpt_files}. "
"Cannot automatically select from multiple 'best' checkpoints."
)
else:
# No 'best' checkpoint found
warnings.warn(
"No 'best' checkpoint found, falling back to latest checkpoint.",
stacklevel=2,
)
if len(latest_version_files) == 1:
# Only one checkpoint file exists, return it
return latest_version_files[0]
else:
# Multiple checkpoints exist, but none are marked 'best'.
# Try to find the one with the highest step count.
max_step = -1
latest_ckpt = None
for f in latest_version_files:
match = re.search(r"step=(\d+)", f)
if match:
step = int(match.group(1))
if step > max_step:
max_step = step
latest_ckpt = f
if latest_ckpt is not None:
return latest_ckpt
else:
# Could not determine which checkpoint to use
raise ValueError(
"Multiple checkpoint files found but cannot determine which "
f"to use: {latest_version_files}. "
"None are marked as 'best' and cannot parse step counts to determine latest. "
"Please manually select the appropriate checkpoint."
)
[docs]
@typechecked
def check_if_semi_supervised(losses_to_use: ListConfig | list | None = None) -> bool:
"""Use config file to determine if model is semi-supervised.
Take the entry of the hydra cfg that specifies losses_to_use. If it contains
meaningful entries, infer that we want a semi_supervised model.
Args:
losses_to_use (ListConfig, list | None, optional): the cfg entry
specifying semisupervised losses to use. Defaults to None.
Returns:
bool: True if the model is semi_supervised. False otherwise.
"""
if losses_to_use is None: # null
semi_supervised = False
elif len(losses_to_use) == 0: # empty list
semi_supervised = False
elif (
len(losses_to_use) == 1 and losses_to_use[0] == ""
): # list with an empty string
semi_supervised = False
else:
semi_supervised = True
return semi_supervised
[docs]
@typechecked
def get_keypoint_names(
cfg: DictConfig | None = None,
csv_file: str | None = None,
header_rows: list | None = [0, 1, 2],
) -> list[str]:
if os.path.exists(csv_file):
if header_rows is None:
if "header_rows" in cfg.data:
header_rows = list(cfg.data.header_rows)
else:
# assume dlc format
header_rows = [0, 1, 2]
# We're just reading to parse the column structure, so let's only
# read a few rows (nrows=...). Unsure if this includes header rows.
csv_data = pd.read_csv(csv_file, header=header_rows, nrows=5)
# collect marker names from multiindex header
if header_rows == [1, 2] or header_rows == [0, 1]:
# self.keypoint_names = csv_data.columns.levels[0]
# ^this returns a sorted list for some reason, don't want that
keypoint_names = [b[0] for b in csv_data.columns if b[1] == "x"]
elif header_rows == [0, 1, 2]:
# self.keypoint_names = csv_data.columns.levels[1]
keypoint_names = [b[1] for b in csv_data.columns if b[2] == "x"]
else:
keypoint_names = [f"bp_{n}" for n in range(cfg.data.num_targets // 2)]
return keypoint_names
# --------------------------------------------------------------------------------------
# Path handling functions for running toy dataset
# --------------------------------------------------------------------------------------
[docs]
@typechecked
def return_absolute_path(possibly_relative_path: str, n_dirs_back: int = 3) -> str:
"""Return absolute path from possibly relative path."""
if os.path.isabs(possibly_relative_path):
# absolute path already; do nothing
abs_path = possibly_relative_path
else:
# our toy_dataset in relative path
cwd_split = os.getcwd().split(os.path.sep)
desired_path_list = cwd_split[:-n_dirs_back]
if desired_path_list[-1] == "multirun":
# hydra multirun, go one dir back
desired_path_list = desired_path_list[:-1]
abs_path = os.path.join(os.path.sep, *desired_path_list, possibly_relative_path)
if not os.path.exists(abs_path):
raise OSError(f"{abs_path} is not a valid path")
return abs_path
[docs]
@typechecked
def return_absolute_data_paths(
data_cfg: DictConfig, n_dirs_back: int = 3
) -> tuple[str, str]:
"""Generate absolute path for our example toy data.
@hydra.main decorator switches the cwd when executing the decorated function, e.g.,
our train(). so we're in some /outputs/YYYY-MM-DD/HH-MM-SS folder.
Args:
data_cfg (DictConfig): data config file with paths to data and video folders.
n_dirs_back (int):
Returns:
Tuple[str, str]: absolute paths to data and video folders.
"""
data_dir = return_absolute_path(data_cfg.data_dir, n_dirs_back=n_dirs_back)
if os.path.isabs(data_cfg.video_dir):
video_dir = data_cfg.video_dir
else:
video_dir = os.path.join(data_dir, data_cfg.video_dir)
# assert that those paths exist and in the proper format
assert os.path.isdir(data_dir)
assert os.path.isdir(video_dir) or os.path.isfile(video_dir)
return data_dir, video_dir
[docs]
@typechecked
def get_videos_in_dir(
video_dir: str, view_names: list[str] | None = None, return_mp4_only: bool = True
) -> list[str] | list[list[str]]:
"""Gather videos to process from a single directory."""
assert os.path.isdir(video_dir)
# get all video files in directory, from allowed formats
allowed_formats = (".mp4", ".avi", ".mov")
if return_mp4_only:
allowed_formats = ".mp4"
if view_names:
# make a list of lists, where the outer list is over views, each inner list is over videos/
# sessions
all_video_files = sorted(os.listdir(video_dir))
video_files = [
[
os.path.join(video_dir, f)
for f in all_video_files
if (
f.endswith(allowed_formats)
and (f.split(".")[-2].endswith(view) or f"_{view}_" in f)
)
]
for view in view_names
]
# check to make sure we have the same set of videos for each view
# naming convention is <vid_name>_<view_name[0]>, <vid_name>_<view_name[1]>, etc.
vid_names = [
[vid_name.split(f"_{view_names[v]}")[0] for vid_name in video_files_]
for v, video_files_ in enumerate(video_files)
]
for vids_view in vid_names:
if set(vids_view) != set(vid_names[0]):
raise RuntimeError(
"Mismatched video names across views! "
"Please check your videos are in the format "
"<vid_name>_<view_name[0]>, <vid_name>_<view_name[1]>, etc., "
"where the `view_name` variable is defined in the config file."
)
else:
video_files = [
os.path.join(video_dir, f)
for f in os.listdir(video_dir)
if f.endswith(allowed_formats)
]
if len(video_files) == 0:
raise OSError(f"Did not find any valid video files in {video_dir}")
return video_files
[docs]
@typechecked
def check_video_paths(
video_paths: list[str] | str,
view_names: list[str] | None = None,
) -> list[str] | list[list[str]]:
# get input data
if isinstance(video_paths, list):
# presumably a list of files
filenames = video_paths
elif isinstance(video_paths, str) and os.path.isfile(video_paths):
# single video file
filenames = [video_paths]
elif isinstance(video_paths, str) and os.path.isdir(video_paths):
# directory of videos
filenames = get_videos_in_dir(video_paths, view_names=view_names)
else:
raise ValueError(
"`video_paths_list` must be a list of files, a single file, or a directory name"
)
for filename in filenames:
if isinstance(filename, str):
filename = [filename]
for f in filename:
assert f.endswith(".mp4"), "video files must be mp4 format!"
return filenames
def collect_video_files_by_view(
video_files: list[Path], view_names: list[str]
) -> dict[str, Path]:
"""Given a list of video files, matches them to views based on their filenames.
Filenames must contain their corresponding view's name, separated by the rest of the filename
by some non-alphanumeric delimiter. For example, mouse_top_3.mp4 is allowed, but mousetop3.mp4
is not allowed.
"""
assert len(video_files) == len(view_names), (
f"{len(video_files)} != {len(view_names)}"
)
video_files_by_view: dict[str, Path] = {}
for view_name in view_names:
# Search all the video_files for a match.
for video_file in video_files:
if re.search(
rf"(?<!0-9a-zA-Z){re.escape(view_name)}(?![0-9a-zA-Z])", video_file.stem
):
if view_name not in video_files_by_view:
video_files_by_view[view_name] = video_file
else:
raise ValueError(f"File matches multiple views: {video_file}")
# After the search if nothing was added to dict, there is a problem.
if view_name not in video_files_by_view:
raise ValueError(f"File not found for view: {view_name}")
return video_files_by_view
[docs]
@typechecked
def get_context_img_paths(center_img_path: Path) -> list[Path]:
"""Given the path to a center image frame, return paths of 5 context frames
(n-2, n-1, n, n+1, n+2).
Negative indices are floored at 0.
"""
match = re.search(r"(\d+)", center_img_path.stem)
assert match is not None, (
f"No frame index in filename, can't get context frames: {center_img_path.name}"
)
center_index_string = match.group()
center_index = int(center_index_string)
context_img_paths = []
for index in range(
center_index - 2, center_index + 3
): # End at n+3 exclusive, n+2 inclusive.
# Negative indices are floored at 0.
index = max(index, 0)
# Add leading zeros to match center_index_string length.
index_string = str(index).zfill(len(center_index_string))
stem = center_img_path.stem.replace(center_index_string, index_string)
path = center_img_path.with_stem(stem)
context_img_paths.append(path)
return context_img_paths
def fix_empty_first_row(df: pd.DataFrame) -> pd.DataFrame:
"""Fixes a problem with `pd.read_csv` where if the first row is all NaN
it gets dropped.
Pandas uses the first row after a multiline header for the df.index.name.
It would look just like a data row where all values are NaN. Pandas has no
way to distinguish between an index name row, and a NaN data row.
Pandas gh issue: https://github.com/pandas-dev/pandas/issues/21995
Our fix detects if there's an index name, and if so it adds a NaN data row
with index=df.index.name.
"""
if df.index.name is not None:
new_row = {col: np.nan for col in df.columns}
prepend_df = pd.DataFrame(
new_row, index=[df.index.name], columns=df.columns, dtype="float64"
)
fixed_df = pd.concat([prepend_df, df])
assert fixed_df.index.name is None
return fixed_df
return df
def extract_view_name_from_video(
video_filename: str, view_names: list[str]
) -> str | None:
"""Like extract_session_name_from_video but returns the view name (or None if not found)."""
for view_name in view_names:
if view_name in Path(video_filename).stem:
return view_name
return None
[docs]
def split_video_files_by_view(
video_paths: list[Path],
view_names: list[str],
) -> list[list[Path]]:
"""
For a list of videos from different sessions and views, split them up and return a list of
lists like
`[[sess0_view0.mp4, sess0_view1.mp4, ...], [sess1_view0.mp4, sess1_view1.mp4, ...], ...]`
Args:
video_paths: List of paths to video files to split
view_names: List of view names to find videos for
Returns:
List for each session, each containing a sub-list with videos for each view for
that session
"""
# map of session -> view -> video
session_view_video_map = collections.defaultdict(dict[str, Path])
for video_path in video_paths:
view = extract_view_name_from_video(video_path.name, view_names)
if view is not None:
session = extract_session_name_from_video(video_path.name, view_names)
session_view_video_map[session][view] = video_path
video_views_per_session = []
for view_to_video_path in session_view_video_map.values():
# skip sessions with any missing view
if any(view_name not in view_to_video_path for view_name in view_names):
continue
view_list = []
for view_name in view_names:
view_list.append(view_to_video_path[view_name])
video_views_per_session.append(view_list)
return video_views_per_session
[docs]
def find_video_files_for_views(
video_dir: str, view_names: list[str]
) -> list[list[Path]]:
"""
Search inside a folder to find a list of videos from different sessions and views, split them
up and return a list of lists like
`[[sess0_view0.mp4, sess0_view1.mp4, ...], [sess1_view0.mp4, sess1_view1.mp4, ...], ...]`
Args:
video_dir: Directory containing video files
view_names: List of view names to find videos for
Returns:
List for each session, each containing a sub-list with videos for each view for
that session
"""
video_dir_path = Path(video_dir)
if not video_dir_path.exists():
raise FileNotFoundError(f"Video directory not found: {video_dir}")
# Get all video files in the directory
all_video_files = list(video_dir_path.glob("*.mp4"))
if not all_video_files:
raise FileNotFoundError(f"No video files found in {video_dir}")
return split_video_files_by_view(all_video_files, view_names)