From d186eed9d4135eef176d163690cff6064108eaee Mon Sep 17 00:00:00 2001 From: Debanjum Date: Thu, 9 Jan 2025 02:50:52 +0700 Subject: [PATCH] Support using Embeddings Model exposed via OpenAI (compatible) API --- src/khoj/configure.py | 1 + .../commands/change_default_model.py | 5 +- ...nfig_embeddings_inference_endpoint_type.py | 29 +++++++ src/khoj/database/models/__init__.py | 9 +++ src/khoj/processor/embeddings.py | 76 +++++++++++++------ 5 files changed, 93 insertions(+), 27 deletions(-) create mode 100644 src/khoj/database/migrations/0079_searchmodelconfig_embeddings_inference_endpoint_type.py diff --git a/src/khoj/configure.py b/src/khoj/configure.py index 900b3e30d..89b1e5552 100644 --- a/src/khoj/configure.py +++ b/src/khoj/configure.py @@ -249,6 +249,7 @@ def configure_server( model.bi_encoder, model.embeddings_inference_endpoint, model.embeddings_inference_endpoint_api_key, + model.embeddings_inference_endpoint_type, query_encode_kwargs=model.bi_encoder_query_encode_config, docs_encode_kwargs=model.bi_encoder_docs_encode_config, model_kwargs=model.bi_encoder_model_config, diff --git a/src/khoj/database/management/commands/change_default_model.py b/src/khoj/database/management/commands/change_default_model.py index d55fd44e0..43111dd98 100644 --- a/src/khoj/database/management/commands/change_default_model.py +++ b/src/khoj/database/management/commands/change_default_model.py @@ -3,11 +3,11 @@ from django.core.management.base import BaseCommand from django.db import transaction -from django.db.models import Count, Q +from django.db.models import Q from tqdm import tqdm from khoj.database.adapters import get_default_search_model -from khoj.database.models import Agent, Entry, KhojUser, SearchModelConfig +from khoj.database.models import Entry, SearchModelConfig from khoj.processor.embeddings import EmbeddingsModel logging.basicConfig(level=logging.INFO) @@ -74,6 +74,7 @@ def regenerate_entries(entry_filter: Q, embeddings_model: EmbeddingsModel, searc model.bi_encoder, model.embeddings_inference_endpoint, model.embeddings_inference_endpoint_api_key, + model.embeddings_inference_endpoint_type, query_encode_kwargs=model.bi_encoder_query_encode_config, docs_encode_kwargs=model.bi_encoder_docs_encode_config, model_kwargs=model.bi_encoder_model_config, diff --git a/src/khoj/database/migrations/0079_searchmodelconfig_embeddings_inference_endpoint_type.py b/src/khoj/database/migrations/0079_searchmodelconfig_embeddings_inference_endpoint_type.py new file mode 100644 index 000000000..a05668da5 --- /dev/null +++ b/src/khoj/database/migrations/0079_searchmodelconfig_embeddings_inference_endpoint_type.py @@ -0,0 +1,29 @@ +# Generated by Django 5.0.10 on 2025-01-08 15:09 + +from django.db import migrations, models + + +def set_endpoint_type(apps, schema_editor): + SearchModelConfig = apps.get_model("database", "SearchModelConfig") + SearchModelConfig.objects.filter(embeddings_inference_endpoint__isnull=False).exclude( + embeddings_inference_endpoint="" + ).update(embeddings_inference_endpoint_type="huggingface") + + +class Migration(migrations.Migration): + dependencies = [ + ("database", "0078_khojuser_email_verification_code_expiry"), + ] + + operations = [ + migrations.AddField( + model_name="searchmodelconfig", + name="embeddings_inference_endpoint_type", + field=models.CharField( + choices=[("huggingface", "Huggingface"), ("openai", "Openai"), ("local", "Local")], + default="local", + max_length=200, + ), + ), + migrations.RunPython(set_endpoint_type, reverse_code=migrations.RunPython.noop), + ] diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py index 68fae4345..29dc5e584 100644 --- a/src/khoj/database/models/__init__.py +++ b/src/khoj/database/models/__init__.py @@ -481,6 +481,11 @@ class SearchModelConfig(DbBaseModel): class ModelType(models.TextChoices): TEXT = "text" + class ApiType(models.TextChoices): + HUGGINGFACE = "huggingface" + OPENAI = "openai" + LOCAL = "local" + # This is the model name exposed to users on their settings page name = models.CharField(max_length=200, default="default") # Type of content the model can generate embeddings for @@ -501,6 +506,10 @@ class ModelType(models.TextChoices): embeddings_inference_endpoint = models.CharField(max_length=200, default=None, null=True, blank=True) # Inference server API Key to use for embeddings inference. Bi-encoder model should be hosted on this server embeddings_inference_endpoint_api_key = models.CharField(max_length=200, default=None, null=True, blank=True) + # Inference server API type to use for embeddings inference. + embeddings_inference_endpoint_type = models.CharField( + max_length=200, choices=ApiType.choices, default=ApiType.LOCAL + ) # Inference server API endpoint to use for embeddings inference. Cross-encoder model should be hosted on this server cross_encoder_inference_endpoint = models.CharField(max_length=200, default=None, null=True, blank=True) # Inference server API Key to use for embeddings inference. Cross-encoder model should be hosted on this server diff --git a/src/khoj/processor/embeddings.py b/src/khoj/processor/embeddings.py index b224e7f51..0e7b66578 100644 --- a/src/khoj/processor/embeddings.py +++ b/src/khoj/processor/embeddings.py @@ -1,6 +1,8 @@ import logging from typing import List +from urllib.parse import urlparse +import openai import requests import tqdm from sentence_transformers import CrossEncoder, SentenceTransformer @@ -13,7 +15,14 @@ ) from torch import nn -from khoj.utils.helpers import fix_json_dict, get_device, merge_dicts, timer +from khoj.database.models import SearchModelConfig +from khoj.utils.helpers import ( + fix_json_dict, + get_device, + get_openai_client, + merge_dicts, + timer, +) from khoj.utils.rawconfig import SearchResponse logger = logging.getLogger(__name__) @@ -25,6 +34,7 @@ def __init__( model_name: str = "thenlper/gte-small", embeddings_inference_endpoint: str = None, embeddings_inference_endpoint_api_key: str = None, + embeddings_inference_endpoint_type=SearchModelConfig.ApiType.LOCAL, query_encode_kwargs: dict = {}, docs_encode_kwargs: dict = {}, model_kwargs: dict = {}, @@ -37,15 +47,16 @@ def __init__( self.model_name = model_name self.inference_endpoint = embeddings_inference_endpoint self.api_key = embeddings_inference_endpoint_api_key - with timer(f"Loaded embedding model {self.model_name}", logger): - self.embeddings_model = SentenceTransformer(self.model_name, **self.model_kwargs) - - def inference_server_enabled(self) -> bool: - return self.api_key is not None and self.inference_endpoint is not None + self.inference_endpoint_type = embeddings_inference_endpoint_type + if self.inference_endpoint_type == SearchModelConfig.ApiType.LOCAL: + with timer(f"Loaded embedding model {self.model_name}", logger): + self.embeddings_model = SentenceTransformer(self.model_name, **self.model_kwargs) def embed_query(self, query): - if self.inference_server_enabled(): - return self.embed_with_api([query])[0] + if self.inference_endpoint_type == SearchModelConfig.ApiType.HUGGINGFACE: + return self.embed_with_hf([query])[0] + elif self.inference_endpoint_type == SearchModelConfig.ApiType.OPENAI: + return self.embed_with_openai([query])[0] return self.embeddings_model.encode([query], **self.query_encode_kwargs)[0] @retry( @@ -54,7 +65,7 @@ def embed_query(self, query): stop=stop_after_attempt(5), before_sleep=before_sleep_log(logger, logging.DEBUG), ) - def embed_with_api(self, docs): + def embed_with_hf(self, docs): payload = {"inputs": docs} headers = { "Authorization": f"Bearer {self.api_key}", @@ -71,23 +82,38 @@ def embed_with_api(self, docs): raise e return response.json()["embeddings"] + @retry( + retry=retry_if_exception_type(requests.exceptions.HTTPError), + wait=wait_random_exponential(multiplier=1, max=10), + stop=stop_after_attempt(5), + before_sleep=before_sleep_log(logger, logging.DEBUG), + ) + def embed_with_openai(self, docs): + client = get_openai_client(self.api_key, self.inference_endpoint) + response = client.embeddings.create(input=docs, model=self.model_name, encoding_format="float") + return [item.embedding for item in response.data] + def embed_documents(self, docs): - if self.inference_server_enabled(): - if "huggingface" not in self.inference_endpoint: - logger.warning( - f"Unsupported inference endpoint: {self.inference_endpoint}. Only HuggingFace supported. Generating embeddings on device instead." - ) - return self.embeddings_model.encode(docs, **self.docs_encode_kwargs).tolist() - # break up the docs payload in chunks of 1000 to avoid hitting rate limits - embeddings = [] - with tqdm.tqdm(total=len(docs)) as pbar: - for i in range(0, len(docs), 1000): - docs_to_embed = docs[i : i + 1000] - generated_embeddings = self.embed_with_api(docs_to_embed) - embeddings += generated_embeddings - pbar.update(1000) - return embeddings - return self.embeddings_model.encode(docs, **self.docs_encode_kwargs).tolist() if docs else [] + if self.inference_endpoint_type == SearchModelConfig.ApiType.LOCAL: + return self.embeddings_model.encode(docs, **self.docs_encode_kwargs).tolist() if docs else [] + elif self.inference_endpoint_type == SearchModelConfig.ApiType.HUGGINGFACE: + embed_with_api = self.embed_with_hf + elif self.inference_endpoint_type == SearchModelConfig.ApiType.OPENAI: + embed_with_api = self.embed_with_openai + else: + logger.warning( + f"Unsupported inference endpoint: {self.inference_endpoint_type}. Generating embeddings locally instead." + ) + return self.embeddings_model.encode(docs, **self.docs_encode_kwargs).tolist() + # break up the docs payload in chunks of 1000 to avoid hitting rate limits + embeddings = [] + with tqdm.tqdm(total=len(docs)) as pbar: + for i in range(0, len(docs), 1000): + docs_to_embed = docs[i : i + 1000] + generated_embeddings = embed_with_api(docs_to_embed) + embeddings += generated_embeddings + pbar.update(1000) + return embeddings class CrossEncoderModel: