Skip to content

Commit

Permalink
Add LocalAI provider
Browse files Browse the repository at this point in the history
  • Loading branch information
djcopley committed Jan 17, 2024
1 parent b0d6f8e commit 61cf705
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 2 deletions.
5 changes: 3 additions & 2 deletions src/shelloracle/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,10 @@ def get_provider(name: str) -> type[Provider]:
:param name: the provider name
:return: the requested provider
"""
from .providers import Ollama, OpenAI
from .providers import Ollama, OpenAI, LocalAI
providers = {
Ollama.name: Ollama,
OpenAI.name: OpenAI
OpenAI.name: OpenAI,
LocalAI.name: LocalAI
}
return providers[name]
1 change: 1 addition & 0 deletions src/shelloracle/providers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .ollama import Ollama
from .openai import OpenAI
from .localai import LocalAI
49 changes: 49 additions & 0 deletions src/shelloracle/providers/localai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from collections.abc import AsyncIterator

from openai import APIError
from openai import AsyncOpenAI as OpenAIClient

from ..config import Setting
from ..provider import Provider, ProviderError


class LocalAI(Provider):
name = "LocalAI"

host = Setting(default="localhost")
port = Setting(default=8080)
model = Setting(default="lunademo")
system_prompt = Setting(
default=(
"Based on the following user description, generate a corresponding Bash command. Focus solely "
"on interpreting the requirements and translating them into a single, executable Bash command. "
"Ensure accuracy and relevance to the user's description. The output should be a valid Bash "
"command that directly aligns with the user's intent, ready for execution in a command-line "
"environment. Output nothing except for the command. No code block, no English explanation, "
"no start/end tags."
)
)

@property
def endpoint(self) -> str:
# computed property because python descriptors need to be bound to an instance before access
return f"http://{self.host}:{self.port}/api/generate"

def __init__(self):
self.client = OpenAIClient(base_url=self.endpoint)

async def generate(self, prompt: str) -> AsyncIterator[str]:
try:
stream = await self.client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": prompt}
],
stream=True,
)
async for chunk in stream:
if chunk.choices[0].delta.content is not None:
yield chunk.choices[0].delta.content
except APIError as e:
raise ProviderError(f"Something went wrong while querying OpenAI: {e}") from e

0 comments on commit 61cf705

Please sign in to comment.