Skip to content

Commit

Permalink
Attention sparse embeddings (#235)
Browse files Browse the repository at this point in the history
* WIP: sparse embeddings using attention

* support for stopwords

* apply stopwords

* proceed implementation of sparse attention embeddings (#234)

* complete inference

* query embed + comment

* use simpler weights formula instead of sorting of words

* update tests

* fix: fix bm42 usage, add query_embed to SparseTextEmbedding, update tests

---------

Co-authored-by: George <[email protected]>
  • Loading branch information
generall and joein authored May 24, 2024
1 parent 316c336 commit dfd25d4
Show file tree
Hide file tree
Showing 15 changed files with 486 additions and 66 deletions.
13 changes: 10 additions & 3 deletions fastembed/common/onnx_model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, Generic, Iterable, Optional, Tuple, Type, TypeVar, Sequence
import warnings
Expand All @@ -13,13 +14,19 @@
T = TypeVar("T")


@dataclass
class OnnxOutputContext:
model_output: np.ndarray
attention_mask: Optional[np.ndarray] = None
input_ids: Optional[np.ndarray] = None


class OnnxModel(Generic[T]):
@classmethod
def _get_worker_class(cls) -> Type["EmbeddingWorker"]:
raise NotImplementedError("Subclasses must implement this method")

@classmethod
def _post_process_onnx_output(cls, output: Tuple[np.ndarray, np.ndarray]) -> Iterable[T]:
def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[T]:
raise NotImplementedError("Subclasses must implement this method")

def __init__(self) -> None:
Expand Down Expand Up @@ -74,7 +81,7 @@ def load_onnx_model(
RuntimeWarning,
)

def onnx_embed(self, *args, **kwargs) -> Tuple[np.ndarray, np.ndarray]:
def onnx_embed(self, *args, **kwargs) -> OnnxOutputContext:
raise NotImplementedError("Subclasses must implement this method")


Expand Down
32 changes: 24 additions & 8 deletions fastembed/common/preprocessor_utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,24 @@
import json
from pathlib import Path
from typing import Tuple

from tokenizers import Tokenizer, AddedToken

from fastembed.image.transform.operators import Compose


def load_tokenizer(model_dir: Path, max_length: int = 512) -> Tokenizer:
def load_special_tokens(model_dir: Path) -> dict:
tokens_map_path = model_dir / "special_tokens_map.json"
if not tokens_map_path.exists():
raise ValueError(f"Could not find special_tokens_map.json in {model_dir}")

with open(str(tokens_map_path)) as tokens_map_file:
tokens_map = json.load(tokens_map_file)

return tokens_map


def load_tokenizer(model_dir: Path, max_length: int = 512) -> Tuple[Tokenizer, dict]:
config_path = model_dir / "config.json"
if not config_path.exists():
raise ValueError(f"Could not find config.json in {model_dir}")
Expand All @@ -19,18 +31,13 @@ def load_tokenizer(model_dir: Path, max_length: int = 512) -> Tokenizer:
if not tokenizer_config_path.exists():
raise ValueError(f"Could not find tokenizer_config.json in {model_dir}")

tokens_map_path = model_dir / "special_tokens_map.json"
if not tokens_map_path.exists():
raise ValueError(f"Could not find special_tokens_map.json in {model_dir}")

with open(str(config_path)) as config_file:
config = json.load(config_file)

with open(str(tokenizer_config_path)) as tokenizer_config_file:
tokenizer_config = json.load(tokenizer_config_file)

with open(str(tokens_map_path)) as tokens_map_file:
tokens_map = json.load(tokens_map_file)
tokens_map = load_special_tokens(model_dir)

tokenizer = Tokenizer.from_file(str(tokenizer_path))
tokenizer.enable_truncation(max_length=min(tokenizer_config["model_max_length"], max_length))
Expand All @@ -44,7 +51,16 @@ def load_tokenizer(model_dir: Path, max_length: int = 512) -> Tokenizer:
elif isinstance(token, dict):
tokenizer.add_special_tokens([AddedToken(**token)])

return tokenizer
special_token_to_id = {}

for token in tokens_map.values():
if isinstance(token, str):
special_token_to_id[token] = tokenizer.token_to_id(token)
elif isinstance(token, dict):
token_str = token.get("content", "")
special_token_to_id[token_str] = tokenizer.token_to_id(token_str)

return tokenizer, special_token_to_id


def load_preprocessor(model_dir: Path) -> Compose:
Expand Down
6 changes: 3 additions & 3 deletions fastembed/image/onnx_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import numpy as np

from fastembed.common.onnx_model import OnnxOutputContext
from fastembed.common.utils import normalize, define_cache_dir
from fastembed.common import ImageInput, OnnxProvider
from fastembed.image.image_embedding_base import ImageEmbeddingBase
Expand Down Expand Up @@ -108,9 +109,8 @@ def _preprocess_onnx_input(self, onnx_input: Dict[str, np.ndarray]) -> Dict[str,

return onnx_input

@classmethod
def _post_process_onnx_output(cls, output: np.ndarray) -> Iterable[np.ndarray]:
return normalize(output).astype(np.float32)
def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[np.ndarray]:
return normalize(output.model_output).astype(np.float32)


class OnnxImageEmbeddingWorker(ImageEmbeddingWorker):
Expand Down
11 changes: 6 additions & 5 deletions fastembed/image/onnx_image_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import numpy as np

from fastembed.common.preprocessor_utils import load_preprocessor
from fastembed.common.onnx_model import OnnxModel, EmbeddingWorker, T
from fastembed.common.onnx_model import OnnxModel, EmbeddingWorker, T, OnnxOutputContext
from fastembed.common import PathInput, ImageInput, OnnxProvider
from fastembed.common.utils import iter_batch
from fastembed.parallel_processor import ParallelWorkerPool
Expand All @@ -21,8 +21,7 @@ class OnnxImageModel(OnnxModel[T]):
def _get_worker_class(cls) -> Type["ImageEmbeddingWorker"]:
raise NotImplementedError("Subclasses must implement this method")

@classmethod
def _post_process_onnx_output(cls, output: np.ndarray) -> Iterable[T]:
def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[T]:
raise NotImplementedError("Subclasses must implement this method")

def __init__(self) -> None:
Expand All @@ -47,7 +46,7 @@ def load_onnx_model(
)
self.processor = load_preprocessor(model_dir=model_dir)

def onnx_embed(self, images: List[PathInput]) -> np.ndarray:
def onnx_embed(self, images: List[PathInput]) -> OnnxOutputContext:
with contextlib.ExitStack():
image_files = [Image.open(image) for image in images]
encoded = self.processor(image_files)
Expand All @@ -56,7 +55,9 @@ def onnx_embed(self, images: List[PathInput]) -> np.ndarray:

model_output = self.model.run(None, onnx_input)
embeddings = model_output[0]
return embeddings
return OnnxOutputContext(
model_output=embeddings
)

def _embed_images(
self,
Expand Down
Loading

0 comments on commit dfd25d4

Please sign in to comment.