diff --git a/berkeley-function-call-leaderboard/bfcl/model_handler/proprietary_model/claude.py b/berkeley-function-call-leaderboard/bfcl/model_handler/proprietary_model/claude.py index 27ec930f6..f63984076 100644 --- a/berkeley-function-call-leaderboard/bfcl/model_handler/proprietary_model/claude.py +++ b/berkeley-function-call-leaderboard/bfcl/model_handler/proprietary_model/claude.py @@ -12,11 +12,12 @@ convert_system_prompt_into_user_prompt, convert_to_function_call, convert_to_tool, - format_execution_results_prompting, extract_system_prompt, + format_execution_results_prompting, func_doc_language_specific_pre_processing, system_prompt_pre_processing_chat_model, ) +from bfcl.utils import is_multi_turn class ClaudeHandler(BaseHandler): @@ -74,14 +75,26 @@ def _query_FC(self, inference_data: dict): "message": repr(inference_data["message"]), "tools": inference_data["tools"], } - - return self.client.messages.create( + messages = inference_data["message"] + + if inference_data["caching_enabled"]: + # Only add cache control to the last two user messages + # Remove previously set cache control flags from all user messages except the last two + count = 0 + for message in reversed(messages): + if message["role"] == "user": + if count < 2: + message["content"][0]["cache_control"] = {"type": "ephemeral"} + else: + if "cache_control" in message["content"][0]: + del message["content"][0]["cache_control"] + count += 1 + + return self.client.beta.prompt_caching.messages.create( model=self.model_name.strip("-FC"), - max_tokens=( - 8192 if "claude-3-5-sonnet-20240620" in self.model_name else 4096 - ), # 3.5 Sonnet has a higher max token limit + max_tokens=(8192 if "claude-3-5-sonnet-20240620" in self.model_name else 4096), tools=inference_data["tools"], - messages=inference_data["message"], + messages=messages, ) def _pre_query_processing_FC(self, inference_data: dict, test_entry: dict) -> dict: @@ -92,8 +105,15 @@ def _pre_query_processing_FC(self, inference_data: dict, test_entry: dict) -> di test_entry["question"][round_idx] = combine_consecutive_user_prompts( test_entry["question"][round_idx] ) - inference_data["message"] = [] + + test_entry_id: str = test_entry["id"] + test_category: str = test_entry_id.rsplit("_", 1)[0] + # caching enabled only for multi_turn category + inference_data["caching_enabled"] = ( + is_multi_turn(test_category) and "claude-3-sonnet" not in self.model_name + ) + return inference_data def _compile_tools(self, inference_data: dict, test_entry: dict) -> dict: @@ -103,8 +123,19 @@ def _compile_tools(self, inference_data: dict, test_entry: dict) -> dict: functions = func_doc_language_specific_pre_processing(functions, test_category) tools = convert_to_tool(functions, GORILLA_TO_OPENAPI, self.model_style) + if inference_data["caching_enabled"]: + # First time compiling tools, so adding cache control flag to the last tool + if "tools" not in inference_data: + tools[-1]["cache_control"] = {"type": "ephemeral"} + # This is the situation where the tools are already compiled and we are adding more tools to the existing tools (in miss_func category) + # We add the cache control flag to the last tool in the previous existing tools and the last tool in the new tools to maximize cache hit + else: + existing_tool_len = len(inference_data["tools"]) + tools[existing_tool_len - 1]["cache_control"] = {"type": "ephemeral"} + tools[-1]["cache_control"] = {"type": "ephemeral"} + inference_data["tools"] = tools - + return inference_data def _parse_query_response_FC(self, api_response: any) -> dict: @@ -134,12 +165,16 @@ def _parse_query_response_FC(self, api_response: any) -> dict: def add_first_turn_message_FC( self, inference_data: dict, first_turn_message: list[dict] ) -> dict: + for message in first_turn_message: + message["content"] = [{"type": "text", "text": message["content"]}] inference_data["message"].extend(first_turn_message) return inference_data def _add_next_turn_user_message_FC( self, inference_data: dict, user_message: list[dict] ) -> dict: + for message in user_message: + message["content"] = [{"type": "text", "text": message["content"]}] inference_data["message"].extend(user_message) return inference_data @@ -188,16 +223,27 @@ def _query_prompting(self, inference_data: dict): "system_prompt": inference_data["system_prompt"], } - api_response = self.client.messages.create( + if inference_data["caching_enabled"]: + # Cache the system prompt + inference_data["system_prompt"][0]["cache_control"] = {"type": "ephemeral"} + # Add cache control to the last two user messages as well + count = 0 + for message in reversed(inference_data["message"]): + if message["role"] == "user": + if count < 2: + message["content"][0]["cache_control"] = {"type": "ephemeral"} + else: + if "cache_control" in message["content"][0]: + del message["content"][0]["cache_control"] + count += 1 + + api_response = self.client.beta.prompt_caching.messages.create( model=self.model_name, - max_tokens=( - 8192 if "claude-3-5-sonnet-20240620" in self.model_name else 4096 - ), # 3.5 Sonnet has a higher max token limit + max_tokens=(8192 if "claude-3-5-sonnet-20240620" in self.model_name else 4096), temperature=self.temperature, system=inference_data["system_prompt"], messages=inference_data["message"], ) - return api_response def _pre_query_processing_prompting(self, test_entry: dict) -> dict: @@ -213,13 +259,26 @@ def _pre_query_processing_prompting(self, test_entry: dict) -> dict: # Claude takes in system prompt in a specific field, not in the message field, so we don't need to add it to the message system_prompt = extract_system_prompt(test_entry["question"][0]) + system_prompt = [{"type": "text", "text": system_prompt}] + # Claude doesn't allow consecutive user prompts, so we need to combine them for round_idx in range(len(test_entry["question"])): test_entry["question"][round_idx] = combine_consecutive_user_prompts( test_entry["question"][round_idx] ) - return {"message": [], "system_prompt": system_prompt} + test_entry_id: str = test_entry["id"] + test_category: str = test_entry_id.rsplit("_", 1)[0] + # caching enabled only for multi_turn category + caching_enabled: bool = ( + is_multi_turn(test_category) and "claude-3-sonnet" not in self.model_name + ) + + return { + "message": [], + "system_prompt": system_prompt, + "caching_enabled": caching_enabled, + } def _parse_query_response_prompting(self, api_response: any) -> dict: return { @@ -231,12 +290,16 @@ def _parse_query_response_prompting(self, api_response: any) -> dict: def add_first_turn_message_prompting( self, inference_data: dict, first_turn_message: list[dict] ) -> dict: + for message in first_turn_message: + message["content"] = [{"type": "text", "text": message["content"]}] inference_data["message"].extend(first_turn_message) return inference_data def _add_next_turn_user_message_prompting( self, inference_data: dict, user_message: list[dict] ) -> dict: + for message in user_message: + message["content"] = [{"type": "text", "text": message["content"]}] inference_data["message"].extend(user_message) return inference_data diff --git a/berkeley-function-call-leaderboard/pyproject.toml b/berkeley-function-call-leaderboard/pyproject.toml index abdc32666..e29f9acb2 100644 --- a/berkeley-function-call-leaderboard/pyproject.toml +++ b/berkeley-function-call-leaderboard/pyproject.toml @@ -24,7 +24,7 @@ dependencies = [ "tree-sitter-javascript==0.21.4", "openai==1.46.0", "mistralai==1.1.0", - "anthropic==0.31.1", + "anthropic==0.37.1", "cohere==5.5.8", "typer>=0.12.5", "tabulate>=0.9.0",