Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix gemini model set up issue #414

Merged
merged 1 commit into from
Jan 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 19 additions & 13 deletions aios/llm_core/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ def __init__(
variables making this needless.
"""
if isinstance(llm_name, list) != isinstance(llm_backend, list):
raise ValueError
raise ValueError("llm_name and llm_backend do not be the same type")
elif isinstance(llm_backend, list) and len(llm_name) == len(llm_backend):
raise ValueError
raise ValueError("llm_name and llm_backend do not have the same length")

self.llm_name = llm_name if isinstance(llm_name, list) else [llm_name]
self.max_gpu_memory = max_gpu_memory
Expand All @@ -76,9 +76,7 @@ def __init__(
self.log_mode = log_mode
self.llm_backend = llm_backend if isinstance(llm_backend, list) else [llm_backend]
self.context_manager = SimpleContextManager() if use_context_manager else None

if strategy == RouterStrategy.SIMPLE:
self.strategy = SimpleStrategy(self.llm_name)


# Set all supported API keys
api_providers = {
Expand Down Expand Up @@ -111,6 +109,8 @@ def __init__(
else:
print(f"- Not found in config.yaml or environment variables")

# breakpoint()

# Format model names to match backend or instantiate local backends
for idx in range(len(self.llm_name)):
if self.llm_backend[idx] is None:
Expand Down Expand Up @@ -140,18 +140,23 @@ def __init__(
case None:
continue
case _:
prefix = self.llm_backend[idx] + "/"
is_formatted = self.llm_name[idx].startswith(prefix)

# Google backwards compatibility fix
if self.llm_backend[idx] == "google":
self.llm_backend[idx] = "gemini"
if is_formatted:
self.llm_name[idx] = "gemini/" + self.llm_name[idx].split("/")[1]
continue

# Google backwards compatibility fix

prefix = self.llm_backend[idx] + "/"
is_formatted = self.llm_name[idx].startswith(prefix)

# if not is_formatted:
# self.llm_name[idx] = "gemini/" + self.llm_name[idx].split("/")[1]
# continue

if not is_formatted:
self.llm_name[idx] = prefix + self.llm_name[idx]

if strategy == RouterStrategy.SIMPLE:
self.strategy = SimpleStrategy(self.llm_name)

def tool_calling_input_format(self, messages: list, tools: list) -> list:
"""Integrate tool information into the messages for open-sourced LLMs
Expand Down Expand Up @@ -265,12 +270,13 @@ def address_syscall(
llm_syscall.set_start_time(time.time())

restored_context = None

if self.context_manager:
pid = llm_syscall.get_pid()
if self.context_manager.check_restoration(pid):
restored_context = self.context_manager.gen_recover(pid)

if restored_context is not None:
if restored_context:
messages += [{
"role": "assistant",
"content": "" + restored_context,
Expand Down
24 changes: 16 additions & 8 deletions aios/llm_core/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

import os

from aios.config.config_manager import config

class HfLocalBackend:
def __init__(self, model_name, device="auto", max_gpu_memory=None, hostname=None):
print("\n=== HfLocalBackend Initialization ===")
Expand Down Expand Up @@ -35,12 +37,12 @@ def __init__(self, model_name, device="auto", max_gpu_memory=None, hostname=None
self.tokenizer.chat_template = "{% for message in messages %}{% if message['role'] == 'user' %}{{ ' ' }}{% endif %}{{ message['content'] }}{% if not loop.last %}{{ ' ' }}{% endif %}{% endfor %}{{ eos_token }}"

def inference_online(self, messages, temperature, stream=False):
return str(completion(
return completion(
model="huggingface/" + self.model_name,
messages=messages,
temperature=temperature,
api_base=self.hostname,
))
).choices[0].message.content

def __call__(
self,
Expand Down Expand Up @@ -83,7 +85,8 @@ def __init__(self, model_name, device="auto", max_gpu_memory=None, hostname=None
self.model_name = model_name
self.device = device
self.max_gpu_memory = max_gpu_memory
self.hostname = hostname
# self.hostname = hostname
self.hostname = "http://localhost:8001"

# If a hostname is given, then this vLLM instance is hosted as a web server.
# Therefore, do not start the AIOS-based vLLM instance.
Expand All @@ -98,19 +101,22 @@ def __init__(self, model_name, device="auto", max_gpu_memory=None, hostname=None
tensor_parallel_size=1 if max_gpu_memory is None else len(max_gpu_memory)
)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.sampling_params = vllm.SamplingParams(temperature=temperature)

except ImportError:
raise ImportError("Could not import vllm Python package"
"Please install it with `pip install python`")
except Exception as err:
print("Error loading vllm model:", err)

def inference_online(self, messages, temperature, stream=False):
return str(completion(
breakpoint()
return completion(
model="hosted_vllm/" + self.model_name,
messages=messages,
temperature=temperature,
api_base=self.hostname,
))
).choices[0].message.content

def __call__(
self,
Expand All @@ -121,14 +127,16 @@ def __call__(
if self.hostname is not None:
return self.inference_online(messages, temperature, stream=stream)

assert vllm
assert self.model
assert self.sampling_params
# breakpoint()
if stream:
raise NotImplemented

parameters = vllm.SamplingParams(temperature=temperature)
# parameters = vllm.SamplingParams(temperature=temperature)
prompt = self.tokenizer.apply_chat_template(messages,
tokenize=False)
response = self.model.generate(prompt, parameters)
response = self.model.generate(prompt, self.sampling_params)
result = response[0].outputs[0].text

return result
Expand Down
25 changes: 23 additions & 2 deletions scripts/run_agent.sh
Original file line number Diff line number Diff line change
@@ -1,13 +1,34 @@
# run agent with gemini-1.5-flash
run-agent \
--llm_name llama3:8b \
--llm_backend ollama \
--llm_name gemini-1.5-flash \
--llm_backend google \
--agent_name_or_path demo_author/demo_agent \
--task "Tell me what is core idea of AIOS" \
--aios_kernel_url http://localhost:8000

# run agent with gpt-4o-mini using openai
run-agent \
--llm_name gpt-4o-mini \
--llm_backend openai \
--agent_name_or_path demo_author/demo_agent \
--task "Tell me what is core idea of AIOS" \
--aios_kernel_url http://localhost:8000

# run agent with gpt-4o-mini using openai
vllm serve meta-llama/Meta-Llama-3-8B-Instruct --dtype auto --port 8001 # start the vllm server
run-agent \
--llm_name meta-llama/Meta-Llama-3-8B-Instruct \
--llm_backend vllm \
--agent_name_or_path demo_author/demo_agent \
--task "Tell me what is core idea of AIOS" \
--aios_kernel_url http://localhost:8000

# run agent with llama3:8b using ollama
ollama pull llama3:8b # pull the model
ollama serve # start the ollama server
run-agent \
--llm_name llama3:8b \
--llm_backend ollama \
--agent_name_or_path demo_author/demo_agent \
--task "Tell me what is core idea of AIOS" \
--aios_kernel_url http://localhost:8000
Loading