Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support using Embeddings Model exposed via OpenAI (compatible) API #1051

Merged
merged 2 commits into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/khoj/configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ def configure_server(
model.bi_encoder,
model.embeddings_inference_endpoint,
model.embeddings_inference_endpoint_api_key,
model.embeddings_inference_endpoint_type,
query_encode_kwargs=model.bi_encoder_query_encode_config,
docs_encode_kwargs=model.bi_encoder_docs_encode_config,
model_kwargs=model.bi_encoder_model_config,
Expand Down
5 changes: 3 additions & 2 deletions src/khoj/database/management/commands/change_default_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@

from django.core.management.base import BaseCommand
from django.db import transaction
from django.db.models import Count, Q
from django.db.models import Q
from tqdm import tqdm

from khoj.database.adapters import get_default_search_model
from khoj.database.models import Agent, Entry, KhojUser, SearchModelConfig
from khoj.database.models import Entry, SearchModelConfig
from khoj.processor.embeddings import EmbeddingsModel

logging.basicConfig(level=logging.INFO)
Expand Down Expand Up @@ -74,6 +74,7 @@ def regenerate_entries(entry_filter: Q, embeddings_model: EmbeddingsModel, searc
model.bi_encoder,
model.embeddings_inference_endpoint,
model.embeddings_inference_endpoint_api_key,
model.embeddings_inference_endpoint_type,
query_encode_kwargs=model.bi_encoder_query_encode_config,
docs_encode_kwargs=model.bi_encoder_docs_encode_config,
model_kwargs=model.bi_encoder_model_config,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Generated by Django 5.0.10 on 2025-01-08 15:09

from django.db import migrations, models


def set_endpoint_type(apps, schema_editor):
SearchModelConfig = apps.get_model("database", "SearchModelConfig")
SearchModelConfig.objects.filter(embeddings_inference_endpoint__isnull=False).exclude(
embeddings_inference_endpoint=""
).update(embeddings_inference_endpoint_type="huggingface")


class Migration(migrations.Migration):
dependencies = [
("database", "0078_khojuser_email_verification_code_expiry"),
]

operations = [
migrations.AddField(
model_name="searchmodelconfig",
name="embeddings_inference_endpoint_type",
field=models.CharField(
choices=[("huggingface", "Huggingface"), ("openai", "Openai"), ("local", "Local")],
default="local",
max_length=200,
),
),
migrations.RunPython(set_endpoint_type, reverse_code=migrations.RunPython.noop),
]
9 changes: 9 additions & 0 deletions src/khoj/database/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,11 @@ class SearchModelConfig(DbBaseModel):
class ModelType(models.TextChoices):
TEXT = "text"

class ApiType(models.TextChoices):
debanjum marked this conversation as resolved.
Show resolved Hide resolved
HUGGINGFACE = "huggingface"
OPENAI = "openai"
LOCAL = "local"

# This is the model name exposed to users on their settings page
name = models.CharField(max_length=200, default="default")
# Type of content the model can generate embeddings for
Expand All @@ -501,6 +506,10 @@ class ModelType(models.TextChoices):
embeddings_inference_endpoint = models.CharField(max_length=200, default=None, null=True, blank=True)
debanjum marked this conversation as resolved.
Show resolved Hide resolved
# Inference server API Key to use for embeddings inference. Bi-encoder model should be hosted on this server
embeddings_inference_endpoint_api_key = models.CharField(max_length=200, default=None, null=True, blank=True)
# Inference server API type to use for embeddings inference.
embeddings_inference_endpoint_type = models.CharField(
max_length=200, choices=ApiType.choices, default=ApiType.LOCAL
)
# Inference server API endpoint to use for embeddings inference. Cross-encoder model should be hosted on this server
cross_encoder_inference_endpoint = models.CharField(max_length=200, default=None, null=True, blank=True)
# Inference server API Key to use for embeddings inference. Cross-encoder model should be hosted on this server
Expand Down
76 changes: 51 additions & 25 deletions src/khoj/processor/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import logging
from typing import List
from urllib.parse import urlparse

import openai
import requests
import tqdm
from sentence_transformers import CrossEncoder, SentenceTransformer
Expand All @@ -13,7 +15,14 @@
)
from torch import nn

from khoj.utils.helpers import fix_json_dict, get_device, merge_dicts, timer
from khoj.database.models import SearchModelConfig
from khoj.utils.helpers import (
fix_json_dict,
get_device,
get_openai_client,
merge_dicts,
timer,
)
from khoj.utils.rawconfig import SearchResponse

logger = logging.getLogger(__name__)
Expand All @@ -25,6 +34,7 @@ def __init__(
model_name: str = "thenlper/gte-small",
embeddings_inference_endpoint: str = None,
embeddings_inference_endpoint_api_key: str = None,
embeddings_inference_endpoint_type=SearchModelConfig.ApiType.LOCAL,
query_encode_kwargs: dict = {},
docs_encode_kwargs: dict = {},
model_kwargs: dict = {},
Expand All @@ -37,15 +47,16 @@ def __init__(
self.model_name = model_name
self.inference_endpoint = embeddings_inference_endpoint
self.api_key = embeddings_inference_endpoint_api_key
with timer(f"Loaded embedding model {self.model_name}", logger):
self.embeddings_model = SentenceTransformer(self.model_name, **self.model_kwargs)

def inference_server_enabled(self) -> bool:
return self.api_key is not None and self.inference_endpoint is not None
self.inference_endpoint_type = embeddings_inference_endpoint_type
if self.inference_endpoint_type == SearchModelConfig.ApiType.LOCAL:
with timer(f"Loaded embedding model {self.model_name}", logger):
self.embeddings_model = SentenceTransformer(self.model_name, **self.model_kwargs)

def embed_query(self, query):
if self.inference_server_enabled():
return self.embed_with_api([query])[0]
if self.inference_endpoint_type == SearchModelConfig.ApiType.HUGGINGFACE:
return self.embed_with_hf([query])[0]
elif self.inference_endpoint_type == SearchModelConfig.ApiType.OPENAI:
return self.embed_with_openai([query])[0]
return self.embeddings_model.encode([query], **self.query_encode_kwargs)[0]

@retry(
Expand All @@ -54,7 +65,7 @@ def embed_query(self, query):
stop=stop_after_attempt(5),
before_sleep=before_sleep_log(logger, logging.DEBUG),
)
def embed_with_api(self, docs):
def embed_with_hf(self, docs):
payload = {"inputs": docs}
headers = {
"Authorization": f"Bearer {self.api_key}",
Expand All @@ -71,23 +82,38 @@ def embed_with_api(self, docs):
raise e
return response.json()["embeddings"]

@retry(
retry=retry_if_exception_type(requests.exceptions.HTTPError),
wait=wait_random_exponential(multiplier=1, max=10),
stop=stop_after_attempt(5),
before_sleep=before_sleep_log(logger, logging.DEBUG),
)
def embed_with_openai(self, docs):
client = get_openai_client(self.api_key, self.inference_endpoint)
response = client.embeddings.create(input=docs, model=self.model_name, encoding_format="float")
return [item.embedding for item in response.data]

def embed_documents(self, docs):
if self.inference_server_enabled():
if "huggingface" not in self.inference_endpoint:
logger.warning(
f"Unsupported inference endpoint: {self.inference_endpoint}. Only HuggingFace supported. Generating embeddings on device instead."
)
return self.embeddings_model.encode(docs, **self.docs_encode_kwargs).tolist()
# break up the docs payload in chunks of 1000 to avoid hitting rate limits
embeddings = []
with tqdm.tqdm(total=len(docs)) as pbar:
for i in range(0, len(docs), 1000):
docs_to_embed = docs[i : i + 1000]
generated_embeddings = self.embed_with_api(docs_to_embed)
embeddings += generated_embeddings
pbar.update(1000)
return embeddings
return self.embeddings_model.encode(docs, **self.docs_encode_kwargs).tolist() if docs else []
if self.inference_endpoint_type == SearchModelConfig.ApiType.LOCAL:
return self.embeddings_model.encode(docs, **self.docs_encode_kwargs).tolist() if docs else []
elif self.inference_endpoint_type == SearchModelConfig.ApiType.HUGGINGFACE:
embed_with_api = self.embed_with_hf
elif self.inference_endpoint_type == SearchModelConfig.ApiType.OPENAI:
embed_with_api = self.embed_with_openai
else:
logger.warning(
f"Unsupported inference endpoint: {self.inference_endpoint_type}. Generating embeddings locally instead."
)
return self.embeddings_model.encode(docs, **self.docs_encode_kwargs).tolist()
debanjum marked this conversation as resolved.
Show resolved Hide resolved
# break up the docs payload in chunks of 1000 to avoid hitting rate limits
embeddings = []
with tqdm.tqdm(total=len(docs)) as pbar:
for i in range(0, len(docs), 1000):
docs_to_embed = docs[i : i + 1000]
generated_embeddings = embed_with_api(docs_to_embed)
embeddings += generated_embeddings
pbar.update(1000)
return embeddings


class CrossEncoderModel:
Expand Down
Loading