Skip to content

Commit

Permalink
fix gemini model set up issue (#414)
Browse files Browse the repository at this point in the history
  • Loading branch information
dongyuanjushi authored Jan 11, 2025
1 parent d1b538a commit a8b07f3
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 23 deletions.
32 changes: 19 additions & 13 deletions aios/llm_core/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ def __init__(
variables making this needless.
"""
if isinstance(llm_name, list) != isinstance(llm_backend, list):
raise ValueError
raise ValueError("llm_name and llm_backend do not be the same type")
elif isinstance(llm_backend, list) and len(llm_name) == len(llm_backend):
raise ValueError
raise ValueError("llm_name and llm_backend do not have the same length")

self.llm_name = llm_name if isinstance(llm_name, list) else [llm_name]
self.max_gpu_memory = max_gpu_memory
Expand All @@ -76,9 +76,7 @@ def __init__(
self.log_mode = log_mode
self.llm_backend = llm_backend if isinstance(llm_backend, list) else [llm_backend]
self.context_manager = SimpleContextManager() if use_context_manager else None

if strategy == RouterStrategy.SIMPLE:
self.strategy = SimpleStrategy(self.llm_name)


# Set all supported API keys
api_providers = {
Expand Down Expand Up @@ -111,6 +109,8 @@ def __init__(
else:
print(f"- Not found in config.yaml or environment variables")

# breakpoint()

# Format model names to match backend or instantiate local backends
for idx in range(len(self.llm_name)):
if self.llm_backend[idx] is None:
Expand Down Expand Up @@ -140,18 +140,23 @@ def __init__(
case None:
continue
case _:
prefix = self.llm_backend[idx] + "/"
is_formatted = self.llm_name[idx].startswith(prefix)

# Google backwards compatibility fix
if self.llm_backend[idx] == "google":
self.llm_backend[idx] = "gemini"
if is_formatted:
self.llm_name[idx] = "gemini/" + self.llm_name[idx].split("/")[1]
continue

# Google backwards compatibility fix

prefix = self.llm_backend[idx] + "/"
is_formatted = self.llm_name[idx].startswith(prefix)

# if not is_formatted:
# self.llm_name[idx] = "gemini/" + self.llm_name[idx].split("/")[1]
# continue

if not is_formatted:
self.llm_name[idx] = prefix + self.llm_name[idx]

if strategy == RouterStrategy.SIMPLE:
self.strategy = SimpleStrategy(self.llm_name)

def tool_calling_input_format(self, messages: list, tools: list) -> list:
"""Integrate tool information into the messages for open-sourced LLMs
Expand Down Expand Up @@ -265,12 +270,13 @@ def address_syscall(
llm_syscall.set_start_time(time.time())

restored_context = None

if self.context_manager:
pid = llm_syscall.get_pid()
if self.context_manager.check_restoration(pid):
restored_context = self.context_manager.gen_recover(pid)

if restored_context is not None:
if restored_context:
messages += [{
"role": "assistant",
"content": "" + restored_context,
Expand Down
24 changes: 16 additions & 8 deletions aios/llm_core/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

import os

from aios.config.config_manager import config

class HfLocalBackend:
def __init__(self, model_name, device="auto", max_gpu_memory=None, hostname=None):
print("\n=== HfLocalBackend Initialization ===")
Expand Down Expand Up @@ -35,12 +37,12 @@ def __init__(self, model_name, device="auto", max_gpu_memory=None, hostname=None
self.tokenizer.chat_template = "{% for message in messages %}{% if message['role'] == 'user' %}{{ ' ' }}{% endif %}{{ message['content'] }}{% if not loop.last %}{{ ' ' }}{% endif %}{% endfor %}{{ eos_token }}"

def inference_online(self, messages, temperature, stream=False):
return str(completion(
return completion(
model="huggingface/" + self.model_name,
messages=messages,
temperature=temperature,
api_base=self.hostname,
))
).choices[0].message.content

def __call__(
self,
Expand Down Expand Up @@ -83,7 +85,8 @@ def __init__(self, model_name, device="auto", max_gpu_memory=None, hostname=None
self.model_name = model_name
self.device = device
self.max_gpu_memory = max_gpu_memory
self.hostname = hostname
# self.hostname = hostname
self.hostname = "http://localhost:8001"

# If a hostname is given, then this vLLM instance is hosted as a web server.
# Therefore, do not start the AIOS-based vLLM instance.
Expand All @@ -98,19 +101,22 @@ def __init__(self, model_name, device="auto", max_gpu_memory=None, hostname=None
tensor_parallel_size=1 if max_gpu_memory is None else len(max_gpu_memory)
)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.sampling_params = vllm.SamplingParams(temperature=temperature)

except ImportError:
raise ImportError("Could not import vllm Python package"
"Please install it with `pip install python`")
except Exception as err:
print("Error loading vllm model:", err)

def inference_online(self, messages, temperature, stream=False):
return str(completion(
breakpoint()
return completion(
model="hosted_vllm/" + self.model_name,
messages=messages,
temperature=temperature,
api_base=self.hostname,
))
).choices[0].message.content

def __call__(
self,
Expand All @@ -121,14 +127,16 @@ def __call__(
if self.hostname is not None:
return self.inference_online(messages, temperature, stream=stream)

assert vllm
assert self.model
assert self.sampling_params
# breakpoint()
if stream:
raise NotImplemented

parameters = vllm.SamplingParams(temperature=temperature)
# parameters = vllm.SamplingParams(temperature=temperature)
prompt = self.tokenizer.apply_chat_template(messages,
tokenize=False)
response = self.model.generate(prompt, parameters)
response = self.model.generate(prompt, self.sampling_params)
result = response[0].outputs[0].text

return result
Expand Down
25 changes: 23 additions & 2 deletions scripts/run_agent.sh
Original file line number Diff line number Diff line change
@@ -1,13 +1,34 @@
# run agent with gemini-1.5-flash
run-agent \
--llm_name llama3:8b \
--llm_backend ollama \
--llm_name gemini-1.5-flash \
--llm_backend google \
--agent_name_or_path demo_author/demo_agent \
--task "Tell me what is core idea of AIOS" \
--aios_kernel_url http://localhost:8000

# run agent with gpt-4o-mini using openai
run-agent \
--llm_name gpt-4o-mini \
--llm_backend openai \
--agent_name_or_path demo_author/demo_agent \
--task "Tell me what is core idea of AIOS" \
--aios_kernel_url http://localhost:8000

# run agent with gpt-4o-mini using openai
vllm serve meta-llama/Meta-Llama-3-8B-Instruct --dtype auto --port 8001 # start the vllm server
run-agent \
--llm_name meta-llama/Meta-Llama-3-8B-Instruct \
--llm_backend vllm \
--agent_name_or_path demo_author/demo_agent \
--task "Tell me what is core idea of AIOS" \
--aios_kernel_url http://localhost:8000

# run agent with llama3:8b using ollama
ollama pull llama3:8b # pull the model
ollama serve # start the ollama server
run-agent \
--llm_name llama3:8b \
--llm_backend ollama \
--agent_name_or_path demo_author/demo_agent \
--task "Tell me what is core idea of AIOS" \
--aios_kernel_url http://localhost:8000

0 comments on commit a8b07f3

Please sign in to comment.