diff --git a/application/api/answer/routes.py b/application/api/answer/routes.py index ae9ef71f7..86943ff18 100644 --- a/application/api/answer/routes.py +++ b/application/api/answer/routes.py @@ -13,7 +13,7 @@ from application.core.settings import settings -from application.llm.openai import OpenAILLM, AzureOpenAILLM +from application.llm.llm_creator import LLMCreator from application.vectorstore.faiss import FaissStore from application.error import bad_request @@ -128,16 +128,8 @@ def is_azure_configured(): def complete_stream(question, docsearch, chat_history, api_key, conversation_id): - if is_azure_configured(): - llm = AzureOpenAILLM( - openai_api_key=api_key, - openai_api_base=settings.OPENAI_API_BASE, - openai_api_version=settings.OPENAI_API_VERSION, - deployment_name=settings.AZURE_DEPLOYMENT_NAME, - ) - else: - logger.debug("plain OpenAI") - llm = OpenAILLM(api_key=api_key) + llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=api_key) + docs = docsearch.search(question, k=2) # join all page_content together with a newline @@ -270,16 +262,8 @@ def api_answer(): # Note if you have used other embeddings than OpenAI, you need to change the embeddings docsearch = FaissStore(vectorstore, embeddings_key) - if is_azure_configured(): - llm = AzureOpenAILLM( - openai_api_key=api_key, - openai_api_base=settings.OPENAI_API_BASE, - openai_api_version=settings.OPENAI_API_VERSION, - deployment_name=settings.AZURE_DEPLOYMENT_NAME, - ) - else: - logger.debug("plain OpenAI") - llm = OpenAILLM(api_key=api_key) + + llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=api_key) diff --git a/application/core/settings.py b/application/core/settings.py index d127c293b..1479beb38 100644 --- a/application/core/settings.py +++ b/application/core/settings.py @@ -4,7 +4,7 @@ class Settings(BaseSettings): - LLM_NAME: str = "openai_chat" + LLM_NAME: str = "openai" EMBEDDINGS_NAME: str = "openai_text-embedding-ada-002" CELERY_BROKER_URL: str = "redis://localhost:6379/0" CELERY_RESULT_BACKEND: str = "redis://localhost:6379/1" diff --git a/application/llm/llm_creator.py b/application/llm/llm_creator.py new file mode 100644 index 000000000..a7ffc0f65 --- /dev/null +++ b/application/llm/llm_creator.py @@ -0,0 +1,20 @@ +from application.llm.openai import OpenAILLM, AzureOpenAILLM +from application.llm.sagemaker import SagemakerAPILLM +from application.llm.huggingface import HuggingFaceLLM + + + +class LLMCreator: + llms = { + 'openai': OpenAILLM, + 'azure_openai': AzureOpenAILLM, + 'sagemaker': SagemakerAPILLM, + 'huggingface': HuggingFaceLLM + } + + @classmethod + def create_llm(cls, type, *args, **kwargs): + llm_class = cls.llms.get(type.lower()) + if not llm_class: + raise ValueError(f"No LLM class found for type {type}") + return llm_class(*args, **kwargs) \ No newline at end of file diff --git a/application/llm/openai.py b/application/llm/openai.py index 23e5fab0e..34d568549 100644 --- a/application/llm/openai.py +++ b/application/llm/openai.py @@ -1,4 +1,5 @@ from application.llm.base import BaseLLM +from application.core.settings import settings class OpenAILLM(BaseLLM): @@ -44,9 +45,9 @@ class AzureOpenAILLM(OpenAILLM): def __init__(self, openai_api_key, openai_api_base, openai_api_version, deployment_name): super().__init__(openai_api_key) - self.api_base = openai_api_base - self.api_version = openai_api_version - self.deployment_name = deployment_name + self.api_base = settings.OPENAI_API_BASE, + self.api_version = settings.OPENAI_API_VERSION, + self.deployment_name = settings.AZURE_DEPLOYMENT_NAME, def _get_openai(self): openai = super()._get_openai() diff --git a/application/llm/sagemaker.py b/application/llm/sagemaker.py new file mode 100644 index 000000000..9ef5d0afe --- /dev/null +++ b/application/llm/sagemaker.py @@ -0,0 +1,27 @@ +from application.llm.base import BaseLLM +from application.core.settings import settings +import requests +import json + +class SagemakerAPILLM(BaseLLM): + + def __init__(self, *args, **kwargs): + self.url = settings.SAGEMAKER_API_URL + + def gen(self, model, engine, messages, stream=False, **kwargs): + context = messages[0]['content'] + user_question = messages[-1]['content'] + prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" + + response = requests.post( + url=self.url, + headers={ + "Content-Type": "application/json; charset=utf-8", + }, + data=json.dumps({"input": prompt}) + ) + + return response.json()['answer'] + + def gen_stream(self, model, engine, messages, stream=True, **kwargs): + raise NotImplementedError("Sagemaker does not support streaming") \ No newline at end of file