Skip to content

Commit

Permalink
Support config tags (#167)
Browse files Browse the repository at this point in the history
* Support tags by PathReprStorage

* Support tags by SQLiteReprStorage

* Update extract_features.py to support tags

* Add missing dependency

* Make config tag opaque

* Update extract_features.py sript

* Delete obsolete code

* Update repr storage tests

* Update generate_matches.py script

* Update template_matching.py script

* Update general_tests

* Add missing unit-test dependency

* Optimize module dependencies
  • Loading branch information
stepan-anokhin authored Oct 29, 2020
1 parent 8651922 commit f995086
Show file tree
Hide file tree
Showing 18 changed files with 406 additions and 231 deletions.
1 change: 1 addition & 0 deletions environment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ dependencies:
- jupyter
- jupyterlab
- pip:
- lmdb
- image
- imageio
- moviepy
Expand Down
18 changes: 8 additions & 10 deletions extract_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,13 @@
import click

from db import Database
from db.utils import *
from winnow.feature_extraction import IntermediateCnnExtractor, FrameToVideoRepresentation, SimilarityModel, \
load_featurizer
from winnow.feature_extraction.model import default_model_path
from winnow.storage.db_result_storage import DBResultStorage
from winnow.storage.repr_storage import ReprStorage
from winnow.storage.repr_utils import bulk_read, bulk_write, path_resolver
from winnow.utils import scan_videos, create_video_list, scan_videos_from_txt, resolve_config
from winnow.storage.repr_utils import bulk_read, bulk_write
from winnow.utils import scan_videos, create_video_list, scan_videos_from_txt, resolve_config, reprkey_resolver

logging.getLogger().setLevel(logging.ERROR)
logging.getLogger("winnow").setLevel(logging.INFO)
Expand Down Expand Up @@ -43,11 +42,10 @@



def main(config,list_of_files,frame_sampling,save_frames):

def main(config, list_of_files, frame_sampling, save_frames):
config = resolve_config(config_path=config, frame_sampling=frame_sampling, save_frames=save_frames)
reps = ReprStorage(os.path.join(config.repr.directory))
storepath = path_resolver(source_root=config.sources.root)
reprkey = reprkey_resolver(config)

print('Searching for Dataset Video Files')

Expand All @@ -60,7 +58,7 @@ def main(config,list_of_files,frame_sampling,save_frames):

print('Number of files found: {}'.format(len(videos)))

remaining_videos_path = [path for path in videos if not reps.frame_level.exists(storepath(path), get_hash(path))]
remaining_videos_path = [path for path in videos if not reps.frame_level.exists(reprkey(path))]

print('There are {} videos left'.format(len(remaining_videos_path)))

Expand All @@ -71,7 +69,7 @@ def main(config,list_of_files,frame_sampling,save_frames):
if len(remaining_videos_path) > 0:
# Instantiates the extractor
model_path = default_model_path(config.proc.pretrained_model_local_path)
extractor = IntermediateCnnExtractor(video_src=VIDEOS_LIST, reprs=reps, storepath=storepath,
extractor = IntermediateCnnExtractor(video_src=VIDEOS_LIST, reprs=reps, reprkey=reprkey,
frame_sampling=config.proc.frame_sampling,
save_frames=config.proc.save_frames,
model=(load_featurizer(model_path)))
Expand All @@ -87,13 +85,13 @@ def main(config,list_of_files,frame_sampling,save_frames):
print('Extracting Signatures from Video representations')

sm = SimilarityModel()
signatures = sm.predict(bulk_read(reps.video_level)) # Get dict (path,hash) => signature
signatures = sm.predict(bulk_read(reps.video_level)) # Get {ReprKey => signature} dict

print('Saving Video Signatures on :{}'.format(reps.signature.directory))

if config.database.use:
# Convert dict to list of (path, sha256, signature) tuples
entries = [(path, sha256, sig) for (path, sha256), sig in signatures.items()]
entries = [(key.path, key.hash, sig) for key, sig in signatures.items()]

# Connect to database
database = Database(uri=config.database.uri)
Expand Down
12 changes: 7 additions & 5 deletions generate_matches.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,9 @@ def main(config):
signatures_dict = sm.predict(bulk_read(reps.video_level))

# Unpack paths, hashes and signatures as separate np.arrays
path_hash_pairs, video_signatures = zip(*signatures_dict.items())
paths, hashes = map(np.array, zip(*path_hash_pairs))
repr_keys, video_signatures = zip(*signatures_dict.items())
paths = np.array([key.path for key in repr_keys])
hashes = np.array([key.hash for key in repr_keys])
video_signatures = np.array(video_signatures)


Expand Down Expand Up @@ -117,11 +118,12 @@ def main(config):
print('Filtering dark and/or short videos')

# Get original files for which we have both frames and frame-level features
path_hash_pairs = list(set(reps.video_level.list()))
paths, hashes = zip(*path_hash_pairs)
repr_keys = list(set(reps.video_level.list()))
paths = [key.path for key in repr_keys]
hashes = [key.hash for key in repr_keys]

print('Extracting additional information from video files')
brightness_estimation = np.array([get_brightness_estimation(reps, *path_hash) for path_hash in tqdm(path_hash_pairs)])
brightness_estimation = np.array([get_brightness_estimation(reps, key) for key in tqdm(repr_keys)])
print(brightness_estimation.shape)
metadata_df = pd.DataFrame({"fn": paths,
"sha256": hashes,
Expand Down
3 changes: 2 additions & 1 deletion requirements-winnow-unit-tests.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ sqlalchemy
pyyaml
requests
dataclasses
psycopg2
psycopg2
lmdb
38 changes: 19 additions & 19 deletions tests/general_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

from winnow.feature_extraction import IntermediateCnnExtractor, FrameToVideoRepresentation, SimilarityModel
from winnow.storage.repr_storage import ReprStorage
from winnow.storage.repr_utils import path_resolver, bulk_read
from winnow.utils import scan_videos, create_video_list, get_hash, resolve_config
from winnow.storage.repr_utils import bulk_read
from winnow.utils import scan_videos, create_video_list, resolve_config, reprkey_resolver

NUMBER_OF_TEST_VIDEOS = 40

Expand Down Expand Up @@ -51,10 +51,10 @@ def videos():


@pytest.fixture(scope="module")
def dataset_path_hash_pairs(videos):
def repr_keys(videos):
"""(path_inside_storage,sha256) pairs for test dataset videos."""
storepath = path_resolver(source_root=DATASET_DIR)
return [(storepath(path), get_hash(path)) for path in videos]
reprkey = reprkey_resolver(cfg)
return [reprkey(path) for path in videos]


@pytest.fixture(scope="module")
Expand All @@ -67,9 +67,9 @@ def intermediate_cnn_results(videos, reprs):
Returns:
ReprStorage with populated with intermediate CNN results.
"""
storepath = path_resolver(source_root=DATASET_DIR)
reprkey = reprkey_resolver(cfg)
videos_list = create_video_list(videos, VIDEO_LIST_TXT)
extractor = IntermediateCnnExtractor(videos_list, reprs, storepath)
extractor = IntermediateCnnExtractor(video_src=videos_list, reprs=reprs, reprkey=reprkey)
extractor.start(batch_size=16, cores=4)
return reprs

Expand Down Expand Up @@ -103,8 +103,8 @@ def signatures(frame_to_video_results):
reprs = frame_to_video_results
sm = SimilarityModel()
signatures = sm.predict(bulk_read(reprs.video_level))
for (path, sha256), sig_value in signatures.items():
reprs.signature.write(path, sha256, sig_value)
for repr_key, sig_value in signatures.items():
reprs.signature.write(repr_key, sig_value)
return signatures


Expand All @@ -116,35 +116,35 @@ def test_video_extension_filter(videos):
assert not_videos == 0


def test_intermediate_cnn_extractor(intermediate_cnn_results, dataset_path_hash_pairs):
assert set(intermediate_cnn_results.frame_level.list()) == set(dataset_path_hash_pairs)
def test_intermediate_cnn_extractor(intermediate_cnn_results, repr_keys):
assert set(intermediate_cnn_results.frame_level.list()) == set(repr_keys)

frame_level_features = list(bulk_read(intermediate_cnn_results.frame_level).values())

shapes_correct = sum(features.shape[1] == 4096 for features in frame_level_features)

assert shapes_correct == len(dataset_path_hash_pairs)
assert shapes_correct == len(repr_keys)


def test_frame_to_video_converter(frame_to_video_results, dataset_path_hash_pairs):
assert set(frame_to_video_results.video_level.list()) == set(dataset_path_hash_pairs)
def test_frame_to_video_converter(frame_to_video_results, repr_keys):
assert set(frame_to_video_results.video_level.list()) == set(repr_keys)

video_level_features = np.array(list(bulk_read(frame_to_video_results.video_level).values()))

assert video_level_features.shape == (len(dataset_path_hash_pairs), 1, 4096)
assert video_level_features.shape == (len(repr_keys), 1, 4096)


def test_signatures_shape(signatures, dataset_path_hash_pairs):
assert set(signatures.keys()) == set(dataset_path_hash_pairs)
def test_signatures_shape(signatures, repr_keys):
assert set(signatures.keys()) == set(repr_keys)

signatures_array = np.array(list(signatures.values()))
assert signatures_array.shape == (NUMBER_OF_TEST_VIDEOS, 500)


@pytest.mark.usefixtures("signatures")
def test_saved_signatures(reprs, dataset_path_hash_pairs):
def test_saved_signatures(reprs, repr_keys):
signatures = bulk_read(reprs.signature)
assert set(signatures.keys()) == set(dataset_path_hash_pairs)
assert set(signatures.keys()) == set(repr_keys)

signatures_array = np.array(list(signatures.values()))
assert signatures_array.shape == (NUMBER_OF_TEST_VIDEOS, 500)
93 changes: 58 additions & 35 deletions tests/winnow/storage/test_repr_storage.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import tempfile
from itertools import islice
from uuid import uuid4 as uuid

import numpy as np
import pytest
from dataclasses import asdict

from winnow.storage.path_repr_storage import PathReprStorage
from winnow.storage.lmdb_repr_storage import LMDBReprStorage
from winnow.storage.repr_key import ReprKey
from winnow.storage.repr_utils import bulk_read, bulk_write
from winnow.storage.sqlite_repr_storage import SQLiteReprStorage

Expand All @@ -27,7 +30,24 @@ def store(request):

# Shortcut for pytest parametrize decorator.
# Decorated test will be executed for all existing representation store types.
use_store = pytest.mark.parametrize('store', [PathReprStorage, SQLiteReprStorage], indirect=True)
use_store = pytest.mark.parametrize('store', [LMDBReprStorage, SQLiteReprStorage], indirect=True)


def make_key():
"""Make some repr storage key."""
unique = uuid()
return ReprKey(path=f"some/path-{unique}", hash=f"some-hash-{unique}", tag=f"some-tag-{unique}")


def make_entry():
"""Make some repr storage entry."""
return make_key(), np.array([str(uuid())])


def copy(key, **kwargs):
args = asdict(key)
args.update(kwargs)
return ReprKey(**args)


@use_store
Expand All @@ -37,76 +57,79 @@ def test_empty(store):

@use_store
def test_exists(store):
path, sha256, value = "some/path", "some-hash", np.array(["some-value"])
key, value = make_entry()

# Doesn't exist before write
assert not store.exists(path, sha256)
assert not store.exists(key)

# Exists when written
store.write(path, sha256, value)
assert store.exists(path, sha256)
store.write(key, value)
assert store.exists(key)

# Doesn't exist after deletion
store.delete(path, sha256)
assert not store.exists(path, sha256)
store.delete(key.path)
assert not store.exists(key)


@use_store
def test_read_write(store):
path, sha256, value, another_value = "some/path", "some-hash", np.array(["some-value"]), np.array(["another-value"])
key, value, another_value = make_key(), np.array(["some-value"]), np.array(["another-value"])

store.write(path, sha256, value)
assert store.read(path, sha256) == value
store.write(key, value)
assert store.read(key) == value

store.write(path, sha256, another_value)
assert store.read(path, sha256) == another_value
store.write(key, another_value)
assert store.read(key) == another_value

# Repeat write
store.write(path, sha256, another_value)
assert store.read(path, sha256) == another_value
store.write(key, another_value)
assert store.read(key) == another_value

# Repeat read
assert store.read(path, sha256) == another_value
assert store.read(key) == another_value


@use_store
def test_read_write_multiple(store):
path_1, sha256_1, value_1 = "some/path", "some-hash", np.array(["some-value"])
path_2, sha256_2, value_2 = "other/path", "other-hash", np.array(["other-value"])
key_1, value_1 = make_entry()
key_2, value_2 = make_entry()

store.write(path_1, sha256_1, value_1)
store.write(path_2, sha256_2, value_2)
store.write(key_1, value_1)
store.write(key_2, value_2)

assert store.exists(path_1, sha256_1)
assert store.exists(path_2, sha256_2)
assert store.read(path_1, sha256_1) == value_1
assert store.read(path_2, sha256_2) == value_2
assert store.exists(key_1)
assert store.exists(key_2)
assert store.read(key_1) == value_1
assert store.read(key_2) == value_2

# Mix up path and hash
assert not store.exists(path_1, sha256_2)
assert not store.exists(path_2, sha256_1)
unknown = make_key()
assert not store.exists(unknown)
assert not store.exists(copy(key_1, hash=key_2.hash))
assert not store.exists(copy(key_1, tag=key_2.tag))
assert not store.exists(copy(key_2, hash=key_1.hash))
assert not store.exists(copy(key_2, tag=key_1.tag))


@use_store
def test_list(store):
assert list(store.list()) == []

path_1, sha256_1 = "some/path", "some-hash"
path_2, sha256_2 = "other/path", "other-hash"
key_1, key_2 = make_key(), make_key()

store.write(path_1, sha256_1, np.array(["some-value"]))
assert set(store.list()) == {(path_1, sha256_1)}
store.write(key_1, np.array(["some-value"]))
assert set(store.list()) == {key_1}

store.write(path_2, sha256_2, np.array(["some-value"]))
assert set(store.list()) == {(path_1, sha256_1), (path_2, sha256_2)}
store.write(key_2, np.array(["some-value"]))
assert set(store.list()) == {key_1, key_2}

store.delete(path_1, sha256_1)
assert set(store.list()) == {(path_2, sha256_2)}
store.delete(key_1.path)
assert set(store.list()) == {key_2}


@use_store
def test_bulk_read_write(store):
data_as_dict = {(f"some/path{i}", f"some-hash{i}"): np.array([f"some-value{i}"]) for i in range(100)}
data_as_dict = dict(make_entry() for _ in range(100))

bulk_write(store, data_as_dict)
assert bulk_read(store) == data_as_dict
Expand Down
Loading

0 comments on commit f995086

Please sign in to comment.