Skip to content

Commit

Permalink
improve: Changed the way we are adding the query and document markers…
Browse files Browse the repository at this point in the history
… in colbert (#391)

* improve: Changed the way we are adding the query and document markers in colbert

* fix: Truncate the inout_ids and attention_mask when adding query and document markers to original input length

* fix: Fix broadcast issue

* chore: Remove redundant if condition

* nit

* refactor (#397)

---------

Co-authored-by: George <[email protected]>
  • Loading branch information
hh-space-invader and joein authored Nov 13, 2024
1 parent 8413066 commit 7c93571
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 17 deletions.
16 changes: 5 additions & 11 deletions fastembed/late_interaction/colbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
)
from fastembed.text.onnx_text_model import OnnxTextModel, TextEmbeddingWorker


supported_colbert_models = [
{
"model": "colbert-ir/colbertv2.0",
Expand Down Expand Up @@ -41,7 +42,7 @@
class Colbert(LateInteractionTextEmbeddingBase, OnnxTextModel[np.ndarray]):
QUERY_MARKER_TOKEN_ID = 1
DOCUMENT_MARKER_TOKEN_ID = 2
MIN_QUERY_LENGTH = 32
MIN_QUERY_LENGTH = 31 # it's 32, we add one additional special token in the beginning
MASK_TOKEN = "[MASK]"

def _post_process_onnx_output(
Expand Down Expand Up @@ -69,10 +70,9 @@ def _post_process_onnx_output(
def _preprocess_onnx_input(
self, onnx_input: Dict[str, np.ndarray], is_doc: bool = True
) -> Dict[str, np.ndarray]:
if is_doc:
onnx_input["input_ids"][:, 1] = self.DOCUMENT_MARKER_TOKEN_ID
else:
onnx_input["input_ids"][:, 1] = self.QUERY_MARKER_TOKEN_ID
marker_token = self.DOCUMENT_MARKER_TOKEN_ID if is_doc else self.QUERY_MARKER_TOKEN_ID
onnx_input["input_ids"] = np.insert(onnx_input["input_ids"], 1, marker_token, axis=1)
onnx_input["attention_mask"] = np.insert(onnx_input["attention_mask"], 1, 1, axis=1)
return onnx_input

def tokenize(self, documents: List[str], is_doc: bool = True) -> List[Encoding]:
Expand All @@ -83,9 +83,6 @@ def tokenize(self, documents: List[str], is_doc: bool = True) -> List[Encoding]:
)

def _tokenize_query(self, query: str) -> List[Encoding]:
# "@ " is added to a query to be replaced with a special query token
# make sure that "@ " is considered as a single token
query = f"@ {query}"
encoded = self.tokenizer.encode_batch([query])
# colbert authors recommend to pad queries with [MASK] tokens for query augmentation to improve performance
if len(encoded[0].ids) < self.MIN_QUERY_LENGTH:
Expand All @@ -105,9 +102,6 @@ def _tokenize_query(self, query: str) -> List[Encoding]:
return encoded

def _tokenize_documents(self, documents: List[str]) -> List[Encoding]:
# "@ " is added to a document to be replaced with a special document token
# make sure that "@ " is considered as a single token
documents = ["@ " + doc for doc in documents]
encoded = self.tokenizer.encode_batch(documents)
return encoded

Expand Down
12 changes: 6 additions & 6 deletions fastembed/late_interaction/jina_colbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from fastembed.late_interaction.colbert import Colbert
from fastembed.text.onnx_text_model import TextEmbeddingWorker


supported_jina_colbert_models = [
{
"model": "jinaai/jina-colbert-v2",
Expand All @@ -24,7 +25,7 @@
class JinaColbert(Colbert):
QUERY_MARKER_TOKEN_ID = 250002
DOCUMENT_MARKER_TOKEN_ID = 250003
MIN_QUERY_LENGTH = 32
MIN_QUERY_LENGTH = 31 # it's 32, we add one additional special token in the beginning
MASK_TOKEN = "<mask>"

@classmethod
Expand All @@ -43,11 +44,10 @@ def list_supported_models(cls) -> List[Dict[str, Any]]:
def _preprocess_onnx_input(
self, onnx_input: Dict[str, np.ndarray], is_doc: bool = True
) -> Dict[str, np.ndarray]:
if is_doc:
onnx_input["input_ids"][:, 1] = self.DOCUMENT_MARKER_TOKEN_ID
else:
onnx_input["input_ids"][:, 1] = self.QUERY_MARKER_TOKEN_ID
# the attention mask for jina-colbert-v2 is always 1 in queries
onnx_input = super()._preprocess_onnx_input(onnx_input, is_doc)

# the attention mask for jina-colbert-v2 is always 1 in queries
if not is_doc:
onnx_input["attention_mask"][:] = 1
return onnx_input

Expand Down

0 comments on commit 7c93571

Please sign in to comment.