Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: use singleton in llama_cpp #1013

Merged
merged 2 commits into from
Jun 25, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 26 additions & 23 deletions application/llm/llama_cpp.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,30 @@
from application.llm.base import BaseLLM
from application.core.settings import settings
import threading

class LlamaSingleton:
_instances = {}
_lock = threading.Lock() # Add a lock for thread synchronization

@classmethod
def get_instance(cls, llm_name):
if llm_name not in cls._instances:
try:
from llama_cpp import Llama
except ImportError:
raise ImportError(

Check warning on line 15 in application/llm/llama_cpp.py

View check run for this annotation

Codecov / codecov/patch

application/llm/llama_cpp.py#L11-L15

Added lines #L11 - L15 were not covered by tests
"Please install llama_cpp using pip install llama-cpp-python"
)
cls._instances[llm_name] = Llama(model_path=llm_name, n_ctx=2048)
return cls._instances[llm_name]

Check warning on line 19 in application/llm/llama_cpp.py

View check run for this annotation

Codecov / codecov/patch

application/llm/llama_cpp.py#L18-L19

Added lines #L18 - L19 were not covered by tests

@classmethod
def query_model(cls, llm, prompt, **kwargs):
with cls._lock:
return llm(prompt, **kwargs)

Check warning on line 24 in application/llm/llama_cpp.py

View check run for this annotation

Codecov / codecov/patch

application/llm/llama_cpp.py#L23-L24

Added lines #L23 - L24 were not covered by tests


class LlamaCpp(BaseLLM):

def __init__(
self,
api_key=None,
Expand All @@ -12,41 +33,23 @@
*args,
**kwargs,
):
global llama
try:
from llama_cpp import Llama
except ImportError:
raise ImportError(
"Please install llama_cpp using pip install llama-cpp-python"
)

super().__init__(*args, **kwargs)
self.api_key = api_key
self.user_api_key = user_api_key
llama = Llama(model_path=llm_name, n_ctx=2048)
self.llama = LlamaSingleton.get_instance(llm_name)

Check warning on line 39 in application/llm/llama_cpp.py

View check run for this annotation

Codecov / codecov/patch

application/llm/llama_cpp.py#L39

Added line #L39 was not covered by tests

def _raw_gen(self, baseself, model, messages, stream=False, **kwargs):
context = messages[0]["content"]
user_question = messages[-1]["content"]
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"

result = llama(prompt, max_tokens=150, echo=False)

# import sys
# print(result['choices'][0]['text'].split('### Answer \n')[-1], file=sys.stderr)

result = LlamaSingleton.query_model(self.llama, prompt, max_tokens=150, echo=False)

Check warning on line 45 in application/llm/llama_cpp.py

View check run for this annotation

Codecov / codecov/patch

application/llm/llama_cpp.py#L45

Added line #L45 was not covered by tests
return result["choices"][0]["text"].split("### Answer \n")[-1]

def _raw_gen_stream(self, baseself, model, messages, stream=True, **kwargs):
context = messages[0]["content"]
user_question = messages[-1]["content"]
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"

result = llama(prompt, max_tokens=150, echo=False, stream=stream)

# import sys
# print(list(result), file=sys.stderr)

result = LlamaSingleton.query_model(self.llama, prompt, max_tokens=150, echo=False, stream=stream)

Check warning on line 52 in application/llm/llama_cpp.py

View check run for this annotation

Codecov / codecov/patch

application/llm/llama_cpp.py#L52

Added line #L52 was not covered by tests
for item in result:
for choice in item["choices"]:
yield choice["text"]
yield choice["text"]

Check warning on line 55 in application/llm/llama_cpp.py

View check run for this annotation

Codecov / codecov/patch

application/llm/llama_cpp.py#L55

Added line #L55 was not covered by tests
Loading