Skip to content

Commit

Permalink
fix: fix bm42 usage, add query_embed to SparseTextEmbedding, update t…
Browse files Browse the repository at this point in the history
…ests
  • Loading branch information
joein authored and generall committed May 22, 2024
1 parent 6e5eafc commit ad80e59
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 19 deletions.
14 changes: 14 additions & 0 deletions fastembed/sparse/sparse_text_embedding.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from typing import List, Type, Dict, Any, Union, Iterable, Optional, Sequence

from fastembed.common import OnnxProvider
from fastembed.sparse.bm42 import Bm42
from fastembed.sparse.sparse_embedding_base import SparseTextEmbeddingBase, SparseEmbedding
from fastembed.sparse.splade_pp import SpladePP


class SparseTextEmbedding(SparseTextEmbeddingBase):
EMBEDDINGS_REGISTRY: List[Type[SparseTextEmbeddingBase]] = [
SpladePP,
Bm42,
]

@classmethod
Expand Down Expand Up @@ -84,3 +86,15 @@ def embed(
List of embeddings, one per document
"""
yield from self.model.embed(documents, batch_size, parallel, **kwargs)

def query_embed(self, query: Union[str, Iterable[str]], **kwargs) -> Iterable[SparseEmbedding]:
"""
Embeds queries
Args:
query (Union[str, Iterable[str]]): The query to embed, or an iterable e.g. list of queries.
Returns:
Iterable[SparseEmbedding]: The sparse embeddings.
"""
yield from self.model.query_embed(query, **kwargs)
27 changes: 19 additions & 8 deletions tests/test_attention_embeddings.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,24 @@
from fastembed.sparse.bm42 import Bm42
import numpy as np

from fastembed import SparseTextEmbedding


def test_attention_embeddings():
model = Bm42(model_name="Qdrant/bm42-all-minilm-l6-v2-attentions")
model = SparseTextEmbedding(model_name="Qdrant/bm42-all-minilm-l6-v2-attentions")

output = list(model.query_embed([
"I must not fear. Fear is the mind-killer.",
]))
output = list(
model.query_embed(
[
"I must not fear. Fear is the mind-killer.",
]
)
)

assert len(output) == 1

for result in output:
assert len(result.indices) == len(result.values)
assert np.allclose(result.values, np.ones(len(result.values)))

quotes = [
"I must not fear. Fear is the mind-killer.",
Expand All @@ -36,9 +43,13 @@ def test_attention_embeddings():
assert len(result.indices) > 0

# Test support for unknown languages
output = list(model.query_embed([
"привет мир!",
]))
output = list(
model.query_embed(
[
"привет мир!",
]
)
)

assert len(output) == 1

Expand Down
18 changes: 7 additions & 11 deletions tests/test_sparse_embeddings.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest

from fastembed.sparse.sparse_text_embedding import SparseTextEmbedding

CANONICAL_COLUMN_VALUES = {
Expand Down Expand Up @@ -47,30 +48,25 @@ def test_batch_embedding():
docs_to_embed = docs * 10

for model_name, expected_result in CANONICAL_COLUMN_VALUES.items():
print("evaluating", model_name)
model = SparseTextEmbedding(model_name=model_name)
result = next(iter(model.embed(docs_to_embed, batch_size=6)))
print(result.indices)

assert result.indices.tolist() == expected_result["indices"]

for i, value in enumerate(result.values):
assert pytest.approx(value, abs=0.001) == expected_result["values"][i]


def test_single_embedding():
docs_to_embed = docs

for model_name, expected_result in CANONICAL_COLUMN_VALUES.items():
print("evaluating", model_name)
model = SparseTextEmbedding(model_name=model_name)
result = next(iter(model.embed(docs_to_embed, batch_size=6)))
print(result.indices)

assert result.indices.tolist() == expected_result["indices"]
passage_result = next(iter(model.embed(docs, batch_size=6)))
query_result = next(iter(model.query_embed(docs)))
for result in [passage_result, query_result]:
assert result.indices.tolist() == expected_result["indices"]

for i, value in enumerate(result.values):
assert pytest.approx(value, abs=0.001) == expected_result["values"][i]
for i, value in enumerate(result.values):
assert pytest.approx(value, abs=0.001) == expected_result["values"][i]


def test_parallel_processing():
Expand Down

0 comments on commit ad80e59

Please sign in to comment.