Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: add sentence trimming to OpenAIWrapper #1526

Merged
merged 18 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion mteb/leaderboard/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from pathlib import Path

import gradio as gr
import pandas as pd
from gradio_rangeslider import RangeSlider

import mteb
Expand Down
98 changes: 83 additions & 15 deletions mteb/models/openai_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,38 +15,91 @@


class OpenAIWrapper(Wrapper):
def __init__(self, model_name: str, embed_dim: int | None = None, **kwargs) -> None:
def __init__(
self,
model_name: str,
max_tokens: int,
tokenizer_name: str = "cl100k_base", # since all models use this tokenizer now
embed_dim: int | None = None,
**kwargs,
) -> None:
"""Wrapper for OpenAIs embedding API.
To handle documents larger than 8192 tokens, we truncate the document to the specified sequence length.
"""
requires_package(self, "openai", "Openai text embedding")
from openai import OpenAI

requires_package(self, "tiktoken", "Tiktoken package")
import tiktoken

self._client = OpenAI()
self._model_name = model_name
self._embed_dim = embed_dim
self._max_tokens = max_tokens
self._encoding = tiktoken.get_encoding(tokenizer_name)

def truncate_text_tokens(self, text):
"""Truncate a string to have `max_tokens` according to the given encoding."""
truncated_sentence = self._encoding.encode(text)[: self._max_tokens]
return self._encoding.decode(truncated_sentence)

def encode(self, sentences: list[str], **kwargs: Any) -> np.ndarray:
requires_package(self, "openai", "Openai text embedding")

from openai import NotGiven

if self._model_name == "text-embedding-ada-002" and self._embed_dim is not None:
logger.warning(
"Reducing embedding size available only for text-embedding-3-* models"
)

trimmed_sentences = []
KennethEnevoldsen marked this conversation as resolved.
Show resolved Hide resolved
for sentence in sentences:
encoded_sentence = self._encoding.encode(sentence)
if len(encoded_sentence) > self._max_tokens:
truncated_sentence = self.truncate_text_tokens(sentence)
trimmed_sentences.append(truncated_sentence)
else:
trimmed_sentences.append(sentence)

max_batch_size = 2048
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we change this to something like max_batch_size = kwargs.get("batch_size", 2048)?

sublists = [
sentences[i : i + max_batch_size]
for i in range(0, len(sentences), max_batch_size)
trimmed_sentences[i : i + max_batch_size]
for i in range(0, len(trimmed_sentences), max_batch_size)
]

all_embeddings = []

for sublist in sublists:
response = self._client.embeddings.create(
input=sublist,
model=self._model_name,
encoding_format="float",
dimensions=self._embed_dim or NotGiven(),
)
try:
response = self._client.embeddings.create(
input=sublist,
model=self._model_name,
encoding_format="float",
dimensions=self._embed_dim or NotGiven(),
)
except Exception as e:
# Sleep due to too many requests
logger.info("Sleeping for 10 seconds due to error", e)
import time

time.sleep(10)
try:
response = self._client.embeddings.create(
input=sublist,
model=self._model_name,
encoding_format="float",
dimensions=self._embed_dim or NotGiven(),
)
except Exception as e:
logger.info("Sleeping for 60 seconds due to error", e)
time.sleep(60)
response = self._client.embeddings.create(
input=sublist,
model=self._model_name,
encoding_format="float",
dimensions=self._embed_dim or NotGiven(),
)
all_embeddings.extend(self._to_numpy(response))

return np.array(all_embeddings)
Expand All @@ -57,10 +110,15 @@ def _to_numpy(self, embedding_response) -> np.ndarray:

text_embedding_3_small = ModelMeta(
name="openai/text-embedding-3-small",
revision="1",
revision="2",
release_date="2024-01-25",
languages=None, # supported languages not specified
loader=partial(OpenAIWrapper, model_name="text-embedding-3-small"),
loader=partial(
OpenAIWrapper,
model_name="text-embedding-3-small",
tokenizer_name="cl100k_base",
max_tokens=8192,
),
max_tokens=8191,
embed_dim=1536,
open_weights=False,
Expand All @@ -74,10 +132,15 @@ def _to_numpy(self, embedding_response) -> np.ndarray:
)
text_embedding_3_large = ModelMeta(
name="openai/text-embedding-3-large",
revision="1",
revision="2",
release_date="2024-01-25",
languages=None, # supported languages not specified
loader=partial(OpenAIWrapper, model_name="text-embedding-3-large"),
loader=partial(
OpenAIWrapper,
model_name="text-embedding-3-large",
tokenizer_name="cl100k_base",
max_tokens=8192,
),
max_tokens=8191,
embed_dim=3072,
open_weights=False,
Expand All @@ -89,10 +152,15 @@ def _to_numpy(self, embedding_response) -> np.ndarray:
)
text_embedding_ada_002 = ModelMeta(
name="openai/text-embedding-ada-002",
revision="1",
revision="2",
release_date="2022-12-15",
languages=None, # supported languages not specified
loader=partial(OpenAIWrapper, model_name="text-embedding-ada-002"),
loader=partial(
OpenAIWrapper,
model_name="text-embedding-ada-002",
tokenizer_name="cl100k_base",
max_tokens=8192,
),
reference="https://openai.com/index/new-and-improved-embedding-model/",
max_tokens=8191,
embed_dim=1536,
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ leaderboard = ["gradio>=5.7.1", "gradio_rangeslider>=0.0.8"]
flagembedding = ["FlagEmbedding"]
jina = ["einops>=0.8.0"]
flash_attention = ["flash-attn>=2.6.3"]
openai = ["openai>=1.41.0", "tiktoken>=0.8.0"]


[tool.coverage.report]
Expand Down
Loading