Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Factor out get_cached_array_path #202

Merged
merged 9 commits into from
Aug 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 95 additions & 0 deletions tests/test_npy_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from pathlib import Path
import yaml

from zamba.data.video import (
VideoLoaderConfig,
npy_cache,
get_cached_array_path,
load_video_frames,
)

config_yaml = """
cache_dir: local_data/cache
crop_bottom_pixels: 50
early_bias: false
ensure_total_frames: true
evenly_sample_total_frames: false
fps: 4.0
frame_indices: null
frame_selection_height: null
frame_selection_width: null
i_frames: false
megadetector_lite_config:
confidence: 0.25
fill_mode: score_sorted
image_height: 640
image_width: 640
n_frames: 16
nms_threshold: 0.45
seed: 55
sort_by_time: true
model_input_height: 240
model_input_width: 426
pix_fmt: rgb24
scene_threshold: null
total_frames: 16
cleanup_cache: false
cache_dir: data/cache
"""


def test_get_cached_array_path():
config_dict = yaml.safe_load(config_yaml)
config = VideoLoaderConfig(**config_dict)

# NOTE: the validation in VideoLoaderConfig changes some fields,
# so dict(config) != config_dict

cached_load_video_frames = npy_cache(
cache_path=config.cache_dir, cleanup=config.cleanup_cache
)(load_video_frames)
assert isinstance(cached_load_video_frames, type(load_video_frames))

vid_path_str = "data/raw/noemie/Taï_cam197_683044_652175_20161223/01090065.AVI"
vid_path = Path(vid_path_str)

expected_cache_path = vid_path.with_suffix(".npy")
expected_hash = "2d1fee2b1e1f78d06aa08bdea88e7661f927bd81"
expected = config.cache_dir / expected_hash / expected_cache_path

# test video path as string or Path
for video_path in [vid_path_str, vid_path]:
path = get_cached_array_path(video_path, config)
assert path == expected

# pass the cache_dir as a Path
config_dict = yaml.safe_load(config_yaml)
config_dict["cache_dir"] = Path(config_dict["cache_dir"])
config = VideoLoaderConfig(**config_dict)
path = get_cached_array_path(vid_path, config)
assert path == expected

# changing config.cleanup_cache should not affect the key
config_dict = yaml.safe_load(config_yaml)
config_dict["cleanup_cache"] = True
config = VideoLoaderConfig(**config_dict)
path = get_cached_array_path(vid_path, config)
assert path == expected

# changing config.config_dir should change the path but not the hash
config_dict = yaml.safe_load(config_yaml)
config_dict["cache_dir"] = "something/else"
config = VideoLoaderConfig(**config_dict)
path = get_cached_array_path(vid_path, config)
expected_different_path = config.cache_dir / expected_hash / expected_cache_path
assert path == expected_different_path

# changing anything else should change the key but not the path
config_dict = yaml.safe_load(config_yaml)
config_dict["total_frames"] = 8

config = VideoLoaderConfig(**config_dict)
path = get_cached_array_path(vid_path, config)
different_hash = "9becb6d6dfe6b9970afe05af06ef49af4881bd73"
expected_different_hash = config.cache_dir / different_hash / expected_cache_path
assert path == expected_different_hash
57 changes: 39 additions & 18 deletions zamba/data/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,39 @@ def validate_total_frames(cls, values):
return values


def get_cached_array_path(vid_path, config):
"""Get the path to where the cached array would be, if it exists.

vid_path: string path to the video, or Path
config: VideoLoaderConfig

returns: Path object to the cached data
"""
assert isinstance(config, VideoLoaderConfig)

# don't include `cleanup_cache` or `cache_dir` in the hashed config
# NOTE: sorting the keys avoids a cache miss if we see the same config in a different order;
# might not be necessary with a VideoLoaderConfig
config_dict = config.dict()
keys = config_dict.keys() - {"cleanup_cache", "cache_dir"}
hashed_part = {k: config_dict[k] for k in sorted(keys)}

# hash config for inclusion in path
hash_str = hashlib.sha1(str(hashed_part).encode("utf-8")).hexdigest()
logger.opt(lazy=True).debug(f"Generated hash {hash_str} from {hashed_part}")

# strip leading "/" in absolute path
vid_path = AnyPath(str(vid_path).lstrip("/"))

# if the video is in S3, drop the prefix and bucket name
if isinstance(vid_path, S3Path):
vid_path = AnyPath(vid_path.key)

cache_dir = config.cache_dir
npy_path = AnyPath(cache_dir) / hash_str / vid_path.with_suffix(".npy")
return npy_path


class npy_cache:
def __init__(self, cache_path: Optional[Path] = None, cleanup: bool = False):
self.cache_path = cache_path
Expand All @@ -337,28 +370,16 @@ def _wrapped(*args, **kwargs):
except Exception:
vid_path = args[0]
try:
config = kwargs["config"].dict()
config = kwargs["config"]
except Exception:
config = kwargs

# don't include cleanup in the hashed config
config.pop("cleanup_cache")

# hash config for inclusion in filename
hash_str = hashlib.sha1(str(config).encode("utf-8")).hexdigest()
logger.opt(lazy=True).debug(
"Generated hash {hash_str} from {config}",
hash_str=lambda: hash_str,
config=lambda: str(config),
)
config = VideoLoaderConfig(**kwargs)

# strip leading "/" in absolute path
vid_path = AnyPath(str(vid_path).lstrip("/"))
# NOTE: what should we do if this assert fails?
assert config.cache_dir == self.cache_path

if isinstance(vid_path, S3Path):
vid_path = AnyPath(vid_path.key)
# get the path for the cached data
npy_path = get_cached_array_path(vid_path, config)

npy_path = self.cache_path / hash_str / vid_path.with_suffix(".npy")
# make parent directories since we're using absolute paths
npy_path.parent.mkdir(parents=True, exist_ok=True)

Expand Down