Skip to content

Commit

Permalink
Support using Embeddings Model exposed via OpenAI (compatible) API
Browse files Browse the repository at this point in the history
  • Loading branch information
debanjum committed Jan 8, 2025
1 parent e057838 commit 1189679
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 26 deletions.
1 change: 1 addition & 0 deletions src/khoj/configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ def configure_server(
model.bi_encoder,
model.embeddings_inference_endpoint,
model.embeddings_inference_endpoint_api_key,
model.embeddings_inference_endpoint_type,
query_encode_kwargs=model.bi_encoder_query_encode_config,
docs_encode_kwargs=model.bi_encoder_docs_encode_config,
model_kwargs=model.bi_encoder_model_config,
Expand Down
5 changes: 3 additions & 2 deletions src/khoj/database/management/commands/change_default_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@

from django.core.management.base import BaseCommand
from django.db import transaction
from django.db.models import Count, Q
from django.db.models import Q
from tqdm import tqdm

from khoj.database.adapters import get_default_search_model
from khoj.database.models import Agent, Entry, KhojUser, SearchModelConfig
from khoj.database.models import Entry, SearchModelConfig
from khoj.processor.embeddings import EmbeddingsModel

logging.basicConfig(level=logging.INFO)
Expand Down Expand Up @@ -74,6 +74,7 @@ def regenerate_entries(entry_filter: Q, embeddings_model: EmbeddingsModel, searc
model.bi_encoder,
model.embeddings_inference_endpoint,
model.embeddings_inference_endpoint_api_key,
model.embeddings_inference_endpoint_type,
query_encode_kwargs=model.bi_encoder_query_encode_config,
docs_encode_kwargs=model.bi_encoder_docs_encode_config,
model_kwargs=model.bi_encoder_model_config,
Expand Down
9 changes: 9 additions & 0 deletions src/khoj/database/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,11 @@ class SearchModelConfig(DbBaseModel):
class ModelType(models.TextChoices):
TEXT = "text"

class ApiType(models.TextChoices):
HUGGINGFACE = "huggingface"
OPENAI = "openai"
LOCAL = "local"

# This is the model name exposed to users on their settings page
name = models.CharField(max_length=200, default="default")
# Type of content the model can generate embeddings for
Expand All @@ -501,6 +506,10 @@ class ModelType(models.TextChoices):
embeddings_inference_endpoint = models.CharField(max_length=200, default=None, null=True, blank=True)
# Inference server API Key to use for embeddings inference. Bi-encoder model should be hosted on this server
embeddings_inference_endpoint_api_key = models.CharField(max_length=200, default=None, null=True, blank=True)
# Inference server API type to use for embeddings inference.
embeddings_inference_endpoint_type = models.CharField(
max_length=200, choices=ApiType.choices, default=ApiType.LOCAL
)
# Inference server API endpoint to use for embeddings inference. Cross-encoder model should be hosted on this server
cross_encoder_inference_endpoint = models.CharField(max_length=200, default=None, null=True, blank=True)
# Inference server API Key to use for embeddings inference. Cross-encoder model should be hosted on this server
Expand Down
77 changes: 53 additions & 24 deletions src/khoj/processor/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
from typing import List

import openai
import requests
import tqdm
from sentence_transformers import CrossEncoder, SentenceTransformer
Expand All @@ -13,6 +14,7 @@
)
from torch import nn

from khoj.database.models import SearchModelConfig
from khoj.utils.helpers import fix_json_dict, get_device, merge_dicts, timer
from khoj.utils.rawconfig import SearchResponse

Expand All @@ -25,6 +27,7 @@ def __init__(
model_name: str = "thenlper/gte-small",
embeddings_inference_endpoint: str = None,
embeddings_inference_endpoint_api_key: str = None,
embeddings_inference_endpoint_type=SearchModelConfig.ApiType.LOCAL,
query_encode_kwargs: dict = {},
docs_encode_kwargs: dict = {},
model_kwargs: dict = {},
Expand All @@ -37,15 +40,16 @@ def __init__(
self.model_name = model_name
self.inference_endpoint = embeddings_inference_endpoint
self.api_key = embeddings_inference_endpoint_api_key
with timer(f"Loaded embedding model {self.model_name}", logger):
self.embeddings_model = SentenceTransformer(self.model_name, **self.model_kwargs)

def inference_server_enabled(self) -> bool:
return self.api_key is not None and self.inference_endpoint is not None
self.inference_endpoint_type = embeddings_inference_endpoint_type
if self.inference_endpoint_type == SearchModelConfig.ApiType.LOCAL:
with timer(f"Loaded embedding model {self.model_name}", logger):
self.embeddings_model = SentenceTransformer(self.model_name, **self.model_kwargs)

def embed_query(self, query):
if self.inference_server_enabled():
return self.embed_with_api([query])[0]
if self.inference_endpoint_type == SearchModelConfig.ApiType.HUGGINGFACE:
return self.embed_with_hf([query])[0]
elif self.inference_endpoint_type == SearchModelConfig.ApiType.OPENAI:
return self.embed_with_openai([query])[0]
return self.embeddings_model.encode([query], **self.query_encode_kwargs)[0]

@retry(
Expand All @@ -54,7 +58,7 @@ def embed_query(self, query):
stop=stop_after_attempt(5),
before_sleep=before_sleep_log(logger, logging.DEBUG),
)
def embed_with_api(self, docs):
def embed_with_hf(self, docs):
payload = {"inputs": docs}
headers = {
"Authorization": f"Bearer {self.api_key}",
Expand All @@ -71,23 +75,48 @@ def embed_with_api(self, docs):
raise e
return response.json()["embeddings"]

@retry(
retry=retry_if_exception_type(requests.exceptions.HTTPError),
wait=wait_random_exponential(multiplier=1, max=10),
stop=stop_after_attempt(5),
before_sleep=before_sleep_log(logger, logging.DEBUG),
)
def embed_with_openai(self, docs):
if self.inference_endpoint and "openai.azure.com" in self.inference_endpoint:
client = openai.AzureOpenAI(
api_key=self.api_key,
azure_endpoint=self.inference_endpoint,
api_version="2024-10-21",
)
else:
client = openai.OpenAI(
api_key=self.api_key,
base_url=self.inference_endpoint,
)
response = client.embeddings.create(input=docs, model=self.model_name, encoding_format="float")
return [item.embedding for item in response.data]

def embed_documents(self, docs):
if self.inference_server_enabled():
if "huggingface" not in self.inference_endpoint:
logger.warning(
f"Unsupported inference endpoint: {self.inference_endpoint}. Only HuggingFace supported. Generating embeddings on device instead."
)
return self.embeddings_model.encode(docs, **self.docs_encode_kwargs).tolist()
# break up the docs payload in chunks of 1000 to avoid hitting rate limits
embeddings = []
with tqdm.tqdm(total=len(docs)) as pbar:
for i in range(0, len(docs), 1000):
docs_to_embed = docs[i : i + 1000]
generated_embeddings = self.embed_with_api(docs_to_embed)
embeddings += generated_embeddings
pbar.update(1000)
return embeddings
return self.embeddings_model.encode(docs, **self.docs_encode_kwargs).tolist() if docs else []
if self.inference_endpoint_type == SearchModelConfig.ApiType.LOCAL:
return self.embeddings_model.encode(docs, **self.docs_encode_kwargs).tolist() if docs else []
elif self.inference_endpoint_type == SearchModelConfig.ApiType.HUGGINGFACE:
embed_with_api = self.embed_with_hf
elif self.inference_endpoint_type == SearchModelConfig.ApiType.OPENAI:
embed_with_api = self.embed_with_openai
else:
logger.warning(
f"Unsupported inference endpoint: {self.inference_endpoint_type}. Generating embeddings locally instead."
)
return self.embeddings_model.encode(docs, **self.docs_encode_kwargs).tolist()
# break up the docs payload in chunks of 1000 to avoid hitting rate limits
embeddings = []
with tqdm.tqdm(total=len(docs)) as pbar:
for i in range(0, len(docs), 1000):
docs_to_embed = docs[i : i + 1000]
generated_embeddings = embed_with_api(docs_to_embed)
embeddings += generated_embeddings
pbar.update(1000)
return embeddings


class CrossEncoderModel:
Expand Down

0 comments on commit 1189679

Please sign in to comment.