diff --git a/src/khoj/configure.py b/src/khoj/configure.py index 900b3e30d..89b1e5552 100644 --- a/src/khoj/configure.py +++ b/src/khoj/configure.py @@ -249,6 +249,7 @@ def configure_server( model.bi_encoder, model.embeddings_inference_endpoint, model.embeddings_inference_endpoint_api_key, + model.embeddings_inference_endpoint_type, query_encode_kwargs=model.bi_encoder_query_encode_config, docs_encode_kwargs=model.bi_encoder_docs_encode_config, model_kwargs=model.bi_encoder_model_config, diff --git a/src/khoj/database/management/commands/change_default_model.py b/src/khoj/database/management/commands/change_default_model.py index d55fd44e0..43111dd98 100644 --- a/src/khoj/database/management/commands/change_default_model.py +++ b/src/khoj/database/management/commands/change_default_model.py @@ -3,11 +3,11 @@ from django.core.management.base import BaseCommand from django.db import transaction -from django.db.models import Count, Q +from django.db.models import Q from tqdm import tqdm from khoj.database.adapters import get_default_search_model -from khoj.database.models import Agent, Entry, KhojUser, SearchModelConfig +from khoj.database.models import Entry, SearchModelConfig from khoj.processor.embeddings import EmbeddingsModel logging.basicConfig(level=logging.INFO) @@ -74,6 +74,7 @@ def regenerate_entries(entry_filter: Q, embeddings_model: EmbeddingsModel, searc model.bi_encoder, model.embeddings_inference_endpoint, model.embeddings_inference_endpoint_api_key, + model.embeddings_inference_endpoint_type, query_encode_kwargs=model.bi_encoder_query_encode_config, docs_encode_kwargs=model.bi_encoder_docs_encode_config, model_kwargs=model.bi_encoder_model_config, diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py index 68fae4345..29dc5e584 100644 --- a/src/khoj/database/models/__init__.py +++ b/src/khoj/database/models/__init__.py @@ -481,6 +481,11 @@ class SearchModelConfig(DbBaseModel): class ModelType(models.TextChoices): TEXT = "text" + class ApiType(models.TextChoices): + HUGGINGFACE = "huggingface" + OPENAI = "openai" + LOCAL = "local" + # This is the model name exposed to users on their settings page name = models.CharField(max_length=200, default="default") # Type of content the model can generate embeddings for @@ -501,6 +506,10 @@ class ModelType(models.TextChoices): embeddings_inference_endpoint = models.CharField(max_length=200, default=None, null=True, blank=True) # Inference server API Key to use for embeddings inference. Bi-encoder model should be hosted on this server embeddings_inference_endpoint_api_key = models.CharField(max_length=200, default=None, null=True, blank=True) + # Inference server API type to use for embeddings inference. + embeddings_inference_endpoint_type = models.CharField( + max_length=200, choices=ApiType.choices, default=ApiType.LOCAL + ) # Inference server API endpoint to use for embeddings inference. Cross-encoder model should be hosted on this server cross_encoder_inference_endpoint = models.CharField(max_length=200, default=None, null=True, blank=True) # Inference server API Key to use for embeddings inference. Cross-encoder model should be hosted on this server diff --git a/src/khoj/processor/embeddings.py b/src/khoj/processor/embeddings.py index b224e7f51..0b29cef3b 100644 --- a/src/khoj/processor/embeddings.py +++ b/src/khoj/processor/embeddings.py @@ -1,6 +1,7 @@ import logging from typing import List +import openai import requests import tqdm from sentence_transformers import CrossEncoder, SentenceTransformer @@ -13,6 +14,7 @@ ) from torch import nn +from khoj.database.models import SearchModelConfig from khoj.utils.helpers import fix_json_dict, get_device, merge_dicts, timer from khoj.utils.rawconfig import SearchResponse @@ -25,6 +27,7 @@ def __init__( model_name: str = "thenlper/gte-small", embeddings_inference_endpoint: str = None, embeddings_inference_endpoint_api_key: str = None, + embeddings_inference_endpoint_type=SearchModelConfig.ApiType.LOCAL, query_encode_kwargs: dict = {}, docs_encode_kwargs: dict = {}, model_kwargs: dict = {}, @@ -37,15 +40,16 @@ def __init__( self.model_name = model_name self.inference_endpoint = embeddings_inference_endpoint self.api_key = embeddings_inference_endpoint_api_key - with timer(f"Loaded embedding model {self.model_name}", logger): - self.embeddings_model = SentenceTransformer(self.model_name, **self.model_kwargs) - - def inference_server_enabled(self) -> bool: - return self.api_key is not None and self.inference_endpoint is not None + self.inference_endpoint_type = embeddings_inference_endpoint_type + if self.inference_endpoint_type == SearchModelConfig.ApiType.LOCAL: + with timer(f"Loaded embedding model {self.model_name}", logger): + self.embeddings_model = SentenceTransformer(self.model_name, **self.model_kwargs) def embed_query(self, query): - if self.inference_server_enabled(): - return self.embed_with_api([query])[0] + if self.inference_endpoint_type == SearchModelConfig.ApiType.HUGGINGFACE: + return self.embed_with_hf([query])[0] + elif self.inference_endpoint_type == SearchModelConfig.ApiType.OPENAI: + return self.embed_with_openai([query])[0] return self.embeddings_model.encode([query], **self.query_encode_kwargs)[0] @retry( @@ -54,7 +58,7 @@ def embed_query(self, query): stop=stop_after_attempt(5), before_sleep=before_sleep_log(logger, logging.DEBUG), ) - def embed_with_api(self, docs): + def embed_with_hf(self, docs): payload = {"inputs": docs} headers = { "Authorization": f"Bearer {self.api_key}", @@ -71,23 +75,48 @@ def embed_with_api(self, docs): raise e return response.json()["embeddings"] + @retry( + retry=retry_if_exception_type(requests.exceptions.HTTPError), + wait=wait_random_exponential(multiplier=1, max=10), + stop=stop_after_attempt(5), + before_sleep=before_sleep_log(logger, logging.DEBUG), + ) + def embed_with_openai(self, docs): + if self.inference_endpoint and "openai.azure.com" in self.inference_endpoint: + client = openai.AzureOpenAI( + api_key=self.api_key, + azure_endpoint=self.inference_endpoint, + api_version="2024-10-21", + ) + else: + client = openai.OpenAI( + api_key=self.api_key, + base_url=self.inference_endpoint, + ) + response = client.embeddings.create(input=docs, model=self.model_name, encoding_format="float") + return [item.embedding for item in response.data] + def embed_documents(self, docs): - if self.inference_server_enabled(): - if "huggingface" not in self.inference_endpoint: - logger.warning( - f"Unsupported inference endpoint: {self.inference_endpoint}. Only HuggingFace supported. Generating embeddings on device instead." - ) - return self.embeddings_model.encode(docs, **self.docs_encode_kwargs).tolist() - # break up the docs payload in chunks of 1000 to avoid hitting rate limits - embeddings = [] - with tqdm.tqdm(total=len(docs)) as pbar: - for i in range(0, len(docs), 1000): - docs_to_embed = docs[i : i + 1000] - generated_embeddings = self.embed_with_api(docs_to_embed) - embeddings += generated_embeddings - pbar.update(1000) - return embeddings - return self.embeddings_model.encode(docs, **self.docs_encode_kwargs).tolist() if docs else [] + if self.inference_endpoint_type == SearchModelConfig.ApiType.LOCAL: + return self.embeddings_model.encode(docs, **self.docs_encode_kwargs).tolist() if docs else [] + elif self.inference_endpoint_type == SearchModelConfig.ApiType.HUGGINGFACE: + embed_with_api = self.embed_with_hf + elif self.inference_endpoint_type == SearchModelConfig.ApiType.OPENAI: + embed_with_api = self.embed_with_openai + else: + logger.warning( + f"Unsupported inference endpoint: {self.inference_endpoint_type}. Generating embeddings locally instead." + ) + return self.embeddings_model.encode(docs, **self.docs_encode_kwargs).tolist() + # break up the docs payload in chunks of 1000 to avoid hitting rate limits + embeddings = [] + with tqdm.tqdm(total=len(docs)) as pbar: + for i in range(0, len(docs), 1000): + docs_to_embed = docs[i : i + 1000] + generated_embeddings = embed_with_api(docs_to_embed) + embeddings += generated_embeddings + pbar.update(1000) + return embeddings class CrossEncoderModel: