-
Notifications
You must be signed in to change notification settings - Fork 1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add support for Palmyra X 004 from Writer
- Loading branch information
Showing
8 changed files
with
230 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,3 +11,4 @@ class ModelStyle(Enum): | |
NEXUS = "nexus" | ||
OSSMODEL = "ossmodel" | ||
COHERE = "cohere" | ||
WRITER = "writer" |
213 changes: 213 additions & 0 deletions
213
berkeley-function-call-leaderboard/bfcl/model_handler/proprietary_model/writer.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,213 @@ | ||
import json | ||
import os | ||
|
||
from bfcl.model_handler.base_handler import BaseHandler | ||
from bfcl.model_handler.constant import DEFAULT_SYSTEM_PROMPT, GORILLA_TO_OPENAPI | ||
from bfcl.model_handler.model_style import ModelStyle | ||
from bfcl.model_handler.utils import ( | ||
convert_to_function_call, | ||
convert_to_tool, | ||
default_decode_ast_prompting, | ||
default_decode_execute_prompting, | ||
format_execution_results_prompting, | ||
func_doc_language_specific_pre_processing, | ||
system_prompt_pre_processing_chat_model, | ||
convert_system_prompt_into_user_prompt, | ||
combine_consecutive_user_prompts, | ||
) | ||
from writerai import Writer | ||
|
||
|
||
class WriterHandler(BaseHandler): | ||
def __init__(self, model_name, temperature) -> None: | ||
super().__init__(model_name, temperature) | ||
self.model_style = ModelStyle.WRITER | ||
self.client = Writer(api_key=os.getenv("WRITER_API_KEY")) | ||
|
||
def decode_ast(self, result, language="Python"): | ||
if "FC" not in self.model_name: | ||
return default_decode_ast_prompting(result, language) | ||
else: | ||
decoded_output = [] | ||
for invoked_function in result: | ||
name = list(invoked_function.keys())[0] | ||
params = json.loads(invoked_function[name]) | ||
decoded_output.append({name: params}) | ||
return decoded_output | ||
|
||
def decode_execute(self, result): | ||
if "FC" not in self.model_name: | ||
return default_decode_execute_prompting(result) | ||
else: | ||
function_call = convert_to_function_call(result) | ||
return function_call | ||
|
||
#### FC methods #### | ||
|
||
def _query_FC(self, inference_data: dict): | ||
message: list[dict] = inference_data["message"] | ||
tools = inference_data["tools"] | ||
inference_data["inference_input_log"] = {"message": repr(message), "tools": tools} | ||
|
||
if len(tools) > 0: | ||
for tool in tools: | ||
if "response" in tool["function"]: | ||
del tool["function"]["response"] | ||
|
||
api_response = self.client.chat.chat( | ||
messages=message, | ||
model=self.model_name.replace("-FC", ""), | ||
temperature=self.temperature, | ||
tools=tools, | ||
tool_choice="auto", | ||
) | ||
else: | ||
api_response = self.client.chat.chat( | ||
messages=message, | ||
model=self.model_name.replace("-FC", ""), | ||
temperature=self.temperature, | ||
) | ||
return api_response | ||
|
||
def _pre_query_processing_FC(self, inference_data: dict, test_entry: dict) -> dict: | ||
inference_data["message"] = [] | ||
return inference_data | ||
|
||
def _compile_tools(self, inference_data: dict, test_entry: dict) -> dict: | ||
functions: list = test_entry["function"] | ||
test_category: str = test_entry["id"].rsplit("_", 1)[0] | ||
|
||
functions = func_doc_language_specific_pre_processing(functions, test_category) | ||
tools = convert_to_tool(functions, GORILLA_TO_OPENAPI, self.model_style) | ||
|
||
inference_data["tools"] = tools | ||
|
||
return inference_data | ||
|
||
def _parse_query_response_FC(self, api_response: any) -> dict: | ||
try: | ||
model_responses = [ | ||
{func_call.function.name: func_call.function.arguments} | ||
for func_call in api_response.choices[0].message.tool_calls | ||
] | ||
tool_call_ids = [ | ||
func_call.id for func_call in api_response.choices[0].message.tool_calls | ||
] | ||
except: | ||
model_responses = api_response.choices[0].message.content | ||
tool_call_ids = [] | ||
|
||
model_responses_message_for_chat_history = api_response.choices[0].message | ||
|
||
return { | ||
"model_responses": model_responses, | ||
"model_responses_message_for_chat_history": model_responses_message_for_chat_history, | ||
"tool_call_ids": tool_call_ids, | ||
"input_token": api_response.usage.prompt_tokens, | ||
"output_token": api_response.usage.completion_tokens, | ||
} | ||
|
||
def add_first_turn_message_FC( | ||
self, inference_data: dict, first_turn_message: list[dict] | ||
) -> dict: | ||
inference_data["message"].extend(first_turn_message) | ||
return inference_data | ||
|
||
def _add_next_turn_user_message_FC( | ||
self, inference_data: dict, user_message: list[dict] | ||
) -> dict: | ||
inference_data["message"].extend(user_message) | ||
return inference_data | ||
|
||
def _add_assistant_message_FC( | ||
self, inference_data: dict, model_response_data: dict | ||
) -> dict: | ||
inference_data["message"].append( | ||
model_response_data["model_responses_message_for_chat_history"] | ||
) | ||
return inference_data | ||
|
||
def _add_execution_results_FC( | ||
self, | ||
inference_data: dict, | ||
execution_results: list[str], | ||
model_response_data: dict, | ||
) -> dict: | ||
# Add the execution results to the current round result, one at a time | ||
for execution_result, tool_call_id in zip( | ||
execution_results, model_response_data["tool_call_ids"] | ||
): | ||
tool_message = { | ||
"role": "tool", | ||
"content": str(execution_result), | ||
"tool_call_id": tool_call_id, | ||
} | ||
inference_data["message"].append(tool_message) | ||
|
||
return inference_data | ||
|
||
#### Prompting methods #### | ||
|
||
def _query_prompting(self, inference_data: dict): | ||
inference_data["inference_input_log"] = {"message": repr(inference_data["message"])} | ||
|
||
|
||
api_response = self.client.chat.chat( | ||
messages=inference_data["message"], | ||
model=self.model_name, | ||
temperature=self.temperature, | ||
) | ||
|
||
return api_response | ||
|
||
def _pre_query_processing_prompting(self, test_entry: dict) -> dict: | ||
functions: list = test_entry["function"] | ||
test_category: str = test_entry["id"].rsplit("_", 1)[0] | ||
|
||
functions = func_doc_language_specific_pre_processing(functions, test_category) | ||
|
||
test_entry["question"][0] = system_prompt_pre_processing_chat_model( | ||
test_entry["question"][0], DEFAULT_SYSTEM_PROMPT, functions | ||
) | ||
|
||
return {"message": []} | ||
|
||
def _parse_query_response_prompting(self, api_response: any) -> dict: | ||
return { | ||
"model_responses": api_response.choices[0].message.content, | ||
"model_responses_message_for_chat_history": api_response.choices[0].message, | ||
"input_token": api_response.usage.prompt_tokens, | ||
"output_token": api_response.usage.completion_tokens, | ||
} | ||
|
||
def add_first_turn_message_prompting( | ||
self, inference_data: dict, first_turn_message: list[dict] | ||
) -> dict: | ||
inference_data["message"].extend(first_turn_message) | ||
return inference_data | ||
|
||
def _add_next_turn_user_message_prompting( | ||
self, inference_data: dict, user_message: list[dict] | ||
) -> dict: | ||
inference_data["message"].extend(user_message) | ||
return inference_data | ||
|
||
def _add_assistant_message_prompting( | ||
self, inference_data: dict, model_response_data: dict | ||
) -> dict: | ||
inference_data["message"].append( | ||
model_response_data["model_responses_message_for_chat_history"] | ||
) | ||
return inference_data | ||
|
||
def _add_execution_results_prompting( | ||
self, inference_data: dict, execution_results: list[str], model_response_data: dict | ||
) -> dict: | ||
formatted_results_message = format_execution_results_prompting( | ||
inference_data, execution_results, model_response_data | ||
) | ||
inference_data["message"].append( | ||
{"role": "user", "content": formatted_results_message} | ||
) | ||
|
||
return inference_data |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters