From 9e81ad9883c5d63242714899daa9a46a5d04711d Mon Sep 17 00:00:00 2001 From: Stepan Anokhin Date: Wed, 30 Mar 2022 16:01:35 +0700 Subject: [PATCH 01/11] Delete legacy code --- cli/cli/handlers/finder.py | 33 ++-- cli/cli/handlers/pipeline.py | 35 ++-- process_video_url.py | 3 - winnow/pipeline/extract_exif.py | 65 -------- .../pipeline/extract_video_level_features.py | 64 -------- winnow/pipeline/extract_video_signatures.py | 65 -------- winnow/pipeline/find_frame.py | 60 ------- winnow/pipeline/generate_local_matches.py | 151 ------------------ winnow/pipeline/generate_remote_matches.py | 81 ---------- winnow/pipeline/match_templates.py | 124 -------------- winnow/pipeline/prepare_text_search.py | 63 -------- winnow/pipeline/process_urls.py | 37 ----- winnow/pipeline/pull_fingerprints.py | 36 ----- winnow/pipeline/push_fingerprints.py | 36 ----- winnow/pipeline/store_database_signatures.py | 64 -------- 15 files changed, 30 insertions(+), 887 deletions(-) delete mode 100644 winnow/pipeline/extract_exif.py delete mode 100644 winnow/pipeline/extract_video_level_features.py delete mode 100644 winnow/pipeline/extract_video_signatures.py delete mode 100644 winnow/pipeline/find_frame.py delete mode 100644 winnow/pipeline/generate_local_matches.py delete mode 100644 winnow/pipeline/generate_remote_matches.py delete mode 100644 winnow/pipeline/match_templates.py delete mode 100644 winnow/pipeline/prepare_text_search.py delete mode 100644 winnow/pipeline/process_urls.py delete mode 100644 winnow/pipeline/pull_fingerprints.py delete mode 100644 winnow/pipeline/push_fingerprints.py delete mode 100644 winnow/pipeline/store_database_signatures.py diff --git a/cli/cli/handlers/finder.py b/cli/cli/handlers/finder.py index 2790410e..1ef4aa7f 100644 --- a/cli/cli/handlers/finder.py +++ b/cli/cli/handlers/finder.py @@ -1,7 +1,6 @@ from typing import Optional from cli.handlers.errors import handle_errors -from winnow.utils.logging import configure_logging_cli class FinderCli: @@ -13,30 +12,22 @@ def __init__(self, pipeline): @handle_errors def local_matches(self): """Find matches between local videos.""" - from winnow.pipeline.generate_local_matches import generate_local_matches - from winnow.utils.files import scan_videos + import luigi - config = self._pipeline.config - configure_logging_cli(config.logging) + from winnow.pipeline.luigi.matches import MatchesReportTask - videos = scan_videos(config.sources.root, "**", extensions=config.sources.extensions) - generate_local_matches(files=videos, pipeline=self._pipeline) + luigi.build([MatchesReportTask(config=self._pipeline.config)], local_scheduler=True, workers=1) - def remote_matches(self, repo: Optional[str] = None, contributor: Optional[str] = None): + def remote_matches(self, repo: Optional[str] = None): """Find matches between local files and remote fingerprints.""" - from winnow.pipeline.generate_remote_matches import generate_remote_matches + import logging.config + import luigi - config = self._pipeline.config - configure_logging_cli(config.logging) + from winnow.pipeline.luigi.matches import RemoteMatchesTask - if repo is not None: - repo = str(repo) - - if contributor is not None: - contributor = str(contributor) - - generate_remote_matches( - pipeline=self._pipeline, - repository_name=repo, - contributor_name=contributor, + logging.config.fileConfig("./logging.conf") + luigi.build( + [RemoteMatchesTask(config=self._pipeline.config, repository_name=repo)], + local_scheduler=True, + workers=1, ) diff --git a/cli/cli/handlers/pipeline.py b/cli/cli/handlers/pipeline.py index 7ca68003..f2206529 100644 --- a/cli/cli/handlers/pipeline.py +++ b/cli/cli/handlers/pipeline.py @@ -1,6 +1,3 @@ -import os - - class PipelineCli: """Process video files.""" @@ -9,20 +6,24 @@ def __init__(self, config): def all(self): """Process all video files.""" - from winnow.utils.logging import configure_logging_cli - from winnow.pipeline.detect_scenes import detect_scenes - from winnow.pipeline.generate_local_matches import generate_local_matches - from winnow.utils.files import scan_videos - from winnow.pipeline.extract_exif import extract_exif - from winnow.pipeline.pipeline_context import PipelineContext - configure_logging_cli(self._config.logging) + import luigi - # Resolve list of video files from the directory - absolute_root = os.path.abspath(self._config.sources.root) - videos = scan_videos(absolute_root, "**", extensions=self._config.sources.extensions) + from winnow.pipeline.luigi.exif import ExifTask + from winnow.pipeline.luigi.signatures import ( + SignaturesTask, + DBSignaturesTask, + ) + from winnow.pipeline.luigi.matches import MatchesReportTask, DBMatchesTask - pipeline_context = PipelineContext(self._config) - generate_local_matches(files=videos, pipeline=pipeline_context) - detect_scenes(files=videos, pipeline=pipeline_context) - extract_exif(videos, pipeline=pipeline_context) + luigi.build( + [ + ExifTask(config=self._config), + SignaturesTask(config=self._config), + DBSignaturesTask(config=self._config), + MatchesReportTask(config=self._config), + DBMatchesTask(config=self._config), + ], + local_scheduler=True, + workers=1, + ) diff --git a/process_video_url.py b/process_video_url.py index e20a880f..c832e104 100644 --- a/process_video_url.py +++ b/process_video_url.py @@ -5,10 +5,7 @@ import luigi from winnow.pipeline.luigi.download import DownloadFilesTask -from winnow.pipeline.pipeline_context import PipelineContext -from winnow.pipeline.process_urls import process_urls from winnow.utils.config import resolve_config -from winnow.utils.logging import configure_logging_cli @click.command() diff --git a/winnow/pipeline/extract_exif.py b/winnow/pipeline/extract_exif.py deleted file mode 100644 index 29745f45..00000000 --- a/winnow/pipeline/extract_exif.py +++ /dev/null @@ -1,65 +0,0 @@ -import logging -from os.path import join -from typing import Iterable, Optional - -from db.schema import Files -from winnow.pipeline.pipeline_context import PipelineContext -from winnow.pipeline.progress_monitor import ProgressMonitor -from winnow.storage.db_result_storage import DBResultStorage -from winnow.storage.repr_utils import path_resolver -from winnow.utils.files import scan_videos, get_hash -from winnow.utils.metadata_extraction import ( - extract_from_list_of_videos, - convert_to_df, - parse_and_filter_metadata_df, -) - - -def extract_exif( - videos: Optional[Iterable[str]], - pipeline: PipelineContext, - progress_monitor=ProgressMonitor.NULL, -): - """Extract EXIF metadata from video files.""" - - logger = logging.getLogger(__name__) - config = pipeline.config - storepath = path_resolver(config.sources.root) - - if videos is not None: - hashes = [get_hash(video, config.repr.hash_mode) for video in videos] - elif config.database.use: - with pipeline.database.session_scope() as session: - video_records = session.query(Files).filter(Files.contributor == None).yield_per(10 ** 4) # noqa: E711 - path_hash_pairs = [(join(config.sources.root, record.file_path), record.sha256) for record in video_records] - videos, hashes = zip(*path_hash_pairs) - else: - videos = scan_videos(config.sources.root, "**", extensions=config.sources.extensions) - hashes = [get_hash(video, config.repr.hash_mode) for video in videos] - - assert len(videos) > 0, "No videos found" - - logger.info(f"{len(videos)} videos found") - - metadata = extract_from_list_of_videos(videos) - - df = convert_to_df(metadata) - - df_parsed = parse_and_filter_metadata_df(df, metadata) - - assert len(metadata) == len(df_parsed) - - if config.save_files: - - EXIF_REPORT_PATH = join(config.repr.directory, "exif_metadata.csv") - - df_parsed.to_csv(EXIF_REPORT_PATH) - - logger.info(f"Exif Metadata report exported to:{EXIF_REPORT_PATH}") - - if config.database.use: - result_store = DBResultStorage(pipeline.database) - exif_entries = zip(map(storepath, videos), hashes, df_parsed.to_dict("records")) - result_store.add_exifs(exif_entries) - - progress_monitor.complete() diff --git a/winnow/pipeline/extract_video_level_features.py b/winnow/pipeline/extract_video_level_features.py deleted file mode 100644 index adb95971..00000000 --- a/winnow/pipeline/extract_video_level_features.py +++ /dev/null @@ -1,64 +0,0 @@ -import logging -from typing import Collection - -from winnow.feature_extraction.loading_utils import global_vector -from winnow.pipeline.extract_frame_level_features import ( - extract_frame_level_features, - frame_features_exist, -) -from winnow.pipeline.pipeline_context import PipelineContext -from winnow.pipeline.progress_monitor import ProgressMonitor - -# Default module logger -logger = logging.getLogger(__name__) - - -def extract_video_level_features(files: Collection[str], pipeline: PipelineContext, progress=ProgressMonitor.NULL): - """Extract video-level features from the dataset videos.""" - - files = tuple(files) - remaining_video_paths = [*missing_video_features(files, pipeline)] - - # Ensure dependencies are satisfied - if not frame_features_exist(remaining_video_paths, pipeline): - extract_frame_level_features(remaining_video_paths, pipeline, progress=progress.subtask(0.9)) - progress = progress.subtask(0.1) - - # Skip step if required results already exist - if not remaining_video_paths: - logger.info("All required video-level features already exist. Skipping...") - progress.complete() - return - - # Do convert frame-level features into video-level features. - logger.info("Starting video-level feature extraction for %s of %s files", len(remaining_video_paths), len(files)) - frame_to_global(remaining_video_paths, pipeline, progress) - logger.info("Done video-level feature extraction.") - - -def missing_video_features(files, pipeline: PipelineContext): - """Get file paths with missing video-level features.""" - video_features = pipeline.repr_storage.video_level - for i, file_path in enumerate(files): - if not video_features.exists(pipeline.filekey(file_path)): - yield file_path - - -def video_features_exist(files, pipeline: PipelineContext): - """Check if all required video-level features do exist.""" - return not any(missing_video_features(files, pipeline)) - - -def frame_to_global(files, pipeline: PipelineContext, progress=ProgressMonitor.NULL): - """Calculate and save video-level feature vectors based on frame-level representation.""" - progress.scale(len(files)) - for key in map(pipeline.filekey, files): - try: - frame_features = pipeline.repr_storage.frame_level.read(key) - video_representation = global_vector(frame_features) - pipeline.repr_storage.video_level.write(key, video_representation) - except Exception: - logger.exception("Error computing video-level features for file: %s", key) - finally: - progress.increase(1) - progress.complete() diff --git a/winnow/pipeline/extract_video_signatures.py b/winnow/pipeline/extract_video_signatures.py deleted file mode 100644 index c99bfb22..00000000 --- a/winnow/pipeline/extract_video_signatures.py +++ /dev/null @@ -1,65 +0,0 @@ -import logging -from typing import Collection, Dict - -from winnow.feature_extraction import SimilarityModel -from winnow.pipeline.extract_video_level_features import video_features_exist, extract_video_level_features -from winnow.pipeline.pipeline_context import PipelineContext -from winnow.pipeline.progress_monitor import ProgressMonitor -from winnow.storage.file_key import FileKey -from winnow.storage.repr_utils import bulk_read, bulk_write - -# Default module logger -logger = logging.getLogger(__name__) - - -def extract_video_signatures( - files: Collection[str], pipeline: PipelineContext, hashes=None, progress=ProgressMonitor.NULL -): - """Calculate and save signatures for the given files to repr-storage.""" - - files = tuple(files) - - remaining_video_paths = list(missing_video_signatures(files, pipeline)) - - # Ensure dependencies are satisfied - - if not video_features_exist(remaining_video_paths, pipeline): - extract_video_level_features(remaining_video_paths, pipeline, progress=progress.subtask(0.9)) - progress = progress.subtask(0.1) - # Skip step if required results already exist - if not remaining_video_paths: - logger.info("Representation storage contains all required signatures. Skipping...") - progress.complete() - return - - # Do calculate signatures - logger.info("Starting signature extraction for %s of %s files", len(remaining_video_paths), len(files)) - signatures = extract_signatures(remaining_video_paths, pipeline) - bulk_write(pipeline.repr_storage.signature, signatures) - - logger.info("Done signature extraction.") - progress.complete() - - -def missing_video_signatures(files, pipeline: PipelineContext): - """Get file paths with missing signatures.""" - signatures = pipeline.repr_storage.signature - - for i, file_path in enumerate(files): - if not signatures.exists(pipeline.filekey(file_path)): - yield file_path - - -def video_signatures_exist(files, pipeline: PipelineContext): - """Check if all required signatures do exist.""" - - return not any(missing_video_signatures(files, pipeline)) - - -def extract_signatures(files, pipeline: PipelineContext) -> Dict[FileKey, Collection[float]]: - """Do extract signatures for the given video-files.""" - similarity_model = SimilarityModel() - file_keys = [pipeline.filekey(file) for i, file in enumerate(files)] - - video_features = bulk_read(pipeline.repr_storage.video_level, select=file_keys) - return similarity_model.predict(video_features) diff --git a/winnow/pipeline/find_frame.py b/winnow/pipeline/find_frame.py deleted file mode 100644 index 6504b1f6..00000000 --- a/winnow/pipeline/find_frame.py +++ /dev/null @@ -1,60 +0,0 @@ -import logging -from typing import Collection - -from winnow.pipeline.extract_frame_level_features import frame_features_exist, extract_frame_level_features -from winnow.pipeline.pipeline_context import PipelineContext -from winnow.pipeline.progress_monitor import ProgressMonitor -from winnow.search_engine import SearchEngine, Template -from winnow.search_engine.black_list import BlackList -from winnow.search_engine.model import Frame -from winnow.utils.files import get_hash -from winnow.config import Config - -# Default module logger -logger = logging.getLogger(__name__) - - -def find_frame(frame: Frame, files: Collection[str], pipeline: PipelineContext, progress=ProgressMonitor.NULL): - """Find frame among other videos.""" - - config = pipeline.config - - # We don't check for pre-existing templates so far... - # So we always perform search for all videos. - remaining_files = tuple(files) - - # Ensure dependencies are satisfied - if not frame_features_exist(remaining_files, pipeline): - extract_frame_level_features(remaining_files, pipeline, progress=progress.subtask(0.8)) - progress = progress.subtask(0.2) - - template = pipeline.template_loader.load_template_from_frame(frame) - logger.info("Loaded temporary template: %s", template.name) - - black_list = make_black_list(template, frame, config) - logger.info("Frame source file is excluded from the search scope.") - - se = SearchEngine(frame_features=pipeline.repr_storage.frame_level, black_list=black_list) - template_matches = se.create_annotation_report( - templates=[template], - threshold=config.templates.distance, - frame_sampling=config.proc.frame_sampling, - distance_min=config.templates.distance_min, - ) - - tm_entries = template_matches[["path", "hash"]] - tm_entries["template_matches"] = template_matches.drop(columns=["path", "hash"]).to_dict("records") - - logger.info("Found %s frame matches", len(tm_entries)) - progress.complete() - - return tm_entries - - -def make_black_list(template: Template, frame: Frame, config: Config) -> BlackList: - """Exclude the frame source from the template scope.""" - black_list = BlackList() - black_list.exclude_file( - template_name=template.name, file_path=frame.path, file_hash=get_hash(frame.path, config.repr.hash_mode) - ) - return black_list diff --git a/winnow/pipeline/generate_local_matches.py b/winnow/pipeline/generate_local_matches.py deleted file mode 100644 index 94b3188c..00000000 --- a/winnow/pipeline/generate_local_matches.py +++ /dev/null @@ -1,151 +0,0 @@ -import logging -import os -from time import time -from typing import Collection, Dict, Iterable, Set - -import pandas as pd -from dataclasses import asdict, replace -from tqdm import tqdm - -from winnow.duplicate_detection.neighbors import NeighborMatcher, DetectedMatch -from winnow.pipeline.extract_video_level_features import video_features_exist, extract_video_level_features -from winnow.pipeline.extract_video_signatures import video_signatures_exist, extract_video_signatures -from winnow.pipeline.pipeline_context import PipelineContext -from winnow.pipeline.progress_monitor import ProgressMonitor -from winnow.pipeline.store_database_signatures import database_signatures_exist, store_database_signatures -from winnow.storage.file_key import FileKey -from winnow.storage.repr_utils import bulk_read -from winnow.utils.brightness import get_brightness_estimation -from winnow.utils.files import get_hash -from winnow.utils.neighbors import as_vectors - -# Default module logger -logger = logging.getLogger(__name__) - - -def generate_local_matches( - files: Collection[str], pipeline: PipelineContext, hashes=None, progress=ProgressMonitor.NULL -): - """Find matches between video files.""" - - files = tuple(files) - config = pipeline.config - if hashes is None: - hashes = [get_hash(file, config.repr.hash_mode) for file in files] - - # There is no way to check if matches are already generated. - # Hence we must always attempt to generate matches. - - # Ensure dependencies are satisfied - if not video_features_exist(files, pipeline) and config.proc.filter_dark_videos: - extract_video_level_features(files, pipeline, progress=progress.subtask(0.9)) - progress = progress.subtask(0.1) - if not video_signatures_exist(files, pipeline): - extract_video_signatures(files, pipeline, progress=progress.subtask(0.7)) - progress = progress.subtask(0.3) - if not database_signatures_exist(files, pipeline): - store_database_signatures(files, pipeline, progress=progress.subtask(0.2)) - progress = progress.subtask(0.8) - - logger.info("Starting match detection for %s files", len(files)) - - # Load signatures - all_signatures = bulk_read(pipeline.repr_storage.signature) - req_signatures = bulk_read(pipeline.repr_storage.signature, select=map(pipeline.filekey, files)) - - # Do find matches - start_time = time() - neighbor_matcher = NeighborMatcher(haystack=as_vectors(all_signatures)) - matches = neighbor_matcher.find_matches(needles=as_vectors(req_signatures), max_distance=config.proc.match_distance) - logger.info(f"Match detection took {time() - start_time:.3f} seconds") - progress.increase(amount=0.5) - - # Save unfiltered report - unfiltered_report_name = f"matches_at_{config.proc.match_distance}_distance.csv" - unfiltered_report_path = os.path.join(config.repr.directory, unfiltered_report_name) - logger.info("Saving unfiltered report to %s", unfiltered_report_path) - _save_matches_csv(matches, unfiltered_report_path) - - # Filter dark videos - if config.proc.filter_dark_videos: - logger.info("Filtering dark and/or short videos") - - video_features = pipeline.repr_storage.video_level - file_keys = tuple(map(pipeline.filekey, files)) - brightness = {key: get_brightness_estimation(video_features.read(key)) for key in tqdm(file_keys)} - - threshold = config.proc.filter_dark_videos_thr - metadata = {key: _metadata(gray_max, threshold) for key, gray_max in brightness.items()} - - discarded = {key for key, meta in metadata.items() if meta["flagged"]} - matches = list(_reject(matches, discarded)) - - if config.database.use: - result_storage = pipeline.result_storage - result_storage.add_metadata((key.path, key.hash, meta) for key, meta in metadata.items()) - - if config.save_files: - filtered_report_name = f"matches_at_{config.proc.match_distance}_distance_filtered.csv" - filtered_report_path = os.path.join(config.repr.directory, filtered_report_name) - logger.info("Saving Filtered Matches report to %s", filtered_report_path) - _save_matches_csv(matches, filtered_report_path) - - metadata_report_path = os.path.join(config.repr.directory, "metadata_signatures.csv") - logger.info("Saving metadata to %s", metadata_report_path) - _save_metadata_csv(metadata, metadata_report_path) - - if config.database.use: - result_storage = pipeline.result_storage - result_storage.add_matches(_entry(match) for match in matches) - - progress.complete() - - -def _metadata(gray_max, threshold) -> Dict: - """Create metadata dict.""" - video_dark_flag = gray_max < threshold - return {"gray_max": gray_max, "video_dark_flag": video_dark_flag, "flagged": video_dark_flag} - - -def _reject(detected_matches: Iterable[DetectedMatch], discarded: Set[FileKey]): - """Reject discarded matches.""" - for match in detected_matches: - if match.needle_key not in discarded and match.haystack_key not in discarded: - yield _order_match(match) - - -def _order_match(match: DetectedMatch): - """Order match and query file keys""" - if match.haystack_key.path <= match.needle_key.path: - return replace(match, haystack_key=match.needle_key, needle_key=match.haystack_key) - return match - - -def _entry(detected_match: DetectedMatch): - """Flatten (query_key, match_key, dist) match entry.""" - query, match = detected_match.needle_key, detected_match.haystack_key - return query.path, query.hash, match.path, match.hash, detected_match.distance - - -def _save_matches_csv(matches: Iterable[DetectedMatch], path): - """Save matches to csv file.""" - dataframe = pd.DataFrame( - tuple(_entry(match) for match in matches), - columns=[ - "query_video", - "query_sha256", - "match_video", - "match_sha256", - "distance", - ], - ) - dataframe.to_csv(path) - - -def _save_metadata_csv(metadata: Dict[FileKey, Dict], path): - """Save metadata to csv file.""" - keys, metas = map(tuple, zip(*metadata.items())) - keys = tuple(map(asdict, keys)) - dataframe = pd.DataFrame(metas) - dataframe = dataframe.merge(pd.DataFrame(keys), left_index=True, right_index=True) - dataframe.to_csv(path) diff --git a/winnow/pipeline/generate_remote_matches.py b/winnow/pipeline/generate_remote_matches.py deleted file mode 100644 index dc56fe5c..00000000 --- a/winnow/pipeline/generate_remote_matches.py +++ /dev/null @@ -1,81 +0,0 @@ -import logging -from math import ceil -from typing import Iterable, List, Dict - -from winnow.duplicate_detection.neighbors import NeighborMatcher, FeatureVector, DetectedMatch -from winnow.pipeline.pipeline_context import PipelineContext -from winnow.pipeline.progress_monitor import ProgressMonitor, ProgressBar -from remote.model import RemoteFingerprint -from winnow.storage.remote_signatures_dao import RemoteMatch -from winnow.storage.repr_utils import bulk_read -from winnow.utils.iterators import chunks - -# Default module logger -from winnow.utils.neighbors import as_vectors - -logger = logging.getLogger(__name__) - - -def generate_remote_matches( - pipeline: PipelineContext, - repository_name: str = None, - contributor_name: str = None, - progress=ProgressMonitor.NULL, -): - """Find matches between local and remote files.""" - - chunk_size = 1000 - config = pipeline.config - - logger.info("Starting remote match detection.") - - # Prepare index of local signatures to detect matches - local_signatures = bulk_read(pipeline.repr_storage.signature) - neighbor_matcher = NeighborMatcher(haystack=as_vectors(local_signatures)) - - # Acquire remote signature storage - storage = pipeline.remote_signature_dao - - # Configure progress monitor - total_work, step_work = _progress(storage, repository_name, contributor_name, chunk_size) - progress.scale(total_work=total_work) - progress = ProgressBar(progress) - - # Load remote matches by chunks and do find matches - for remote_signatures in chunks(storage.query_signatures(repository_name, contributor_name), size=chunk_size): - remote_index = {remote.id: remote for remote in remote_signatures} - needles = (FeatureVector(key=remote.id, features=remote.fingerprint) for remote in remote_signatures) - detected_matches = neighbor_matcher.find_matches(needles=needles, max_distance=config.proc.match_distance) - storage.save_matches(remote_matches(detected_matches, remote_index)) - progress.increase(amount=step_work) - - logger.info("Done remote match detection.") - progress.complete() - - -def _progress(remote_signatures_storage, repository_name, contributor_name, chunk_size): - """Get total work and step work to account all available remote fingerprints.""" - total_count = remote_signatures_storage.count( - repository_name=repository_name, - contributor_name=contributor_name, - ) - if total_count is None: - return 1.0, 0 - iterations = ceil(max(total_count, 1) / float(chunk_size)) - return iterations, 1 - - -def remote_matches( - detected_matches: Iterable[DetectedMatch], - remote_sigs: Dict[int, RemoteFingerprint], -) -> List[RemoteMatch]: - """Convert detected feature-vector matches to remote matches.""" - results = [] - for detected_match in detected_matches: - remote_match = RemoteMatch( - remote=remote_sigs[detected_match.needle_key], - local=detected_match.haystack_key, - distance=detected_match.distance, - ) - results.append(remote_match) - return results diff --git a/winnow/pipeline/match_templates.py b/winnow/pipeline/match_templates.py deleted file mode 100644 index 17c1503f..00000000 --- a/winnow/pipeline/match_templates.py +++ /dev/null @@ -1,124 +0,0 @@ -import logging -import os -from typing import Collection, List - -from sqlalchemy.orm import joinedload - -from db.schema import TemplateFileExclusion, TemplateMatches -from winnow.pipeline.extract_frame_level_features import frame_features_exist, extract_frame_level_features -from winnow.pipeline.pipeline_context import PipelineContext -from winnow.pipeline.progress_monitor import ProgressMonitor -from winnow.search_engine import Template -from winnow.search_engine.black_list import BlackList -from winnow.search_engine.template_matching import SearchEngine - -# Default module logger -logger = logging.getLogger(__name__) - - -def match_templates(files: Collection[str], pipeline: PipelineContext, progress=ProgressMonitor.NULL): - """Match existing templates with dataset videos.""" - - config = pipeline.config - - # We don't check for pre-existing templates so far... - # So we always perform search for all videos. - remaining_files = tuple(files) - - # Ensure dependencies are satisfied - if not frame_features_exist(remaining_files, pipeline): - extract_frame_level_features(remaining_files, pipeline, progress=progress.subtask(0.7)) - progress = progress.subtask(0.3) - - # Load templates - templates = load_templates(pipeline) - logger.info("Loaded %s templates", len(templates)) - if len(templates) == 0: - logger.info("No templates found. Skipping template matching step...") - progress.complete() - return - - # Load file exclusions - black_list = load_black_list(pipeline) - logger.info( - "Found %s file exclusions and %s time exclusions", - black_list.file_exclusions_count, - black_list.time_exclusions_count, - ) - - se = SearchEngine(frame_features=pipeline.repr_storage.frame_level, black_list=black_list) - template_matches = se.create_annotation_report( - templates=templates, - threshold=config.templates.distance, - frame_sampling=config.proc.frame_sampling, - distance_min=config.templates.distance_min, - ) - - tm_entries = template_matches[["path", "hash"]] - tm_entries["template_matches"] = template_matches.drop(columns=["path", "hash"]).to_dict("records") - - if config.database.use: - # Save Template Matches - result_storage = pipeline.result_storage - template_names = {template.name for template in templates} - result_storage.add_template_matches(template_names, tm_entries.to_numpy()) - - if config.save_files: - template_matches_report_path = os.path.join(config.repr.directory, "template_matches.csv") - template_matches.to_csv(template_matches_report_path) - - logger.info("Template Matches report exported to: %s", template_matches_report_path) - - template_test_output = os.path.join(pipeline.config.repr.directory, "template_test.csv") - logger.info("Report saved to %s", template_test_output) - progress.complete() - - -def load_templates(pipeline: PipelineContext) -> List[Template]: - """Load templates according to the pipeline config.""" - config = pipeline.config - templates_source = config.templates.source_path - if templates_source: - logger.info("Loading templates from: %s", templates_source) - templates = pipeline.template_loader.load_templates_from_folder(templates_source) - if config.database.use: - return pipeline.template_loader.store_templates(templates, pipeline.database, pipeline.file_storage) - return templates - elif config.database.use: - logger.info("Loading templates from the database") - return pipeline.template_loader.load_templates_from_database(pipeline.database, pipeline.file_storage) - else: - logger.error("Neither database nor template source directory are not available") - return [] - - -def load_black_list(pipeline: PipelineContext) -> BlackList: - """Get template file exclusions.""" - - # Load file exclusions - config = pipeline.config - file_exclusions = () - time_exclusions = () - if config.database.use: - with pipeline.database.session_scope(expunge=True) as session: - file_exclusions = ( - session.query(TemplateFileExclusion) - .options(joinedload(TemplateFileExclusion.file)) - .options(joinedload(TemplateFileExclusion.template)) - .all() - ) - time_exclusions = ( - session.query(TemplateMatches) - .options(joinedload(TemplateMatches.file)) - .options(joinedload(TemplateMatches.template)) - .filter(TemplateMatches.false_positive == True) # noqa: E712 - .all() - ) - - # Populate black list - black_list = BlackList() - for file_exclusion in file_exclusions: - black_list.exclude_file_entity(file_exclusion) - for time_exclusion in time_exclusions: - black_list.exclude_time_range(time_exclusion) - return black_list diff --git a/winnow/pipeline/prepare_text_search.py b/winnow/pipeline/prepare_text_search.py deleted file mode 100644 index 42d5f55f..00000000 --- a/winnow/pipeline/prepare_text_search.py +++ /dev/null @@ -1,63 +0,0 @@ -import logging - -import numpy as np -from dataclasses import astuple - -from db.access.files import FilesDAO -from winnow.pipeline.pipeline_context import PipelineContext -from winnow.pipeline.progress_monitor import ProgressMonitor -from winnow.text_search.similarity_index import AnnoySimilarityIndex -from winnow.utils.iterators import chunks - - -def prepare_text_search( - pipeline: PipelineContext, - force: bool = True, - progress: ProgressMonitor = ProgressMonitor.NULL, -): - """Build text similarity index for existing fingerprints.""" - logger = logging.getLogger(__name__) - progress.scale(1.0) - - # Skip index building if not required - if not force and pipeline.text_search_id_index_exists(): - logger.info("Text search index already exists.") - progress.complete() - return - elif force and pipeline.text_search_id_index_exists(): - logger.info("Rebuilding text search index because force=True") - elif not pipeline.text_search_id_index_exists(): - logger.info("Text-search is missing and will be created.") - - logger.info("Loading fingerprints.") - fingerprint_storage = pipeline.repr_storage.signature - file_keys, fingerprints = [], [] - for file_key in fingerprint_storage.list(): - fingerprint = fingerprint_storage.read(file_key) - file_keys.append(file_key) - fingerprints.append(fingerprint) - progress.increase(0.3) - logger.info("Loaded %s fingerprints", len(fingerprints)) - - logger.info("Getting database ids") - file_ids = [] - for chunk in chunks(file_keys, size=10000): - with pipeline.database.session_scope() as session: - ids_chunk = FilesDAO.query_local_file_ids(session, map(astuple, chunk)).all() - file_ids.extend(ids_chunk) - logger.info("Loaded %s database ids", len(file_ids)) - progress.increase(0.1) - - logger.info("Loading the text search model") - model = pipeline.text_search_model - - logger.info("Converting fingerprints using the text-search model") - vectors = np.array([model.embed_vis(fingerprint)[0].numpy() for fingerprint in fingerprints]) - - logger.info("Building annoy text search index.") - index = AnnoySimilarityIndex() - index.fit(file_ids, np.array(vectors)) - - logger.info("Saving text-search index.") - index.save(pipeline.config.repr.directory, pipeline.TEXT_SEARCH_INDEX_NAME, pipeline.TEXT_SEARCH_DATABASE_IDS_NAME) - progress.complete() diff --git a/winnow/pipeline/process_urls.py b/winnow/pipeline/process_urls.py deleted file mode 100644 index 034eb433..00000000 --- a/winnow/pipeline/process_urls.py +++ /dev/null @@ -1,37 +0,0 @@ -import logging -from typing import Collection - -from winnow.pipeline.detect_scenes import detect_scenes -from winnow.pipeline.extract_exif import extract_exif -from winnow.pipeline.generate_local_matches import generate_local_matches -from winnow.pipeline.pipeline_context import PipelineContext -from winnow.pipeline.progress_monitor import ProgressMonitor -from winnow.utils.download import download_videos - - -def process_urls( - urls: Collection[str], - pipeline: PipelineContext, - destination_template: str = "%(title)s.%(ext)s", - progress=ProgressMonitor.NULL, -): - """Process single file url.""" - - logger = logging.getLogger(__name__) - logger.info("Processing %s video URLs", len(urls)) - - progress.scale(1.0) - file_paths = download_videos( - urls=urls, - root_directory=pipeline.config.sources.root, - output_template=destination_template, - progress=progress.subtask(0.2), - logger=logger, - suppress_errors=True, - ) - - generate_local_matches(files=file_paths, pipeline=pipeline, progress=progress.subtask(0.7)) - detect_scenes(files=file_paths, pipeline=pipeline, progress=progress.subtask(0.05)) - extract_exif(videos=file_paths, pipeline=pipeline, progress_monitor=progress.subtask(0.05)) - - return file_paths diff --git a/winnow/pipeline/pull_fingerprints.py b/winnow/pipeline/pull_fingerprints.py deleted file mode 100644 index 4decfec2..00000000 --- a/winnow/pipeline/pull_fingerprints.py +++ /dev/null @@ -1,36 +0,0 @@ -import logging - -from remote import make_client -from winnow.pipeline.pipeline_context import PipelineContext -from winnow.pipeline.progress_monitor import ProgressMonitor - -# Default module logger -logger = logging.getLogger(__name__) - - -def pull_fingerprints(repository_name: str, pipeline: PipelineContext, progress=ProgressMonitor.NULL): - """Pull fingerprints to the remote repository.""" - - # Get repository - repo = pipeline.repository_dao.get(repository_name) - if repo is None: - logger.error("Unknown repository name: %s", repository_name) - progress.complete() - return - - progress.scale(total_work=1.0) - connector = pipeline.make_connector(repo) - logger.info("Pulling fingerprints from %s", repository_name) - try: - connector.pull_all(chunk_size=10000, progress=progress.subtask(1.0)) - logger.info("Finished pulling fingerprints from %s", repository_name) - except Exception: - logger.exception("Error pulling fingerprints from %s", repository_name) - raise - finally: - # Update repo metadata - logger.info("Updating '%s' repository metadata", repository_name) - client = make_client(repo) - repo_stats = client.get_stats() - pipeline.repository_dao.update_stats(repo_stats) - progress.complete() diff --git a/winnow/pipeline/push_fingerprints.py b/winnow/pipeline/push_fingerprints.py deleted file mode 100644 index 33a0642b..00000000 --- a/winnow/pipeline/push_fingerprints.py +++ /dev/null @@ -1,36 +0,0 @@ -import logging - -from remote import make_client -from winnow.pipeline.pipeline_context import PipelineContext -from winnow.pipeline.progress_monitor import ProgressMonitor - -# Default module logger -logger = logging.getLogger(__name__) - - -def push_fingerprints(repository_name: str, pipeline: PipelineContext, progress=ProgressMonitor.NULL): - """Push fingerprints to the remote repository.""" - - # Get repository - repo = pipeline.repository_dao.get(repository_name) - if repo is None: - logger.error("Unknown repository name: %s", repository_name) - progress.complete() - return - - progress.scale(total_work=1.0) - connector = pipeline.make_connector(repo) - logger.info("Pushing fingerprints to %s", repository_name) - try: - connector.push_all(chunk_size=10000, progress=progress.subtask(1.0)) - logger.info("Finished pushing fingerprints to %s", repository_name) - except Exception: - logger.exception("Error pushing fingerprints to %s", repository_name) - raise - finally: - # Update repo metadata - logger.info("Updating '%s' repository metadata", repository_name) - client = make_client(repo) - repo_stats = client.get_stats() - pipeline.repository_dao.update_stats(repo_stats) - progress.complete() diff --git a/winnow/pipeline/store_database_signatures.py b/winnow/pipeline/store_database_signatures.py deleted file mode 100644 index 2a3bb6e4..00000000 --- a/winnow/pipeline/store_database_signatures.py +++ /dev/null @@ -1,64 +0,0 @@ -import logging -import os -from pickle import dumps -from typing import Collection - -from dataclasses import astuple - -from db.access.files import FilesDAO -from winnow.pipeline.extract_video_signatures import extract_video_signatures, video_signatures_exist -from winnow.pipeline.pipeline_context import PipelineContext -from winnow.pipeline.progress_monitor import ProgressMonitor -from winnow.storage.repr_utils import bulk_read - -logger = logging.getLogger(__name__) - - -def store_database_signatures(files: Collection[str], pipeline: PipelineContext, progress=ProgressMonitor.NULL): - """Ensure result database contains signatures for all given files.""" - - if not pipeline.config.database.use: - logger.info("Database is disabled. Skipping database signatures update...") - progress.complete() - return - - files = tuple(files) - remaining_video_paths = tuple(missing_database_signatures(files, pipeline)) - - # Ensure dependencies are satisfied - if not video_signatures_exist(remaining_video_paths, pipeline): - extract_video_signatures(remaining_video_paths, pipeline, progress=progress.subtask(0.9)) - progress = progress.subtask(0.1) - - # Skip step if required results already exist - if not remaining_video_paths: - logger.info("Database contains all required signatures. Skipping...") - progress.complete() - return - - # Save signatures to database if needed - logger.info("Saving signatures to the database for %s of %s files", len(remaining_video_paths), len(files)) - file_keys = map(pipeline.filekey, remaining_video_paths) - signatures = bulk_read(pipeline.repr_storage.signature, select=file_keys) - pipeline.result_storage.add_signatures((key.path, key.hash, dumps(sig)) for key, sig in signatures.items()) - - logger.info("Done saving %s signatures to database.", len(remaining_video_paths)) - progress.complete() - - -def missing_database_signatures(files, pipeline: PipelineContext): - """Get file paths with missing signatures.""" - - if not pipeline.config.database.use: - return - - with pipeline.database.session_scope() as session: - file_keys = map(pipeline.filekey, files) - path_hash_pairs = map(astuple, file_keys) - for (path, _) in FilesDAO.select_missing_signatures(path_hash_pairs, session): - yield os.path.join(pipeline.config.sources.root, path) - - -def database_signatures_exist(files, pipeline: PipelineContext): - """Check if all required signatures do exist.""" - return not any(missing_database_signatures(files, pipeline)) From 0fd9185616833be50aba4d7c4ccc40b4ba568751 Mon Sep 17 00:00:00 2001 From: Stepan Anokhin Date: Wed, 20 Apr 2022 21:22:05 +0700 Subject: [PATCH 02/11] Implement tiles generation task --- server/server/{queue => }/time_utils.py | 0 winnow/pipeline/luigi/embeddings_tiles.py | 453 ++++++++++++++++++++++ 2 files changed, 453 insertions(+) rename server/server/{queue => }/time_utils.py (100%) create mode 100644 winnow/pipeline/luigi/embeddings_tiles.py diff --git a/server/server/queue/time_utils.py b/server/server/time_utils.py similarity index 100% rename from server/server/queue/time_utils.py rename to server/server/time_utils.py diff --git a/winnow/pipeline/luigi/embeddings_tiles.py b/winnow/pipeline/luigi/embeddings_tiles.py new file mode 100644 index 00000000..5a079155 --- /dev/null +++ b/winnow/pipeline/luigi/embeddings_tiles.py @@ -0,0 +1,453 @@ +import abc +import json +import math +import os +import shutil +from typing import Union, Tuple, List + +import luigi +import matplotlib.pyplot as plt +import numpy as np +from dataclasses import dataclass +from matplotlib.pyplot import Figure, Axes + +from winnow.pipeline.luigi.condense import CondensedFingerprintsTarget +from winnow.pipeline.luigi.embeddings import ( + PaCMAPEmbeddingsTask, + TSNEEmbeddingsTask, + UmapEmbeddingsTask, + TriMapEmbeddingsTask, +) +from winnow.pipeline.luigi.platform import PipelineTask +from winnow.pipeline.luigi.targets import FileWithTimestampTarget +from winnow.pipeline.progress_monitor import BaseProgressMonitor, ProgressMonitor + + +@dataclass +class BBox: + """A two-dimensional bounding box of the collection of points.""" + + min_x: float + max_x: float + min_y: float + max_y: float + + @property + def y_lim(self) -> Tuple[float, float]: + """Y-axis limits.""" + return self.min_y, self.max_y + + @property + def x_lim(self) -> Tuple[float, float]: + """X-axis limits.""" + return self.min_x, self.max_x + + @property + def width(self) -> float: + """Bounding box width.""" + return abs(self.max_x - self.min_x) + + @property + def height(self) -> float: + """Bounding box height.""" + return abs(self.max_y - self.min_y) + + @property + def center(self) -> Tuple[float, float]: + """Get the bonding box center.""" + center_x = (self.max_x + self.min_x) / 2.0 + center_y = (self.max_y + self.min_y) / 2.0 + return center_x, center_y + + def squared(self) -> "BBox": + """Make Bounding Box equal-sided.""" + center_x, center_y = self.center + delta = max(self.width, self.height) / 2.0 + return BBox.make( + min_x=center_x - delta, + max_x=center_x + delta, + min_y=center_y - delta, + max_y=center_y + delta, + ) + + def select(self, points: np.ndarray, rel_margin: float = 0.0) -> np.ndarray: + """Select points from this bounding box.""" + margin_x = self.width * rel_margin + margin_y = self.height * rel_margin + selection_x = np.logical_and( + points[:, 0] >= self.min_x - margin_x, + points[:, 0] <= self.max_x + margin_x, + ) + selection_y = np.logical_and( + points[:, 1] >= self.min_y - margin_y, + points[:, 1] <= self.max_y + margin_y, + ) + selection = np.logical_and(selection_x, selection_y) + return points[selection] + + def subbox(self, divide, x_index, y_index) -> "BBox": + """Get sub-bounding box calculated by dividing current box `divide` times along x and y.""" + width = self.width / divide + height = self.height / divide + x_offset = self.min_x + width * x_index + # y indexing is in the top-down direction + y_offset = self.max_y - height * (y_index + 1) + return BBox.make( + min_x=x_offset, + max_x=x_offset + width, + min_y=y_offset, + max_y=y_offset + height, + ) + + @staticmethod + def make( + min_x: float, + max_x: float, + min_y: float, + max_y: float, + ) -> "BBox": + """Create BBox ensuring min/max ordering.""" + return BBox( + min_x=min(min_x, max_x), + max_x=max(min_x, max_x), + min_y=min(min_y, max_y), + max_y=max(min_y, max_y), + ) + + @staticmethod + def calculate( + pints: np.ndarray, + rel_ignored_outliers: float = None, + ) -> "BBox": + """Calculate bounding box of the collection of 2D points.""" + sorted_coordinates = np.sort(pints, kind="heapsort", axis=0) + if rel_ignored_outliers is None or not (0 <= rel_ignored_outliers <= 1.0): + return BBox.make( + min_x=sorted_coordinates[0][0], + max_x=sorted_coordinates[-1][0], + min_y=sorted_coordinates[0][1], + max_y=sorted_coordinates[-1][1], + ) + n_ignored = math.floor(len(pints) * rel_ignored_outliers) + return BBox.make( + min_x=sorted_coordinates[n_ignored][0], + max_x=sorted_coordinates[-n_ignored - 1][0], + min_y=sorted_coordinates[n_ignored][1], + max_y=sorted_coordinates[-n_ignored - 1][1], + ) + + +@dataclass +class Tile: + """A tile descriptor. + + The tile descriptor includes: + * bbox - bounding box of the pints in the embedding space coordinates. + * x, y - index of the tile along X and Y axes correspondingly for the given zoom level. + * zoom - zoom level at which tile is supposed to be displayed. + """ + + bbox: BBox + x: int = 0 + y: int = 0 + zoom: int = 0 + + def subtiles(self) -> List["Tile"]: + """Get the subtiles of the given tile (by increasing the zoom and splitting the tile 2x2).""" + x_offset = self.x * 2 + y_offset = self.y * 2 + results: List[Tile] = [] + for x_index in (0, 1): + for y_index in (0, 1): + tile = Tile( + x=x_offset + x_index, + y=y_offset + y_index, + zoom=self.zoom + 1, + bbox=self.bbox.subbox(divide=2, x_index=x_index, y_index=y_index), + ) + results.append(tile) + return results + + @staticmethod + def make_root(points: np.ndarray, rel_ignored_outliers: float = 0.0) -> "Tile": + """Create a root tile for a collection of points.""" + bbox = BBox.calculate(points, rel_ignored_outliers).squared() + return Tile(bbox=bbox, x=0, y=0, zoom=0) + + +class PointStyle(abc.ABC): + """A point style strategy.""" + + @abc.abstractmethod + def size(self, zoom: int) -> Union[float, int]: + """Get the point size for the given zoom.""" + + @abc.abstractmethod + def alpha(self, zoom: int) -> float: + """Get the point opacity.""" + + +class SimplePointStyle(PointStyle): + """A point style strategy.""" + + def __init__( + self, + min_zoom: int = 0, + max_zoom: int = 10, + max_size: int = 64, + min_size: int = 1, + min_alpha: float = 0.5, + ): + self.min_zoom: int = min_zoom + self.max_zoom: int = max_zoom + self.max_size: int = max_size + self.min_size: int = min_size + self.min_alpha: float = min_alpha + + def size(self, zoom: int) -> Union[float, int]: + """Get the point size for the given zoom.""" + start_zoom = max(self.min_zoom, self.max_zoom - int(math.log2(float(self.max_size) / float(self.min_size)))) + if zoom < start_zoom: + return self.min_size + return self.min_size * (2 ** zoom - start_zoom) + + def alpha(self, zoom: int) -> float: + """Get the point opacity.""" + start_zoom = max(self.min_zoom, self.max_zoom - int(math.log2(float(self.max_size) / float(self.min_size)))) + if zoom >= start_zoom: + return 1.0 + return 0.9 ** (start_zoom - zoom) + + +class TileGenerator: + """Recursive tiles generator.""" + + # Tiles detail file + DETAILS_FILE = "details.json" + + def __init__( + self, + point_style: PointStyle = None, + min_zoom: int = 0, + max_zoom: int = 10, + rel_ignored_outliers: float = 0.001, + ): + self.min_zoom: int = min_zoom + self.max_zoom: int = max_zoom + self.rel_ignored: float = rel_ignored_outliers + self.point_style: PointStyle = point_style or SimplePointStyle(min_zoom=min_zoom, max_zoom=max_zoom) + + def generate_tiles( + self, + points: np.ndarray, + output_directory: str, + progress: BaseProgressMonitor = ProgressMonitor.NULL, + ): + """Generate tile images for a collection of points.""" + progress.scale(1.0) + self.save_blank(output_directory) + bbox = BBox.calculate(points, self.rel_ignored).squared() + root_tile = Tile(x=0, y=0, zoom=0, bbox=bbox) + progress.increase(0.01) + tiles_count = self.tiles_count(zoom=root_tile.zoom, max_zoom=self.max_zoom) + progress = progress.bar(scale=tiles_count, unit="tiles") + self.make_tiles_recursive(points, root_tile, output_directory, progress.remaining()) + self._write_details(output_directory, root_tile.bbox) + + def make_tiles_recursive( + self, + points: np.ndarray, + tile: Tile, + output_directory: str, + progress: BaseProgressMonitor = ProgressMonitor.NULL, + ): + tiles_count = self.tiles_count(zoom=tile.zoom, max_zoom=self.max_zoom) + progress.scale(tiles_count) + if tile.zoom > self.max_zoom: + progress.complete() + return + point_size = self.point_style.size(tile.zoom) + tile_size = 256.0 + points = tile.bbox.select(points, rel_margin=point_size / tile_size) + if len(points) == 0: + progress.complete() + return + figure = self.draw_tile(points, tile) + progress.increase(1) + self.save_tile(figure, tile, output_directory) + plt.close(figure) + for subtile in tile.subtiles(): + self.make_tiles_recursive(points, subtile, output_directory, progress.subtask((tiles_count - 1) / 4)) + progress.complete() + + def draw_tile(self, points: np.ndarray, tile: Tile) -> Figure: + figure, axes = self._make_figure() + axes.set_xlim(*tile.bbox.x_lim) + axes.set_ylim(*tile.bbox.y_lim) + axes.scatter( + points[:, 0], + points[:, 1], + s=math.pi * ((self.point_style.size(tile.zoom) / 2.0) ** 2), + alpha=self.point_style.alpha(tile.zoom), + ) + return figure + + @staticmethod + def save_tile(figure: Figure, tile, output_root_dir, format="png"): + """Save tile file.""" + directory = os.path.join(output_root_dir, f"zoom_{tile.zoom}") + os.makedirs(directory, exist_ok=True) + figure_path = os.path.join(directory, f"tile_{tile.x}_{tile.y}.{format}") + figure.savefig(figure_path, format=format, bbox_inches="tight", pad_inches=0) + + @staticmethod + def _make_figure() -> Tuple[Figure, Axes]: + dpi = 96 + # MAGIC! The size is not 256.0/dpi because it is extremely hard to + # get rid of matplotlib margins, paddings and whitespaces... + size = (285.0 / dpi, 285.0 / dpi) + figure: Figure = plt.figure(figsize=size, dpi=dpi, tight_layout=True) + axes: Axes = figure.gca() + axes.axis("off") + return figure, axes + + @staticmethod + def tiles_count(zoom: int, max_zoom: int) -> int: + """Calculate number of tiles to be generated.""" + # This is simply a sum of geometric progression, + # as each tile is split into 4 smaller tiles on + # the next zoom level. + return (4 ** (max_zoom - zoom + 1) - 1) / 3 + + def save_blank(self, output_directory: str, format="png"): + """Save default blank tile.""" + figure, _ = self._make_figure() + os.makedirs(output_directory, exist_ok=True) + figure_path = os.path.join(output_directory, f"blank.{format}") + figure.savefig(figure_path, format=format, bbox_inches="tight", pad_inches=0) + + def _point_size_embedding(self, bbox: BBox): + """Get the point size in embeddings space.""" + max_width = bbox.width + min_width = bbox.width / (2 ** (self.max_zoom - self.min_zoom)) + return max( + float(self.point_style.size(self.min_zoom)) / 256 * max_width, + float(self.point_style.size(self.max_zoom)) / 256 * min_width, + ) + + def _write_details(self, output_directory: str, bbox: BBox): + """Create a file indicating successful completion.""" + done_path = os.path.join(output_directory, self.DETAILS_FILE) + with luigi.LocalTarget(done_path).open("w") as done: + details = { + "bbox": { + "x": [bbox.min_x, bbox.max_x], + "y": [bbox.min_y, bbox.max_y], + }, + "point_size": self._point_size_embedding(bbox), + } + json.dump(details, done, indent=4) + + +class EmbeddingsTilesTarget(FileWithTimestampTarget): + def exists(self): + if not super().exists(): + return False + details_file = os.path.join(self.latest_result_path, TileGenerator.DETAILS_FILE) + return os.path.isfile(details_file) + + +class EmbeddingsTilesTask(PipelineTask, abc.ABC): + """Base task for tiles generation.""" + + prefix: str = luigi.Parameter(default=".") + max_zoom: int = luigi.IntParameter(default=8) + clean_existing: bool = luigi.BoolParameter(default=True, significant=False) + + @abc.abstractmethod + def requires(self): + """Read condensed embeddings.""" + + def output(self) -> FileWithTimestampTarget: + coll = self.pipeline.coll + return EmbeddingsTilesTarget( + path_prefix=self.result_directory, + name_suffix=".d", + need_updates=lambda time: coll.any(prefix=self.prefix, min_mtime=time), + ) + + def run(self): + self.logger.info("Reading %s embeddings", self.algorithm_name) + embeddings_input: CondensedFingerprintsTarget = self.input() + embeddings = embeddings_input.read(self.progress.subtask(0.1)) + self.logger.info("Loaded %s %s embeddings", len(embeddings), self.algorithm_name) + + target = self.output() + previous_results_path = target.latest_result_path + new_result_time = self.pipeline.coll.max_mtime(prefix=self.prefix) + new_result_dir = target.suggest_path(new_result_time) + + self.logger.info("Generating tiles for %s embeddings into %s", self.algorithm_name, new_result_dir) + style = SimplePointStyle(max_zoom=self.max_zoom) + generator = TileGenerator(max_zoom=self.max_zoom, point_style=style) + generator.generate_tiles(embeddings.fingerprints, new_result_dir, self.progress.subtask(0.9)) + self.logger.info("All tiles are saved into %s", new_result_time) + + if previous_results_path is not None and self.clean_existing: + self.logger.info("Removing previous results: %s", previous_results_path) + shutil.rmtree(previous_results_path, ignore_errors=False, onerror=None) + + @property + @abc.abstractmethod + def algorithm_name(self) -> str: + """Embedding algorithm name.""" + + @property + def result_directory(self) -> str: + """Result directory path.""" + dir_name = f"tiles_{self.max_zoom}zoom" + return os.path.join(self.output_directory, "embeddings", self.algorithm_name.lower(), self.prefix, dir_name) + + +class PaCMAPTilesTask(EmbeddingsTilesTask): + """Generate tiles for PaCMAP embeddings.""" + + def requires(self): + return PaCMAPEmbeddingsTask(config=self.config, prefix=self.prefix) + + @property + def algorithm_name(self) -> str: + return "PaCMAP" + + +class TriMAPTilesTask(EmbeddingsTilesTask): + """Generate tiles for TriMAP embeddings.""" + + def requires(self): + return TriMapEmbeddingsTask(config=self.config, prefix=self.prefix) + + @property + def algorithm_name(self) -> str: + return "TriMAP" + + +class TSNETilesTask(EmbeddingsTilesTask): + """Generate tiles for t-SNE embeddings.""" + + def requires(self): + return TSNEEmbeddingsTask(config=self.config, prefix=self.prefix) + + @property + def algorithm_name(self) -> str: + return "t-SNE" + + +class UMAPTilesTask(EmbeddingsTilesTask): + """Generate tiles for UMAP embeddings.""" + + def requires(self): + return UmapEmbeddingsTask(config=self.config, prefix=self.prefix) + + @property + def algorithm_name(self) -> str: + return "UMAP" From bd971c463679a7b468015fc12cfc9d55909e7adb Mon Sep 17 00:00:00 2001 From: Stepan Anokhin Date: Wed, 20 Apr 2022 21:22:22 +0700 Subject: [PATCH 03/11] Implement annoy-index task for embeddings --- .../pipeline/luigi/embeddings_annoy_index.py | 120 ++++++++++++++++++ 1 file changed, 120 insertions(+) create mode 100644 winnow/pipeline/luigi/embeddings_annoy_index.py diff --git a/winnow/pipeline/luigi/embeddings_annoy_index.py b/winnow/pipeline/luigi/embeddings_annoy_index.py new file mode 100644 index 00000000..be3bfc32 --- /dev/null +++ b/winnow/pipeline/luigi/embeddings_annoy_index.py @@ -0,0 +1,120 @@ +import abc +import os + +import luigi +from annoy import AnnoyIndex + +from winnow.pipeline.luigi.condense import CondensedFingerprints +from winnow.pipeline.luigi.embeddings import UmapEmbeddingsTask, TSNEEmbeddingsTask, TriMapEmbeddingsTask, \ + PaCMAPEmbeddingsTask +from winnow.pipeline.luigi.platform import PipelineTask +from winnow.pipeline.luigi.targets import FileGroupTarget + + +class EmbeddingsAnnoyIndexTask(PipelineTask, abc.ABC): + """Build Annoy index for 2D fingerprint embeddings.""" + + prefix: str = luigi.Parameter(default=".") + n_trees: int = luigi.IntParameter(default=10) + clean_existing: bool = luigi.BoolParameter(default=True, significant=False) + + def run(self): + target = self.output() + new_results_time = self.pipeline.coll.max_mtime(prefix=self.prefix) + previous_results_paths, _ = target.latest_result + + self.logger.info("Loading condensed %s embeddings", self.algorithm_name) + embeddings: CondensedFingerprints = self.input().read(self.progress.subtask(0.1)) + self.logger.info("Loaded %s embeddings", len(embeddings)) + + self.logger.info("Building Annoy index for %s embeddings", self.algorithm_name) + annoy_index = AnnoyIndex(2, "euclidean") + fitting_progress = self.progress.bar(0.6, scale=len(embeddings), unit="sigs") + for i, fingerprint in enumerate(embeddings.fingerprints): + annoy_index.add_item(i, fingerprint) + fitting_progress.increase(1) + fitting_progress.complete() + self.logger.info("Added %s fingerprints to the index", len(embeddings)) + + self.logger.info("Building annoy index.") + annoy_index.build(self.n_trees) + self.progress.increase(0.25) + self.logger.info("Annoy index is prepared.") + + self.logger.info("Saving annoy index.") + index_path, keys_path = target.suggest_paths(new_results_time) + os.makedirs(os.path.dirname(index_path), exist_ok=True) + annoy_index.save(index_path) + embeddings.file_keys_df.to_csv(keys_path) + + if self.clean_existing and previous_results_paths is not None: + for path in previous_results_paths: + self.logger.info("Removing previous results: %s", path) + os.remove(path) + + def output(self) -> FileGroupTarget: + coll = self.pipeline.coll + return FileGroupTarget( + common_prefix=os.path.join( + self.output_directory, + "embeddings", + self.algorithm_name.lower(), + self.prefix, + "annoy_index", + ), + suffixes=(".annoy", ".files.csv"), + need_updates=lambda time: coll.any(prefix=self.prefix, min_mtime=time), + ) + + @abc.abstractmethod + def requires(self): + """Read condensed embeddings.""" + + @property + @abc.abstractmethod + def algorithm_name(self) -> str: + """Embedding algorithm name.""" + + +class PaCMAPAnnoyIndexTask(EmbeddingsAnnoyIndexTask): + """Generate tiles for PaCMAP embeddings.""" + + def requires(self): + return PaCMAPEmbeddingsTask(config=self.config, prefix=self.prefix) + + @property + def algorithm_name(self) -> str: + return "PaCMAP" + + +class TriMAPAnnoyIndexTask(EmbeddingsAnnoyIndexTask): + """Generate tiles for TriMAP embeddings.""" + + def requires(self): + return TriMapEmbeddingsTask(config=self.config, prefix=self.prefix) + + @property + def algorithm_name(self) -> str: + return "TriMAP" + + +class TSNEAnnoyIndexTask(EmbeddingsAnnoyIndexTask): + """Generate tiles for t-SNE embeddings.""" + + def requires(self): + return TSNEEmbeddingsTask(config=self.config, prefix=self.prefix) + + @property + def algorithm_name(self) -> str: + return "t-SNE" + + +class UMAPAnnoyIndexTask(EmbeddingsAnnoyIndexTask): + """Generate tiles for UMAP embeddings.""" + + def requires(self): + return UmapEmbeddingsTask(config=self.config, prefix=self.prefix) + + @property + def algorithm_name(self) -> str: + return "UMAP" From a408b9da7c2747097862f717d11d396b4f008f6b Mon Sep 17 00:00:00 2001 From: Stepan Anokhin Date: Wed, 20 Apr 2022 21:23:15 +0700 Subject: [PATCH 04/11] Implement tile generation background job --- task_queue/luigi_support.py | 21 ++++++++++++------ task_queue/tasks.py | 43 +++++++++++++++++++++++++++++++++++++ 2 files changed, 58 insertions(+), 6 deletions(-) diff --git a/task_queue/luigi_support.py b/task_queue/luigi_support.py index 552f25b2..21776e7c 100644 --- a/task_queue/luigi_support.py +++ b/task_queue/luigi_support.py @@ -5,19 +5,28 @@ import luigi from task_queue.metadata import TaskRuntimeMetadata +from task_queue.winnow_task import WinnowTask from winnow.pipeline.luigi.platform import JusticeAITask from winnow.pipeline.progress_monitor import ProgressMonitor +class ProgressObserver: + def __init__(self, celery_task: WinnowTask, resolution: float = 0.001): + self.celery_task: WinnowTask = celery_task + self.resolution: float = resolution + self.last_update: float = -resolution + + def __call__(self, progress: float, change: float): + if progress - self.last_update >= self.resolution: + self.celery_task.update_metadata(TaskRuntimeMetadata(progress=progress)) + self.last_update = progress + + class LuigiRootProgressMonitor(ProgressMonitor): """Root progress monitor to track progress made by entire multitask run.""" - def __init__(self, celery_task): - def update_progress(progress, _): - """Send a metadata update.""" - celery_task.update_metadata(TaskRuntimeMetadata(progress=progress)) - - super().__init__(update_progress) + def __init__(self, celery_task: WinnowTask): + super().__init__(ProgressObserver(celery_task)) self._seen_tasks = set() JusticeAITask.event_handler(luigi.Event.DEPENDENCY_DISCOVERED)(self._update_total_work) JusticeAITask.event_handler(luigi.Event.PROGRESS)(self._handle_task_progress) diff --git a/task_queue/tasks.py b/task_queue/tasks.py index 86ba8acf..ceabfd55 100644 --- a/task_queue/tasks.py +++ b/task_queue/tasks.py @@ -1,5 +1,6 @@ import json import os +import shutil import tempfile import time from typing import List, Dict @@ -151,6 +152,48 @@ def prepare_semantic_search(self, **_): run_luigi(PrepareTextSearchTask(config=config)) +@winnow_task(bind=True) +def generate_tiles(self, algorithm: str, max_zoom: int = 8, force: bool = False, **_): + """Prepare semantic search model.""" + from .luigi_support import luigi_config, run_luigi + from winnow.pipeline.luigi.embeddings_tiles import ( + TriMAPTilesTask, + PaCMAPTilesTask, + UMAPTilesTask, + TSNETilesTask, + ) + from winnow.pipeline.luigi.embeddings_annoy_index import ( + PaCMAPAnnoyIndexTask, + TriMAPAnnoyIndexTask, + TSNEAnnoyIndexTask, + UMAPAnnoyIndexTask, + ) + + if max_zoom < 0: + raise ValueError(f"Negative max_zoom: {max_zoom}") + + with luigi_config(celery_task=self) as config: + if algorithm == "pacmap": + tiles_task = PaCMAPTilesTask(config=config, max_zoom=max_zoom, clean_existing=True) + index_task = PaCMAPAnnoyIndexTask(config=config) + elif algorithm == "trimap": + tiles_task = TriMAPTilesTask(config=config, max_zoom=max_zoom, clean_existing=True) + index_task = TriMAPAnnoyIndexTask(config=config) + elif algorithm == "umap": + tiles_task = UMAPTilesTask(config=config, max_zoom=max_zoom, clean_existing=True) + index_task = UMAPAnnoyIndexTask(config=config) + elif algorithm == "t-sne": + tiles_task = TSNETilesTask(config=config, max_zoom=max_zoom, clean_existing=True) + index_task = TSNEAnnoyIndexTask(config=config) + else: + raise ValueError(f"Unknown embeddings algorithm: {algorithm}") + + existing_tiles_path = tiles_task.output().latest_result_path + if force and existing_tiles_path is not None: + shutil.rmtree(existing_tiles_path) + run_luigi(tiles_task, index_task) + + def fibo(n): """A very inefficient Fibonacci numbers generator.""" if n <= 2: From 0102bd3b3770118c713a1a607c1377c9c1e0ee9a Mon Sep 17 00:00:00 2001 From: Stepan Anokhin Date: Wed, 20 Apr 2022 21:23:52 +0700 Subject: [PATCH 05/11] Augment RPC server to provide embeddings neighbors --- rpc/embeddings.py | 110 ++++++ rpc/rpc.proto | 31 ++ rpc/rpc_pb2.py | 916 +++++++++++++++++++++++++------------------- rpc/rpc_pb2_grpc.py | 231 +++++++---- rpc/server.py | 29 ++ 5 files changed, 837 insertions(+), 480 deletions(-) create mode 100644 rpc/embeddings.py diff --git a/rpc/embeddings.py b/rpc/embeddings.py new file mode 100644 index 00000000..6e5a1b97 --- /dev/null +++ b/rpc/embeddings.py @@ -0,0 +1,110 @@ +from threading import Lock +from typing import List, Dict, Optional, Tuple + +import numpy as np +from annoy import AnnoyIndex + +import rpc.rpc_pb2 as proto +from winnow.pipeline.luigi.condense import CondensedFingerprints +from winnow.pipeline.luigi.embeddings import ( + EmbeddingsTask, + UmapEmbeddingsTask, + TSNEEmbeddingsTask, + TriMapEmbeddingsTask, + PaCMAPEmbeddingsTask, +) +from winnow.pipeline.luigi.embeddings_annoy_index import ( + EmbeddingsAnnoyIndexTask, + PaCMAPAnnoyIndexTask, + TriMAPAnnoyIndexTask, + UMAPAnnoyIndexTask, + TSNEAnnoyIndexTask, +) +from winnow.pipeline.luigi.utils import FileKeyDF +from winnow.pipeline.pipeline_context import PipelineContext +from winnow.storage.file_key import FileKey + + +class EmbeddingsIndex: + def __init__(self, annoy_index: AnnoyIndex, files: List[FileKey], positions: Dict[FileKey, np.ndarray]): + self._annoy_index: AnnoyIndex = annoy_index + self._files: List[FileKey] = files + self._positions: Dict[FileKey, np.ndarray] = positions + + def query( + self, + x: float, + y: float, + max_count: int = 10, + max_distance: Optional[float] = None, + ) -> List[proto.FoundNeighbor]: + if max_distance <= 0: + max_distance = None + indices, distances = self._annoy_index.get_nns_by_vector([x, y], max_count, include_distances=True) + files = [self._files[i] for i in indices] + results: List[proto.FoundNeighbor] = [] + for file, distance in zip(files, distances): + if max_distance is not None and distance > max_distance: + break + x, y = self._positions[file] + results.append( + proto.FoundNeighbor( + file_path=file.path, + file_hash=file.hash, + distance=distance, + x=x, + y=y, + ) + ) + return results + + +class EmbeddingLoader: + def __init__(self, pipeline: PipelineContext): + self._pipeline: PipelineContext = pipeline + self._cache: Dict[str, EmbeddingsIndex] = {} + self._lock = Lock() + + def load(self, algorithm: str) -> Optional[EmbeddingsIndex]: + with self._lock: + if algorithm not in self._cache: + index = self._do_load(algorithm) + if index is not None: + self._cache[algorithm] = index + return self._cache.get(algorithm) + + def _do_load(self, algorithm: str) -> Optional[EmbeddingsIndex]: + """Do load embeddings index.""" + embeddings_task, annoy_task = self._task(algorithm) + if embeddings_task is None or annoy_task is None: + return None + embeddings: CondensedFingerprints = embeddings_task.output().read() + if embeddings is None: + return None + + annoy_output = annoy_task.output() + annoy_paths, _ = annoy_output.latest_result + if annoy_paths is None: + return None + + annoy_index_path, annoy_files_path = annoy_paths + annoy_index = AnnoyIndex(2, "euclidean") + annoy_index.load(annoy_index_path) + annoy_files_df = FileKeyDF.read_csv(annoy_files_path) + positions: Dict[FileKey, np.ndarray] = {} + for i, file_key in enumerate(embeddings.to_file_keys()): + positions[file_key] = embeddings.fingerprints[i] + return EmbeddingsIndex(annoy_index, FileKeyDF.to_file_keys(annoy_files_df), positions) + + def _task(self, algorithm: str) -> Tuple[Optional[EmbeddingsTask], Optional[EmbeddingsAnnoyIndexTask]]: + config = self._pipeline.config + if algorithm == "pacmap": + return PaCMAPEmbeddingsTask(config=config), PaCMAPAnnoyIndexTask(config=config) + elif algorithm == "trimap": + return TriMapEmbeddingsTask(config=config), TriMAPAnnoyIndexTask(config=config) + elif algorithm == "umap": + return UmapEmbeddingsTask(config=config), UMAPAnnoyIndexTask(config=config) + elif algorithm == "t-sne": + return TSNEEmbeddingsTask(config=config), TSNEAnnoyIndexTask(config=config) + else: + return None, None diff --git a/rpc/rpc.proto b/rpc/rpc.proto index 88b30b8c..d074de07 100644 --- a/rpc/rpc.proto +++ b/rpc/rpc.proto @@ -40,4 +40,35 @@ message StatusRequest { // Service status response message StatusResponse { bool status = 1; +} + +// Embeddings service +service Embeddings { + // Get nearest neighbors + rpc query_nearest_neighbors (NearestNeighborsRequest) returns (NearestNeighborsResults) {} + rpc get_status (EmbeddingsStatusRequest) returns (StatusResponse) {} +} + +message NearestNeighborsRequest { + string algorithm = 1; + float x = 2; + float y = 3; + float max_distance = 4; + int32 max_count = 5; +} + +message NearestNeighborsResults { + repeated FoundNeighbor neighbors = 1; +} + +message FoundNeighbor { + string file_path = 1; + string file_hash = 2; + float x = 3; + float y = 4; + float distance = 5; +} + +message EmbeddingsStatusRequest { + string algorithm = 1; } \ No newline at end of file diff --git a/rpc/rpc_pb2.py b/rpc/rpc_pb2.py index a7e46ffb..2c88492b 100644 --- a/rpc/rpc_pb2.py +++ b/rpc/rpc_pb2.py @@ -6,456 +6,564 @@ from google.protobuf import message as _message from google.protobuf import reflection as _reflection from google.protobuf import symbol_database as _symbol_database - # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() + + DESCRIPTOR = _descriptor.FileDescriptor( - name="rpc/rpc.proto", - package="rpc.proto", - syntax="proto3", - serialized_options=None, - create_key=_descriptor._internal_create_key, - serialized_pb=b'\n\rrpc/rpc.proto\x12\trpc.proto"M\n\x11TextSearchRequest\x12\r\n\x05query\x18\x01 \x01(\t\x12\x16\n\x0emin_similarity\x18\x02 \x01(\x02\x12\x11\n\tmax_count\x18\x03 \x01(\x05"\'\n\nFoundVideo\x12\n\n\x02id\x18\x01 \x01(\x04\x12\r\n\x05score\x18\x02 \x01(\x02"\x9f\x01\n\x11TextSearchResults\x12%\n\x06videos\x18\x01 \x03(\x0b\x32\x15.rpc.proto.FoundVideo\x12\x16\n\x0eoriginal_query\x18\x02 \x01(\t\x12\x0e\n\x06tokens\x18\x03 \x03(\t\x12\x14\n\x0c\x63lean_tokens\x18\x04 \x03(\t\x12\x16\n\x0ehuman_readable\x18\x05 \x01(\t\x12\r\n\x05score\x18\x06 \x01(\x02"\x0f\n\rStatusRequest" \n\x0eStatusResponse\x12\x0e\n\x06status\x18\x01 \x01(\x08\x32\xa3\x01\n\x0eSemanticSearch\x12L\n\x0cquery_videos\x12\x1c.rpc.proto.TextSearchRequest\x1a\x1c.rpc.proto.TextSearchResults"\x00\x12\x43\n\nget_status\x12\x18.rpc.proto.StatusRequest\x1a\x19.rpc.proto.StatusResponse"\x00\x62\x06proto3', + name='rpc/rpc.proto', + package='rpc.proto', + syntax='proto3', + serialized_options=None, + create_key=_descriptor._internal_create_key, + serialized_pb=b'\n\rrpc/rpc.proto\x12\trpc.proto\"M\n\x11TextSearchRequest\x12\r\n\x05query\x18\x01 \x01(\t\x12\x16\n\x0emin_similarity\x18\x02 \x01(\x02\x12\x11\n\tmax_count\x18\x03 \x01(\x05\"\'\n\nFoundVideo\x12\n\n\x02id\x18\x01 \x01(\x04\x12\r\n\x05score\x18\x02 \x01(\x02\"\x9f\x01\n\x11TextSearchResults\x12%\n\x06videos\x18\x01 \x03(\x0b\x32\x15.rpc.proto.FoundVideo\x12\x16\n\x0eoriginal_query\x18\x02 \x01(\t\x12\x0e\n\x06tokens\x18\x03 \x03(\t\x12\x14\n\x0c\x63lean_tokens\x18\x04 \x03(\t\x12\x16\n\x0ehuman_readable\x18\x05 \x01(\t\x12\r\n\x05score\x18\x06 \x01(\x02\"\x0f\n\rStatusRequest\" \n\x0eStatusResponse\x12\x0e\n\x06status\x18\x01 \x01(\x08\"k\n\x17NearestNeighborsRequest\x12\x11\n\talgorithm\x18\x01 \x01(\t\x12\t\n\x01x\x18\x02 \x01(\x02\x12\t\n\x01y\x18\x03 \x01(\x02\x12\x14\n\x0cmax_distance\x18\x04 \x01(\x02\x12\x11\n\tmax_count\x18\x05 \x01(\x05\"F\n\x17NearestNeighborsResults\x12+\n\tneighbors\x18\x01 \x03(\x0b\x32\x18.rpc.proto.FoundNeighbor\"]\n\rFoundNeighbor\x12\x11\n\tfile_path\x18\x01 \x01(\t\x12\x11\n\tfile_hash\x18\x02 \x01(\t\x12\t\n\x01x\x18\x03 \x01(\x02\x12\t\n\x01y\x18\x04 \x01(\x02\x12\x10\n\x08\x64istance\x18\x05 \x01(\x02\",\n\x17\x45mbeddingsStatusRequest\x12\x11\n\talgorithm\x18\x01 \x01(\t2\xa3\x01\n\x0eSemanticSearch\x12L\n\x0cquery_videos\x12\x1c.rpc.proto.TextSearchRequest\x1a\x1c.rpc.proto.TextSearchResults\"\x00\x12\x43\n\nget_status\x12\x18.rpc.proto.StatusRequest\x1a\x19.rpc.proto.StatusResponse\"\x00\x32\xc0\x01\n\nEmbeddings\x12\x63\n\x17query_nearest_neighbors\x12\".rpc.proto.NearestNeighborsRequest\x1a\".rpc.proto.NearestNeighborsResults\"\x00\x12M\n\nget_status\x12\".rpc.proto.EmbeddingsStatusRequest\x1a\x19.rpc.proto.StatusResponse\"\x00\x62\x06proto3' ) + + _TEXTSEARCHREQUEST = _descriptor.Descriptor( - name="TextSearchRequest", - full_name="rpc.proto.TextSearchRequest", - filename=None, - file=DESCRIPTOR, - containing_type=None, - create_key=_descriptor._internal_create_key, - fields=[ - _descriptor.FieldDescriptor( - name="query", - full_name="rpc.proto.TextSearchRequest.query", - index=0, - number=1, - type=9, - cpp_type=9, - label=1, - has_default_value=False, - default_value=b"".decode("utf-8"), - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - serialized_options=None, - file=DESCRIPTOR, - create_key=_descriptor._internal_create_key, - ), - _descriptor.FieldDescriptor( - name="min_similarity", - full_name="rpc.proto.TextSearchRequest.min_similarity", - index=1, - number=2, - type=2, - cpp_type=6, - label=1, - has_default_value=False, - default_value=float(0), - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - serialized_options=None, - file=DESCRIPTOR, - create_key=_descriptor._internal_create_key, - ), - _descriptor.FieldDescriptor( - name="max_count", - full_name="rpc.proto.TextSearchRequest.max_count", - index=2, - number=3, - type=5, - cpp_type=1, - label=1, - has_default_value=False, - default_value=0, - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - serialized_options=None, - file=DESCRIPTOR, - create_key=_descriptor._internal_create_key, - ), - ], - extensions=[], - nested_types=[], - enum_types=[], - serialized_options=None, - is_extendable=False, - syntax="proto3", - extension_ranges=[], - oneofs=[], - serialized_start=28, - serialized_end=105, + name='TextSearchRequest', + full_name='rpc.proto.TextSearchRequest', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='query', full_name='rpc.proto.TextSearchRequest.query', index=0, + number=1, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='min_similarity', full_name='rpc.proto.TextSearchRequest.min_similarity', index=1, + number=2, type=2, cpp_type=6, label=1, + has_default_value=False, default_value=float(0), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='max_count', full_name='rpc.proto.TextSearchRequest.max_count', index=2, + number=3, type=5, cpp_type=1, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=28, + serialized_end=105, ) _FOUNDVIDEO = _descriptor.Descriptor( - name="FoundVideo", - full_name="rpc.proto.FoundVideo", - filename=None, - file=DESCRIPTOR, - containing_type=None, - create_key=_descriptor._internal_create_key, - fields=[ - _descriptor.FieldDescriptor( - name="id", - full_name="rpc.proto.FoundVideo.id", - index=0, - number=1, - type=4, - cpp_type=4, - label=1, - has_default_value=False, - default_value=0, - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - serialized_options=None, - file=DESCRIPTOR, - create_key=_descriptor._internal_create_key, - ), - _descriptor.FieldDescriptor( - name="score", - full_name="rpc.proto.FoundVideo.score", - index=1, - number=2, - type=2, - cpp_type=6, - label=1, - has_default_value=False, - default_value=float(0), - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - serialized_options=None, - file=DESCRIPTOR, - create_key=_descriptor._internal_create_key, - ), - ], - extensions=[], - nested_types=[], - enum_types=[], - serialized_options=None, - is_extendable=False, - syntax="proto3", - extension_ranges=[], - oneofs=[], - serialized_start=107, - serialized_end=146, + name='FoundVideo', + full_name='rpc.proto.FoundVideo', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='id', full_name='rpc.proto.FoundVideo.id', index=0, + number=1, type=4, cpp_type=4, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='score', full_name='rpc.proto.FoundVideo.score', index=1, + number=2, type=2, cpp_type=6, label=1, + has_default_value=False, default_value=float(0), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=107, + serialized_end=146, ) _TEXTSEARCHRESULTS = _descriptor.Descriptor( - name="TextSearchResults", - full_name="rpc.proto.TextSearchResults", - filename=None, - file=DESCRIPTOR, - containing_type=None, - create_key=_descriptor._internal_create_key, - fields=[ - _descriptor.FieldDescriptor( - name="videos", - full_name="rpc.proto.TextSearchResults.videos", - index=0, - number=1, - type=11, - cpp_type=10, - label=3, - has_default_value=False, - default_value=[], - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - serialized_options=None, - file=DESCRIPTOR, - create_key=_descriptor._internal_create_key, - ), - _descriptor.FieldDescriptor( - name="original_query", - full_name="rpc.proto.TextSearchResults.original_query", - index=1, - number=2, - type=9, - cpp_type=9, - label=1, - has_default_value=False, - default_value=b"".decode("utf-8"), - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - serialized_options=None, - file=DESCRIPTOR, - create_key=_descriptor._internal_create_key, - ), - _descriptor.FieldDescriptor( - name="tokens", - full_name="rpc.proto.TextSearchResults.tokens", - index=2, - number=3, - type=9, - cpp_type=9, - label=3, - has_default_value=False, - default_value=[], - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - serialized_options=None, - file=DESCRIPTOR, - create_key=_descriptor._internal_create_key, - ), - _descriptor.FieldDescriptor( - name="clean_tokens", - full_name="rpc.proto.TextSearchResults.clean_tokens", - index=3, - number=4, - type=9, - cpp_type=9, - label=3, - has_default_value=False, - default_value=[], - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - serialized_options=None, - file=DESCRIPTOR, - create_key=_descriptor._internal_create_key, - ), - _descriptor.FieldDescriptor( - name="human_readable", - full_name="rpc.proto.TextSearchResults.human_readable", - index=4, - number=5, - type=9, - cpp_type=9, - label=1, - has_default_value=False, - default_value=b"".decode("utf-8"), - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - serialized_options=None, - file=DESCRIPTOR, - create_key=_descriptor._internal_create_key, - ), - _descriptor.FieldDescriptor( - name="score", - full_name="rpc.proto.TextSearchResults.score", - index=5, - number=6, - type=2, - cpp_type=6, - label=1, - has_default_value=False, - default_value=float(0), - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - serialized_options=None, - file=DESCRIPTOR, - create_key=_descriptor._internal_create_key, - ), - ], - extensions=[], - nested_types=[], - enum_types=[], - serialized_options=None, - is_extendable=False, - syntax="proto3", - extension_ranges=[], - oneofs=[], - serialized_start=149, - serialized_end=308, + name='TextSearchResults', + full_name='rpc.proto.TextSearchResults', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='videos', full_name='rpc.proto.TextSearchResults.videos', index=0, + number=1, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='original_query', full_name='rpc.proto.TextSearchResults.original_query', index=1, + number=2, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='tokens', full_name='rpc.proto.TextSearchResults.tokens', index=2, + number=3, type=9, cpp_type=9, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='clean_tokens', full_name='rpc.proto.TextSearchResults.clean_tokens', index=3, + number=4, type=9, cpp_type=9, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='human_readable', full_name='rpc.proto.TextSearchResults.human_readable', index=4, + number=5, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='score', full_name='rpc.proto.TextSearchResults.score', index=5, + number=6, type=2, cpp_type=6, label=1, + has_default_value=False, default_value=float(0), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=149, + serialized_end=308, ) _STATUSREQUEST = _descriptor.Descriptor( - name="StatusRequest", - full_name="rpc.proto.StatusRequest", - filename=None, - file=DESCRIPTOR, - containing_type=None, - create_key=_descriptor._internal_create_key, - fields=[], - extensions=[], - nested_types=[], - enum_types=[], - serialized_options=None, - is_extendable=False, - syntax="proto3", - extension_ranges=[], - oneofs=[], - serialized_start=310, - serialized_end=325, + name='StatusRequest', + full_name='rpc.proto.StatusRequest', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=310, + serialized_end=325, ) _STATUSRESPONSE = _descriptor.Descriptor( - name="StatusResponse", - full_name="rpc.proto.StatusResponse", - filename=None, - file=DESCRIPTOR, - containing_type=None, - create_key=_descriptor._internal_create_key, - fields=[ - _descriptor.FieldDescriptor( - name="status", - full_name="rpc.proto.StatusResponse.status", - index=0, - number=1, - type=8, - cpp_type=7, - label=1, - has_default_value=False, - default_value=False, - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - serialized_options=None, - file=DESCRIPTOR, - create_key=_descriptor._internal_create_key, - ), - ], - extensions=[], - nested_types=[], - enum_types=[], - serialized_options=None, - is_extendable=False, - syntax="proto3", - extension_ranges=[], - oneofs=[], - serialized_start=327, - serialized_end=359, + name='StatusResponse', + full_name='rpc.proto.StatusResponse', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='status', full_name='rpc.proto.StatusResponse.status', index=0, + number=1, type=8, cpp_type=7, label=1, + has_default_value=False, default_value=False, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=327, + serialized_end=359, ) -_TEXTSEARCHRESULTS.fields_by_name["videos"].message_type = _FOUNDVIDEO -DESCRIPTOR.message_types_by_name["TextSearchRequest"] = _TEXTSEARCHREQUEST -DESCRIPTOR.message_types_by_name["FoundVideo"] = _FOUNDVIDEO -DESCRIPTOR.message_types_by_name["TextSearchResults"] = _TEXTSEARCHRESULTS -DESCRIPTOR.message_types_by_name["StatusRequest"] = _STATUSREQUEST -DESCRIPTOR.message_types_by_name["StatusResponse"] = _STATUSRESPONSE -_sym_db.RegisterFileDescriptor(DESCRIPTOR) -TextSearchRequest = _reflection.GeneratedProtocolMessageType( - "TextSearchRequest", - (_message.Message,), - { - "DESCRIPTOR": _TEXTSEARCHREQUEST, - "__module__": "rpc.rpc_pb2" - # @@protoc_insertion_point(class_scope:rpc.proto.TextSearchRequest) - }, +_NEARESTNEIGHBORSREQUEST = _descriptor.Descriptor( + name='NearestNeighborsRequest', + full_name='rpc.proto.NearestNeighborsRequest', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='algorithm', full_name='rpc.proto.NearestNeighborsRequest.algorithm', index=0, + number=1, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='x', full_name='rpc.proto.NearestNeighborsRequest.x', index=1, + number=2, type=2, cpp_type=6, label=1, + has_default_value=False, default_value=float(0), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='y', full_name='rpc.proto.NearestNeighborsRequest.y', index=2, + number=3, type=2, cpp_type=6, label=1, + has_default_value=False, default_value=float(0), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='max_distance', full_name='rpc.proto.NearestNeighborsRequest.max_distance', index=3, + number=4, type=2, cpp_type=6, label=1, + has_default_value=False, default_value=float(0), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='max_count', full_name='rpc.proto.NearestNeighborsRequest.max_count', index=4, + number=5, type=5, cpp_type=1, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=361, + serialized_end=468, ) -_sym_db.RegisterMessage(TextSearchRequest) -FoundVideo = _reflection.GeneratedProtocolMessageType( - "FoundVideo", - (_message.Message,), - { - "DESCRIPTOR": _FOUNDVIDEO, - "__module__": "rpc.rpc_pb2" - # @@protoc_insertion_point(class_scope:rpc.proto.FoundVideo) - }, + +_NEARESTNEIGHBORSRESULTS = _descriptor.Descriptor( + name='NearestNeighborsResults', + full_name='rpc.proto.NearestNeighborsResults', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='neighbors', full_name='rpc.proto.NearestNeighborsResults.neighbors', index=0, + number=1, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=470, + serialized_end=540, ) -_sym_db.RegisterMessage(FoundVideo) -TextSearchResults = _reflection.GeneratedProtocolMessageType( - "TextSearchResults", - (_message.Message,), - { - "DESCRIPTOR": _TEXTSEARCHRESULTS, - "__module__": "rpc.rpc_pb2" - # @@protoc_insertion_point(class_scope:rpc.proto.TextSearchResults) - }, + +_FOUNDNEIGHBOR = _descriptor.Descriptor( + name='FoundNeighbor', + full_name='rpc.proto.FoundNeighbor', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='file_path', full_name='rpc.proto.FoundNeighbor.file_path', index=0, + number=1, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='file_hash', full_name='rpc.proto.FoundNeighbor.file_hash', index=1, + number=2, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='x', full_name='rpc.proto.FoundNeighbor.x', index=2, + number=3, type=2, cpp_type=6, label=1, + has_default_value=False, default_value=float(0), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='y', full_name='rpc.proto.FoundNeighbor.y', index=3, + number=4, type=2, cpp_type=6, label=1, + has_default_value=False, default_value=float(0), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='distance', full_name='rpc.proto.FoundNeighbor.distance', index=4, + number=5, type=2, cpp_type=6, label=1, + has_default_value=False, default_value=float(0), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=542, + serialized_end=635, ) -_sym_db.RegisterMessage(TextSearchResults) -StatusRequest = _reflection.GeneratedProtocolMessageType( - "StatusRequest", - (_message.Message,), - { - "DESCRIPTOR": _STATUSREQUEST, - "__module__": "rpc.rpc_pb2" - # @@protoc_insertion_point(class_scope:rpc.proto.StatusRequest) - }, + +_EMBEDDINGSSTATUSREQUEST = _descriptor.Descriptor( + name='EmbeddingsStatusRequest', + full_name='rpc.proto.EmbeddingsStatusRequest', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='algorithm', full_name='rpc.proto.EmbeddingsStatusRequest.algorithm', index=0, + number=1, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=637, + serialized_end=681, ) + +_TEXTSEARCHRESULTS.fields_by_name['videos'].message_type = _FOUNDVIDEO +_NEARESTNEIGHBORSRESULTS.fields_by_name['neighbors'].message_type = _FOUNDNEIGHBOR +DESCRIPTOR.message_types_by_name['TextSearchRequest'] = _TEXTSEARCHREQUEST +DESCRIPTOR.message_types_by_name['FoundVideo'] = _FOUNDVIDEO +DESCRIPTOR.message_types_by_name['TextSearchResults'] = _TEXTSEARCHRESULTS +DESCRIPTOR.message_types_by_name['StatusRequest'] = _STATUSREQUEST +DESCRIPTOR.message_types_by_name['StatusResponse'] = _STATUSRESPONSE +DESCRIPTOR.message_types_by_name['NearestNeighborsRequest'] = _NEARESTNEIGHBORSREQUEST +DESCRIPTOR.message_types_by_name['NearestNeighborsResults'] = _NEARESTNEIGHBORSRESULTS +DESCRIPTOR.message_types_by_name['FoundNeighbor'] = _FOUNDNEIGHBOR +DESCRIPTOR.message_types_by_name['EmbeddingsStatusRequest'] = _EMBEDDINGSSTATUSREQUEST +_sym_db.RegisterFileDescriptor(DESCRIPTOR) + +TextSearchRequest = _reflection.GeneratedProtocolMessageType('TextSearchRequest', (_message.Message,), { + 'DESCRIPTOR' : _TEXTSEARCHREQUEST, + '__module__' : 'rpc.rpc_pb2' + # @@protoc_insertion_point(class_scope:rpc.proto.TextSearchRequest) + }) +_sym_db.RegisterMessage(TextSearchRequest) + +FoundVideo = _reflection.GeneratedProtocolMessageType('FoundVideo', (_message.Message,), { + 'DESCRIPTOR' : _FOUNDVIDEO, + '__module__' : 'rpc.rpc_pb2' + # @@protoc_insertion_point(class_scope:rpc.proto.FoundVideo) + }) +_sym_db.RegisterMessage(FoundVideo) + +TextSearchResults = _reflection.GeneratedProtocolMessageType('TextSearchResults', (_message.Message,), { + 'DESCRIPTOR' : _TEXTSEARCHRESULTS, + '__module__' : 'rpc.rpc_pb2' + # @@protoc_insertion_point(class_scope:rpc.proto.TextSearchResults) + }) +_sym_db.RegisterMessage(TextSearchResults) + +StatusRequest = _reflection.GeneratedProtocolMessageType('StatusRequest', (_message.Message,), { + 'DESCRIPTOR' : _STATUSREQUEST, + '__module__' : 'rpc.rpc_pb2' + # @@protoc_insertion_point(class_scope:rpc.proto.StatusRequest) + }) _sym_db.RegisterMessage(StatusRequest) -StatusResponse = _reflection.GeneratedProtocolMessageType( - "StatusResponse", - (_message.Message,), - { - "DESCRIPTOR": _STATUSRESPONSE, - "__module__": "rpc.rpc_pb2" - # @@protoc_insertion_point(class_scope:rpc.proto.StatusResponse) - }, -) +StatusResponse = _reflection.GeneratedProtocolMessageType('StatusResponse', (_message.Message,), { + 'DESCRIPTOR' : _STATUSRESPONSE, + '__module__' : 'rpc.rpc_pb2' + # @@protoc_insertion_point(class_scope:rpc.proto.StatusResponse) + }) _sym_db.RegisterMessage(StatusResponse) +NearestNeighborsRequest = _reflection.GeneratedProtocolMessageType('NearestNeighborsRequest', (_message.Message,), { + 'DESCRIPTOR' : _NEARESTNEIGHBORSREQUEST, + '__module__' : 'rpc.rpc_pb2' + # @@protoc_insertion_point(class_scope:rpc.proto.NearestNeighborsRequest) + }) +_sym_db.RegisterMessage(NearestNeighborsRequest) + +NearestNeighborsResults = _reflection.GeneratedProtocolMessageType('NearestNeighborsResults', (_message.Message,), { + 'DESCRIPTOR' : _NEARESTNEIGHBORSRESULTS, + '__module__' : 'rpc.rpc_pb2' + # @@protoc_insertion_point(class_scope:rpc.proto.NearestNeighborsResults) + }) +_sym_db.RegisterMessage(NearestNeighborsResults) + +FoundNeighbor = _reflection.GeneratedProtocolMessageType('FoundNeighbor', (_message.Message,), { + 'DESCRIPTOR' : _FOUNDNEIGHBOR, + '__module__' : 'rpc.rpc_pb2' + # @@protoc_insertion_point(class_scope:rpc.proto.FoundNeighbor) + }) +_sym_db.RegisterMessage(FoundNeighbor) + +EmbeddingsStatusRequest = _reflection.GeneratedProtocolMessageType('EmbeddingsStatusRequest', (_message.Message,), { + 'DESCRIPTOR' : _EMBEDDINGSSTATUSREQUEST, + '__module__' : 'rpc.rpc_pb2' + # @@protoc_insertion_point(class_scope:rpc.proto.EmbeddingsStatusRequest) + }) +_sym_db.RegisterMessage(EmbeddingsStatusRequest) + + _SEMANTICSEARCH = _descriptor.ServiceDescriptor( - name="SemanticSearch", - full_name="rpc.proto.SemanticSearch", - file=DESCRIPTOR, + name='SemanticSearch', + full_name='rpc.proto.SemanticSearch', + file=DESCRIPTOR, + index=0, + serialized_options=None, + create_key=_descriptor._internal_create_key, + serialized_start=684, + serialized_end=847, + methods=[ + _descriptor.MethodDescriptor( + name='query_videos', + full_name='rpc.proto.SemanticSearch.query_videos', index=0, + containing_service=None, + input_type=_TEXTSEARCHREQUEST, + output_type=_TEXTSEARCHRESULTS, serialized_options=None, create_key=_descriptor._internal_create_key, - serialized_start=362, - serialized_end=525, - methods=[ - _descriptor.MethodDescriptor( - name="query_videos", - full_name="rpc.proto.SemanticSearch.query_videos", - index=0, - containing_service=None, - input_type=_TEXTSEARCHREQUEST, - output_type=_TEXTSEARCHRESULTS, - serialized_options=None, - create_key=_descriptor._internal_create_key, - ), - _descriptor.MethodDescriptor( - name="get_status", - full_name="rpc.proto.SemanticSearch.get_status", - index=1, - containing_service=None, - input_type=_STATUSREQUEST, - output_type=_STATUSRESPONSE, - serialized_options=None, - create_key=_descriptor._internal_create_key, - ), - ], -) + ), + _descriptor.MethodDescriptor( + name='get_status', + full_name='rpc.proto.SemanticSearch.get_status', + index=1, + containing_service=None, + input_type=_STATUSREQUEST, + output_type=_STATUSRESPONSE, + serialized_options=None, + create_key=_descriptor._internal_create_key, + ), +]) _sym_db.RegisterServiceDescriptor(_SEMANTICSEARCH) -DESCRIPTOR.services_by_name["SemanticSearch"] = _SEMANTICSEARCH +DESCRIPTOR.services_by_name['SemanticSearch'] = _SEMANTICSEARCH + + +_EMBEDDINGS = _descriptor.ServiceDescriptor( + name='Embeddings', + full_name='rpc.proto.Embeddings', + file=DESCRIPTOR, + index=1, + serialized_options=None, + create_key=_descriptor._internal_create_key, + serialized_start=850, + serialized_end=1042, + methods=[ + _descriptor.MethodDescriptor( + name='query_nearest_neighbors', + full_name='rpc.proto.Embeddings.query_nearest_neighbors', + index=0, + containing_service=None, + input_type=_NEARESTNEIGHBORSREQUEST, + output_type=_NEARESTNEIGHBORSRESULTS, + serialized_options=None, + create_key=_descriptor._internal_create_key, + ), + _descriptor.MethodDescriptor( + name='get_status', + full_name='rpc.proto.Embeddings.get_status', + index=1, + containing_service=None, + input_type=_EMBEDDINGSSTATUSREQUEST, + output_type=_STATUSRESPONSE, + serialized_options=None, + create_key=_descriptor._internal_create_key, + ), +]) +_sym_db.RegisterServiceDescriptor(_EMBEDDINGS) + +DESCRIPTOR.services_by_name['Embeddings'] = _EMBEDDINGS # @@protoc_insertion_point(module_scope) diff --git a/rpc/rpc_pb2_grpc.py b/rpc/rpc_pb2_grpc.py index 07aff59c..9eb118ff 100644 --- a/rpc/rpc_pb2_grpc.py +++ b/rpc/rpc_pb2_grpc.py @@ -6,7 +6,8 @@ class SemanticSearchStub(object): - """Semantic search service""" + """Semantic search service + """ def __init__(self, channel): """Constructor. @@ -15,108 +16,186 @@ def __init__(self, channel): channel: A grpc.Channel. """ self.query_videos = channel.unary_unary( - "/rpc.proto.SemanticSearch/query_videos", - request_serializer=rpc_dot_rpc__pb2.TextSearchRequest.SerializeToString, - response_deserializer=rpc_dot_rpc__pb2.TextSearchResults.FromString, - ) + '/rpc.proto.SemanticSearch/query_videos', + request_serializer=rpc_dot_rpc__pb2.TextSearchRequest.SerializeToString, + response_deserializer=rpc_dot_rpc__pb2.TextSearchResults.FromString, + ) self.get_status = channel.unary_unary( - "/rpc.proto.SemanticSearch/get_status", - request_serializer=rpc_dot_rpc__pb2.StatusRequest.SerializeToString, - response_deserializer=rpc_dot_rpc__pb2.StatusResponse.FromString, - ) + '/rpc.proto.SemanticSearch/get_status', + request_serializer=rpc_dot_rpc__pb2.StatusRequest.SerializeToString, + response_deserializer=rpc_dot_rpc__pb2.StatusResponse.FromString, + ) class SemanticSearchServicer(object): - """Semantic search service""" + """Semantic search service + """ def query_videos(self, request, context): - """Perform semantic search by text description""" + """Perform semantic search by text description + """ context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details("Method not implemented!") - raise NotImplementedError("Method not implemented!") + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') def get_status(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details("Method not implemented!") - raise NotImplementedError("Method not implemented!") + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') def add_SemanticSearchServicer_to_server(servicer, server): rpc_method_handlers = { - "query_videos": grpc.unary_unary_rpc_method_handler( - servicer.query_videos, - request_deserializer=rpc_dot_rpc__pb2.TextSearchRequest.FromString, - response_serializer=rpc_dot_rpc__pb2.TextSearchResults.SerializeToString, - ), - "get_status": grpc.unary_unary_rpc_method_handler( - servicer.get_status, - request_deserializer=rpc_dot_rpc__pb2.StatusRequest.FromString, - response_serializer=rpc_dot_rpc__pb2.StatusResponse.SerializeToString, - ), + 'query_videos': grpc.unary_unary_rpc_method_handler( + servicer.query_videos, + request_deserializer=rpc_dot_rpc__pb2.TextSearchRequest.FromString, + response_serializer=rpc_dot_rpc__pb2.TextSearchResults.SerializeToString, + ), + 'get_status': grpc.unary_unary_rpc_method_handler( + servicer.get_status, + request_deserializer=rpc_dot_rpc__pb2.StatusRequest.FromString, + response_serializer=rpc_dot_rpc__pb2.StatusResponse.SerializeToString, + ), } - generic_handler = grpc.method_handlers_generic_handler("rpc.proto.SemanticSearch", rpc_method_handlers) + generic_handler = grpc.method_handlers_generic_handler( + 'rpc.proto.SemanticSearch', rpc_method_handlers) server.add_generic_rpc_handlers((generic_handler,)) -# This class is part of an EXPERIMENTAL API. + # This class is part of an EXPERIMENTAL API. class SemanticSearch(object): - """Semantic search service""" + """Semantic search service + """ @staticmethod - def query_videos( - request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None, - ): - return grpc.experimental.unary_unary( - request, + def query_videos(request, target, - "/rpc.proto.SemanticSearch/query_videos", + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/rpc.proto.SemanticSearch/query_videos', rpc_dot_rpc__pb2.TextSearchRequest.SerializeToString, rpc_dot_rpc__pb2.TextSearchResults.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - ) + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @staticmethod - def get_status( - request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None, - ): - return grpc.experimental.unary_unary( - request, + def get_status(request, target, - "/rpc.proto.SemanticSearch/get_status", + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/rpc.proto.SemanticSearch/get_status', rpc_dot_rpc__pb2.StatusRequest.SerializeToString, rpc_dot_rpc__pb2.StatusResponse.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - ) + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + +class EmbeddingsStub(object): + """Embeddings service + """ + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.query_nearest_neighbors = channel.unary_unary( + '/rpc.proto.Embeddings/query_nearest_neighbors', + request_serializer=rpc_dot_rpc__pb2.NearestNeighborsRequest.SerializeToString, + response_deserializer=rpc_dot_rpc__pb2.NearestNeighborsResults.FromString, + ) + self.get_status = channel.unary_unary( + '/rpc.proto.Embeddings/get_status', + request_serializer=rpc_dot_rpc__pb2.EmbeddingsStatusRequest.SerializeToString, + response_deserializer=rpc_dot_rpc__pb2.StatusResponse.FromString, + ) + + +class EmbeddingsServicer(object): + """Embeddings service + """ + + def query_nearest_neighbors(self, request, context): + """Get nearest neighbors + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def get_status(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_EmbeddingsServicer_to_server(servicer, server): + rpc_method_handlers = { + 'query_nearest_neighbors': grpc.unary_unary_rpc_method_handler( + servicer.query_nearest_neighbors, + request_deserializer=rpc_dot_rpc__pb2.NearestNeighborsRequest.FromString, + response_serializer=rpc_dot_rpc__pb2.NearestNeighborsResults.SerializeToString, + ), + 'get_status': grpc.unary_unary_rpc_method_handler( + servicer.get_status, + request_deserializer=rpc_dot_rpc__pb2.EmbeddingsStatusRequest.FromString, + response_serializer=rpc_dot_rpc__pb2.StatusResponse.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'rpc.proto.Embeddings', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + + + # This class is part of an EXPERIMENTAL API. +class Embeddings(object): + """Embeddings service + """ + + @staticmethod + def query_nearest_neighbors(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/rpc.proto.Embeddings/query_nearest_neighbors', + rpc_dot_rpc__pb2.NearestNeighborsRequest.SerializeToString, + rpc_dot_rpc__pb2.NearestNeighborsResults.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def get_status(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/rpc.proto.Embeddings/get_status', + rpc_dot_rpc__pb2.EmbeddingsStatusRequest.SerializeToString, + rpc_dot_rpc__pb2.StatusResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/rpc/server.py b/rpc/server.py index 4dd6197b..b369b62b 100644 --- a/rpc/server.py +++ b/rpc/server.py @@ -7,6 +7,7 @@ import rpc.rpc_pb2 as proto import rpc.rpc_pb2_grpc as services +from rpc.embeddings import EmbeddingLoader from rpc.errors import unavailable from rpc.logging import configure_logging from winnow.config.path import resolve_config_path @@ -72,6 +73,27 @@ def _get_search_engine(self, context: grpc.ServicerContext) -> VideoSearch: raise unavailable(context, str(error)) +class EmbeddingsService(services.EmbeddingsServicer): + def __init__(self, loader: EmbeddingLoader): + self._loader = loader + + def query_nearest_neighbors( + self, + request: proto.NearestNeighborsRequest, + context: grpc.ServicerContext, + ) -> proto.NearestNeighborsResults: + index = self._loader.load(request.algorithm) + if index is None: + return proto.NearestNeighborsResults(neighbors=[]) + found = index.query(x=request.x, y=request.y, max_distance=request.max_distance, max_count=request.max_count) + return proto.NearestNeighborsResults(neighbors=found) + + def get_status(self, request: proto.EmbeddingsStatusRequest, context: grpc.ServicerContext) -> proto.StatusResponse: + index = self._loader.load(request.algorithm) + available = index is not None + return proto.StatusResponse(status=available) + + def initialize_search_engine(pipeline: PipelineContext): """Try to eagerly initialize semantic search engine.""" logger.info("Trying to initialize semantic search engine.") @@ -84,13 +106,20 @@ def initialize_search_engine(pipeline: PipelineContext): def serve(host: str, port: int, pipeline: PipelineContext, eager: bool = False): + embeddings_loader = EmbeddingLoader(pipeline) + embeddings_service = EmbeddingsService(loader=embeddings_loader) semantic_search = SemanticSearch(pipeline) if eager: initialize_search_engine(pipeline) + embeddings_loader.load("pacmap") + embeddings_loader.load("t-sne") + embeddings_loader.load("trimap") + embeddings_loader.load("umap") listen_address = f"{host}:{port}" server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) services.add_SemanticSearchServicer_to_server(semantic_search, server) + services.add_EmbeddingsServicer_to_server(embeddings_service, server) server.add_insecure_port(listen_address) logger.info("JusticeAI RPC server is initialized.") From 086cd3abced56ed92b47c75f6752521cedcc2fba Mon Sep 17 00:00:00 2001 From: Stepan Anokhin Date: Wed, 20 Apr 2022 21:24:36 +0700 Subject: [PATCH 06/11] Add REST API for embeddings --- docker-compose.yml | 1 + server/README.md | 1 + server/server/api/__init__.py | 1 + server/server/api/embeddings.py | 224 ++++++++++++++++++++ server/server/api/helpers.py | 8 + server/server/config.py | 1 + server/server/main.py | 5 + server/server/queue/celery/__init__.py | 3 + server/server/queue/celery/task_metadata.py | 2 +- server/server/queue/instance.py | 2 + server/server/queue/model.py | 11 +- 11 files changed, 257 insertions(+), 2 deletions(-) create mode 100644 server/server/api/embeddings.py diff --git a/docker-compose.yml b/docker-compose.yml index 431b5275..b664be54 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -153,6 +153,7 @@ services: SECURITY_STORAGE_PATH: "/project/data/representations" RPC_SERVER_HOST: "rpc" RPC_SERVER_PORT: 50051 + EMBEDDINGS_FOLDER: "/project/data/representations/embeddings" volumes: # Set the BENETECH_DATA_LOCATION environment variable to the path # on your host machine where you placed your video files diff --git a/server/README.md b/server/README.md index 38034e1e..c0c4bf9d 100644 --- a/server/README.md +++ b/server/README.md @@ -56,6 +56,7 @@ Server honors the following environment variables: * `REDIS_CACHE_HOST` - redis cache host (default is `redis`) * `REDIS_CACHE_PORT` - redis cache port (default `6379`) * `REDIS_CACHE_DB` - redis cache db (default is `0`) + * `EMBEDDINGS_FOLDER` - folder with embeddings tiles (default is `./embeddings`) Server accepts the following command-line arguments: diff --git a/server/server/api/__init__.py b/server/server/api/__init__.py index 72c7828c..15dcb797 100644 --- a/server/server/api/__init__.py +++ b/server/server/api/__init__.py @@ -20,6 +20,7 @@ online, processing, health, + embeddings, ) from .blueprint import api diff --git a/server/server/api/embeddings.py b/server/server/api/embeddings.py new file mode 100644 index 00000000..532659fb --- /dev/null +++ b/server/server/api/embeddings.py @@ -0,0 +1,224 @@ +import json +import os +import re +from datetime import datetime +from functools import lru_cache +from http import HTTPStatus +from json import JSONDecodeError +from os.path import dirname, basename +from typing import Optional, Tuple, Dict, List + +from dataclasses import dataclass, asdict +from flask import abort, send_from_directory, jsonify, request + +import rpc.rpc_pb2 as proto +from db.access.files import FilesDAO +from db.schema import Files +from server import time_utils +from .blueprint import api +from .helpers import ( + get_config, + parse_positive_int, + embeddings_rpc, +) +from ..model import database, Transform + + +@dataclass +class TilesBBox: + """Bounding box of embedding space covered by tiles collection.""" + + x: Tuple[float, float] + y: Tuple[float, float] + + +@dataclass +class TilesInfo: + """Tiles collection descriptor.""" + + algorithm: str + available: bool + max_zoom: Optional[int] = None + last_update: Optional[datetime] = None + bbox: Optional[TilesBBox] = None + point_size: Optional[float] = None + + def json_data(self) -> Dict: + """Covert to json data.""" + data = asdict(self) + if self.last_update is not None: + data["last_update"] = time_utils.dumps(self.last_update) + return data + + +class TilesStorage: + """Class to provide access to tile collections for different embeddings algorithms.""" + + DATE_FORMAT = "%Y_%m_%d_%H%M%S%f" + COLL_FORMAT = r"^tiles_(?P\d+)zoom__(?P