Skip to content

Commit

Permalink
Add NVIDIA function calling models
Browse files Browse the repository at this point in the history
  • Loading branch information
nvbagade authored and aw632 committed Aug 26, 2024
1 parent 30124c4 commit 66a34e7
Show file tree
Hide file tree
Showing 8 changed files with 177 additions and 86 deletions.
5 changes: 5 additions & 0 deletions berkeley-function-call-leaderboard/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,11 @@ Below is *a table of models we support* to run our leaderboard evaluation agains
|THUDM/glm-4-9b-chat 💻| Function Calling|
|ibm-granite/granite-20b-functioncalling 💻| Function Calling|
|yi-large-fc | Function Calling|
|meta/llama-3.1-70b-instruct-FC | Function Calling|
|nv-mistralai/mistral-nemo-12b-instruct-FC | Function Calling|
|mistralai/mistral-large-2-instruct-FC | Function Calling|
|mistralai/mistral-7b-instruct-v0.3-FC | Function Calling|


Here {MODEL} 💻 means the model needs to be hosted locally and called by vllm, {MODEL} means the models that are called API calls. For models with a trailing `-FC`, it means that the model supports function-calling feature. You can check out the table summarizing feature supports among different models [here](https://gorilla.cs.berkeley.edu/blogs/8_berkeley_function_calling_leaderboard.html#prompt).

Expand Down
1 change: 0 additions & 1 deletion berkeley-function-call-leaderboard/eval_checker/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,6 @@ def simple_function_checker(
}

func_name = convert_func_name(func_name, model_name)

# Check if function name matches
if func_name not in model_output:
result["valid"] = False
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,36 @@
"OpenAI",
"Proprietary",
],
"meta/llama-3.1-70b-instruct-FC": [
"meta/llama-3.1-70b-instruct (FC)",
"https://integrate.api.nvidia.com/v1",
"NVIDIA",
"MIT",
],
"nv-mistralai/mistral-nemo-12b-instruct-FC": [
"nv-mistralai/mistral-nemo-12b-instruct (FC)",
"https://integrate.api.nvidia.com/v1",
"NVIDIA",
"MIT",
],
"meta/llama-3.1-405b-instruct-FC": [
"meta/llama-3.1-405b-instruct (FC)",
"https://integrate.api.nvidia.com/v1",
"NVIDIA",
"MIT",
],
"mistralai/mistral-large-2-instruct-FC": [
"mistralai/mistral-large-2-instruct (FC)",
"https://integrate.api.nvidia.com/v1",
"NVIDIA",
"MIT",
],
"mistralai/mistral-7b-instruct-v0.3-FC": [
"mistralai/mistral-7b-instruct-v0.3 (FC)",
"https://integrate.api.nvidia.com/v1",
"NVIDIA",
"MIT",
],
"gpt-4o-2024-05-13": [
"GPT-4o-2024-05-13 (Prompt)",
"https://openai.com/index/hello-gpt-4o/",
Expand Down Expand Up @@ -532,7 +562,7 @@
"https://huggingface.co/Salesforce/xLAM-7b-fc-r",
"Salesforce",
"cc-by-nc-4.0",
]
],
}

INPUT_PRICE_PER_MILLION_TOKEN = {
Expand Down Expand Up @@ -673,7 +703,7 @@
"ibm-granite/granite-20b-functioncalling",
"THUDM/glm-4-9b-chat",
"Salesforce/xLAM-1b-fc-r",
"Salesforce/xLAM-7b-fc-r"
"Salesforce/xLAM-7b-fc-r",
]

# Price got from AZure, 22.032 per hour for 8 V100, Pay As You Go Total Price
Expand Down Expand Up @@ -833,7 +863,10 @@ def api_status_sanity_check_rest():
errors.append((data, status))

if correct_count != len(ground_truth_replaced):
raise BadAPIStatusError(errors, f"{len(ground_truth_replaced) - correct_count} / {len(ground_truth_replaced)}")
raise BadAPIStatusError(
errors,
f"{len(ground_truth_replaced) - correct_count} / {len(ground_truth_replaced)}",
)


def api_status_sanity_check_executable():
Expand All @@ -857,26 +890,37 @@ def api_status_sanity_check_executable():
errors.append((data, status))

if correct_count != len(ground_truth):
raise BadAPIStatusError(errors, f"{len(ground_truth) - correct_count} / {len(ground_truth)}")
raise BadAPIStatusError(
errors, f"{len(ground_truth) - correct_count} / {len(ground_truth)}"
)


def display_api_status_error(rest_error, executable_error, display_success=False):
if not rest_error and not executable_error:
if display_success:
print("🟢 All API Status Test Passed!")
return None

print(f"\n{RED_FONT}{'-' * 18} Executable Categories' Error Bounds Based on API Health Status {'-' * 18}{RESET}\n")

RED_FONT = "\033[91m"
RESET = "\033[0m"

print(
f"\n{RED_FONT}{'-' * 18} Executable Categories' Error Bounds Based on API Health Status {'-' * 18}{RESET}\n"
)

if rest_error:
print(f"❗️ Warning: Unable to verify health of executable APIs used in executable test category (REST). Please contact API provider.\n")
print(
f"❗️ Warning: Unable to verify health of executable APIs used in executable test category (REST). Please contact API provider.\n"
)
print(f"{rest_error.error_rate} APIs affected:\n")
for data, status in rest_error.errors:
print(f" - Test Case: {data['ground_truth']}")
print(f" Error Type: {status['error_type']}\n")

if executable_error:
print(f"❗️ Warning: Unable to verify health of executable APIs used in executable test categories (Non-REST). Please contact API provider.\n")
print(
f"❗️ Warning: Unable to verify health of executable APIs used in executable test categories (Non-REST). Please contact API provider.\n"
)
print(f"{executable_error.error_rate} APIs affected:\n")
for data, status in executable_error.errors:
print(f" - Test Case: {data['ground_truth'][0]}")
Expand Down
5 changes: 5 additions & 0 deletions berkeley-function-call-leaderboard/model_handler/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,11 @@
"THUDM/glm-4-9b-chat",
"ibm-granite/granite-20b-functioncalling",
"yi-large-fc",
"meta/llama-3.1-70b-instruct-FC",
"meta/llama-3.1-405b-instruct-FC",
"nv-mistralai/mistral-nemo-12b-instruct-FC",
"mistralai/mistral-large-2-instruct-FC",
"mistralai/mistral-7b-instruct-v0.3-FC"
]

TEST_FILE_MAPPING = {
Expand Down
123 changes: 65 additions & 58 deletions berkeley-function-call-leaderboard/model_handler/gpt_handler.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from openai import OpenAI
import os, time, json, logging
from model_handler.handler import BaseHandler
from model_handler.model_style import ModelStyle
from model_handler.utils import (
Expand All @@ -13,94 +15,99 @@
USER_PROMPT_FOR_CHAT_MODEL,
DEFAULT_SYSTEM_PROMPT,
)
from openai import OpenAI
import os, time, json


class OpenAIHandler(BaseHandler):
def __init__(self, model_name, temperature=0.001, top_p=1, max_tokens=1000) -> None:
super().__init__(model_name, temperature, top_p, max_tokens)
self.model_style = ModelStyle.OpenAI
self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
self.client = OpenAI(
base_url=os.getenv("OPENAI_BASE_URL"), api_key=os.getenv("OPENAI_API_KEY")
)

def inference(self, prompt, functions, test_category):
# Chatting model
if "FC" not in self.model_name:
functions = func_doc_language_specific_pre_processing(functions, test_category)
prompt = augment_prompt_by_languge(prompt, test_category)
functions = language_specific_pre_processing(functions, test_category)
message = [{"role": "user", "content": prompt}]
if type(functions) is not list:
functions = [functions]

oai_tool = convert_to_tool(
functions, GORILLA_TO_OPENAPI, self.model_style, test_category
)

logging.debug("Message: %s", str(message))
logging.debug("Tools: %s", str(oai_tool))
metadata = {}

if "FC" not in self.model_name or len(oai_tool) == 0:
message = [
{
"role": "system",
"content": SYSTEM_PROMPT_FOR_CHAT_MODEL,
},
{
"role": "user",
"content": USER_PROMPT_FOR_CHAT_MODEL.format(
user_prompt=prompt, functions=str(functions)
),
},
]

prompt = system_prompt_pre_processing(prompt, DEFAULT_SYSTEM_PROMPT)
prompt = user_prompt_pre_processing_chat_model(prompt, USER_PROMPT_FOR_CHAT_MODEL, test_category, functions)
message = prompt

start_time = time.time()
response = self.client.chat.completions.create(
messages=message,
model=self.model_name,
model=self.model_name.replace("-FC", ""),
temperature=self.temperature,
max_tokens=self.max_tokens,
top_p=self.top_p,
)
latency = time.time() - start_time
metadata["input_tokens"] = response.usage.prompt_tokens
metadata["output_tokens"] = response.usage.completion_tokens
metadata["latency"] = latency
result = response.choices[0].message.content
# Function call model
else:
functions = func_doc_language_specific_pre_processing(functions, test_category)
logging.debug("Response: %s", str(result))
return result, metadata

message = prompt
oai_tool = convert_to_tool(
functions, GORILLA_TO_OPENAPI, self.model_style, test_category
)
start_time = time.time()
if len(oai_tool) > 0:
response = self.client.chat.completions.create(
messages=message,
model=self.model_name.replace("-FC", ""),
temperature=self.temperature,
max_tokens=self.max_tokens,
top_p=self.top_p,
tools=oai_tool,
)
else:
response = self.client.chat.completions.create(
messages=message,
model=self.model_name.replace("-FC", ""),
temperature=self.temperature,
max_tokens=self.max_tokens,
top_p=self.top_p,
)
latency = time.time() - start_time
try:
result = [
{func_call.function.name: func_call.function.arguments}
for func_call in response.choices[0].message.tool_calls
]
except:
result = response.choices[0].message.content
metadata = {}
oai_tool = convert_to_tool(
functions, GORILLA_TO_OPENAPI, self.model_style, test_category
)
start_time = time.time()
response = self.client.chat.completions.create(
messages=message,
model=self.model_name.replace("-FC", ""),
temperature=self.temperature,
max_tokens=self.max_tokens,
top_p=self.top_p,
tools=oai_tool,
)
latency = time.time() - start_time
result = response.choices[0].message.content

if response.choices[0].message.tool_calls:
result = [
{func_call.function.name: func_call.function.arguments}
for func_call in response.choices[0].message.tool_calls
]
logging.debug("Response: %s", str(result))
metadata["input_tokens"] = response.usage.prompt_tokens
metadata["output_tokens"] = response.usage.completion_tokens
metadata["latency"] = latency
return result,metadata
def decode_ast(self,result,language="Python"):
return result, metadata

def decode_ast(self, result, language="Python"):
if "FC" not in self.model_name:
func = result
if " " == func[0]:
func = func[1:]
if not func.startswith("["):
func = "[" + func
if not func.endswith("]"):
func = func + "]"
decoded_output = ast_parse(func,language)
decoded_output = ast_parse(result, language)
else:
decoded_output = []
for invoked_function in result:
name = list(invoked_function.keys())[0]
params = json.loads(invoked_function[name])
decoded_output.append({name: params})
return decoded_output
def decode_execute(self,result):

def decode_execute(self, result):
if "FC" not in self.model_name:
func = result
if " " == func[0]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,11 @@
"snowflake/arctic": ArcticHandler,
"ibm-granite/granite-20b-functioncalling": GraniteHandler,
"nvidia/nemotron-4-340b-instruct": NvidiaHandler,
"meta/llama-3.1-405b-instruct-FC": OpenAIHandler,
"meta/llama-3.1-70b-instruct-FC": OpenAIHandler,
"nv-mistralai/mistral-nemo-12b-instruct-FC": OpenAIHandler,
"mistralai/mistral-large-2-instruct-FC": OpenAIHandler,
"mistralai/mistral-7b-instruct-v0.3-FC": OpenAIHandler,
"THUDM/glm-4-9b-chat": GLMHandler,
"yi-large-fc": YiHandler,
"Salesforce/xLAM-1b-fc-r": xLAMHandler,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import time, os
import time,os
from openai import OpenAI
from model_handler.handler import BaseHandler
from model_handler.model_style import ModelStyle
Expand Down
Loading

0 comments on commit 66a34e7

Please sign in to comment.