Skip to content

Commit

Permalink
Add support for Palmyra X 004 from Writer
Browse files Browse the repository at this point in the history
  • Loading branch information
samjulien committed Nov 12, 2024
1 parent 19490f1 commit 3118fd7
Show file tree
Hide file tree
Showing 8 changed files with 230 additions and 1 deletion.
1 change: 1 addition & 0 deletions berkeley-function-call-leaderboard/.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ ANTHROPIC_API_KEY=
NVIDIA_API_KEY=nvapi-XXXXXX
YI_API_KEY=
GOGOAGENT_API_KEY=
WRITER_API_KEY=

# We use Vertex AI to inference Google Gemini models
VERTEX_AI_PROJECT_ID=
Expand Down
4 changes: 3 additions & 1 deletion berkeley-function-call-leaderboard/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ The evaluation script will automatically search for dataset files in the default

## Evaluating different models on the BFCL

Make sure the model API keys are included in your `.env` file. Running proprietary models like GPTs, Claude, Mistral-X will require them.
Make sure the model API keys are included in your `.env` file. Running proprietary models like GPTs, Claude, Mistral-X, Palmyra, will require them.

```bash
OPENAI_API_KEY=sk-XXXXXX
Expand All @@ -100,6 +100,7 @@ ANTHROPIC_API_KEY=
NVIDIA_API_KEY=nvapi-XXXXXX
YI_API_KEY=
GOGOAGENT_API_KEY=
WRITER_API_KEY=

VERTEX_AI_PROJECT_ID=
VERTEX_AI_LOCATION=
Expand Down Expand Up @@ -198,6 +199,7 @@ Below is _a table of models we support_ to run our leaderboard evaluation agains
|NousResearch/Hermes-2-Pro-Llama-3-{8B,70B} 💻| Function Calling|
|NousResearch/Hermes-2-Pro-Mistral-7B 💻| Function Calling|
|NousResearch/Hermes-2-Theta-Llama-3-{8B,70B} 💻| Function Calling|
|palmyra-x-004-FC | Function Calling|
|snowflake/arctic | Prompt|
|Salesforce/xLAM-1b-fc-r 💻| Function Calling|
|Salesforce/xLAM-7b-fc-r 💻| Function Calling|
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
from bfcl.model_handler.handler_map import local_inference_handler_map

MODEL_METADATA_MAPPING = {
"palmyra-x-004-FC": [
"palmyra-x-004 (FC)",
"https://writer.com/engineering/actions-with-palmyra-x-004/",
"Writer",
"Proprietary",
],
"o1-preview-2024-09-12": [
"o1-preview-2024-09-12 (Prompt)",
"https://openai.com/index/introducing-openai-o1-preview/",
Expand Down Expand Up @@ -713,6 +719,7 @@
"command-r-plus-FC-optimized": 3,
"command-r-plus-optimized": 3,
"yi-large-fc": 3,
"palmyra-x-004-FC": 5,
}

OUTPUT_PRICE_PER_MILLION_TOKEN = {
Expand Down Expand Up @@ -771,6 +778,7 @@
"command-r-plus-FC-optimized": 15,
"command-r-plus-optimized": 15,
"yi-large-fc": 3,
"palmyra-x-004-FC": 12,
}

# The latency of the open-source models are hardcoded here.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from bfcl.model_handler.proprietary_model.openai import OpenAIHandler
from bfcl.model_handler.proprietary_model.yi import YiHandler
from bfcl.model_handler.proprietary_model.gogoagent import GoGoAgentHandler
from bfcl.model_handler.proprietary_model.writer import WriterHandler

# TODO: Add Deepseek V2, meta-llama/Llama-3.1-405B-Instruct

Expand Down Expand Up @@ -82,6 +83,7 @@
"nvidia/nemotron-4-340b-instruct": NvidiaHandler,
"BitAgent/GoGoAgent": GoGoAgentHandler,
# "yi-large-fc": YiHandler, # Their API is under maintenance, and will not be back online in the near future
"palmyra-x-004-FC": WriterHandler,
}

# Inference through local hosting
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ class ModelStyle(Enum):
NEXUS = "nexus"
OSSMODEL = "ossmodel"
COHERE = "cohere"
WRITER = "writer"
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
import json
import os

from bfcl.model_handler.base_handler import BaseHandler
from bfcl.model_handler.constant import DEFAULT_SYSTEM_PROMPT, GORILLA_TO_OPENAPI
from bfcl.model_handler.model_style import ModelStyle
from bfcl.model_handler.utils import (
convert_to_function_call,
convert_to_tool,
default_decode_ast_prompting,
default_decode_execute_prompting,
format_execution_results_prompting,
func_doc_language_specific_pre_processing,
system_prompt_pre_processing_chat_model,
convert_system_prompt_into_user_prompt,
combine_consecutive_user_prompts,
)
from writerai import Writer


class WriterHandler(BaseHandler):
def __init__(self, model_name, temperature) -> None:
super().__init__(model_name, temperature)
self.model_style = ModelStyle.WRITER
self.client = Writer(api_key=os.getenv("WRITER_API_KEY"))

def decode_ast(self, result, language="Python"):
if "FC" not in self.model_name:
return default_decode_ast_prompting(result, language)
else:
decoded_output = []
for invoked_function in result:
name = list(invoked_function.keys())[0]
params = json.loads(invoked_function[name])
decoded_output.append({name: params})
return decoded_output

def decode_execute(self, result):
if "FC" not in self.model_name:
return default_decode_execute_prompting(result)
else:
function_call = convert_to_function_call(result)
return function_call

#### FC methods ####

def _query_FC(self, inference_data: dict):
message: list[dict] = inference_data["message"]
tools = inference_data["tools"]
inference_data["inference_input_log"] = {"message": repr(message), "tools": tools}

if len(tools) > 0:
for tool in tools:
if "response" in tool["function"]:
del tool["function"]["response"]

api_response = self.client.chat.chat(
messages=message,
model=self.model_name.replace("-FC", ""),
temperature=self.temperature,
tools=tools,
tool_choice="auto",
)
else:
api_response = self.client.chat.chat(
messages=message,
model=self.model_name.replace("-FC", ""),
temperature=self.temperature,
)
return api_response

def _pre_query_processing_FC(self, inference_data: dict, test_entry: dict) -> dict:
inference_data["message"] = []
return inference_data

def _compile_tools(self, inference_data: dict, test_entry: dict) -> dict:
functions: list = test_entry["function"]
test_category: str = test_entry["id"].rsplit("_", 1)[0]

functions = func_doc_language_specific_pre_processing(functions, test_category)
tools = convert_to_tool(functions, GORILLA_TO_OPENAPI, self.model_style)

inference_data["tools"] = tools

return inference_data

def _parse_query_response_FC(self, api_response: any) -> dict:
try:
model_responses = [
{func_call.function.name: func_call.function.arguments}
for func_call in api_response.choices[0].message.tool_calls
]
tool_call_ids = [
func_call.id for func_call in api_response.choices[0].message.tool_calls
]
except:
model_responses = api_response.choices[0].message.content
tool_call_ids = []

model_responses_message_for_chat_history = api_response.choices[0].message

return {
"model_responses": model_responses,
"model_responses_message_for_chat_history": model_responses_message_for_chat_history,
"tool_call_ids": tool_call_ids,
"input_token": api_response.usage.prompt_tokens,
"output_token": api_response.usage.completion_tokens,
}

def add_first_turn_message_FC(
self, inference_data: dict, first_turn_message: list[dict]
) -> dict:
inference_data["message"].extend(first_turn_message)
return inference_data

def _add_next_turn_user_message_FC(
self, inference_data: dict, user_message: list[dict]
) -> dict:
inference_data["message"].extend(user_message)
return inference_data

def _add_assistant_message_FC(
self, inference_data: dict, model_response_data: dict
) -> dict:
inference_data["message"].append(
model_response_data["model_responses_message_for_chat_history"]
)
return inference_data

def _add_execution_results_FC(
self,
inference_data: dict,
execution_results: list[str],
model_response_data: dict,
) -> dict:
# Add the execution results to the current round result, one at a time
for execution_result, tool_call_id in zip(
execution_results, model_response_data["tool_call_ids"]
):
tool_message = {
"role": "tool",
"content": str(execution_result),
"tool_call_id": tool_call_id,
}
inference_data["message"].append(tool_message)

return inference_data

#### Prompting methods ####

def _query_prompting(self, inference_data: dict):
inference_data["inference_input_log"] = {"message": repr(inference_data["message"])}


api_response = self.client.chat.chat(
messages=inference_data["message"],
model=self.model_name,
temperature=self.temperature,
)

return api_response

def _pre_query_processing_prompting(self, test_entry: dict) -> dict:
functions: list = test_entry["function"]
test_category: str = test_entry["id"].rsplit("_", 1)[0]

functions = func_doc_language_specific_pre_processing(functions, test_category)

test_entry["question"][0] = system_prompt_pre_processing_chat_model(
test_entry["question"][0], DEFAULT_SYSTEM_PROMPT, functions
)

return {"message": []}

def _parse_query_response_prompting(self, api_response: any) -> dict:
return {
"model_responses": api_response.choices[0].message.content,
"model_responses_message_for_chat_history": api_response.choices[0].message,
"input_token": api_response.usage.prompt_tokens,
"output_token": api_response.usage.completion_tokens,
}

def add_first_turn_message_prompting(
self, inference_data: dict, first_turn_message: list[dict]
) -> dict:
inference_data["message"].extend(first_turn_message)
return inference_data

def _add_next_turn_user_message_prompting(
self, inference_data: dict, user_message: list[dict]
) -> dict:
inference_data["message"].extend(user_message)
return inference_data

def _add_assistant_message_prompting(
self, inference_data: dict, model_response_data: dict
) -> dict:
inference_data["message"].append(
model_response_data["model_responses_message_for_chat_history"]
)
return inference_data

def _add_execution_results_prompting(
self, inference_data: dict, execution_results: list[str], model_response_data: dict
) -> dict:
formatted_results_message = format_execution_results_prompting(
inference_data, execution_results, model_response_data
)
inference_data["message"].append(
{"role": "user", "content": formatted_results_message}
)

return inference_data
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@ def convert_to_tool(functions, mapping, model_style):
ModelStyle.OpenAI,
ModelStyle.Mistral,
ModelStyle.FIREWORK_AI,
ModelStyle.WRITER,
]:
oai_tool.append({"type": "function", "function": item})
return oai_tool
Expand Down
1 change: 1 addition & 0 deletions berkeley-function-call-leaderboard/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ dependencies = [
"tabulate>=0.9.0",
"google-cloud-aiplatform==1.70.0",
"mpmath==1.3.0",
"writer-sdk>=1.2.0"
]

[project.scripts]
Expand Down

0 comments on commit 3118fd7

Please sign in to comment.