From b6b5de38f939c3b9e65e0f1b1cbf69887c529a07 Mon Sep 17 00:00:00 2001 From: Vishnu Suresh Date: Sun, 10 Nov 2024 23:28:42 -0800 Subject: [PATCH 1/2] Implement prompt caching for claude models --- .../model_handler/proprietary_model/claude.py | 81 +++++++++++++++---- .../pyproject.toml | 2 +- 2 files changed, 68 insertions(+), 15 deletions(-) diff --git a/berkeley-function-call-leaderboard/bfcl/model_handler/proprietary_model/claude.py b/berkeley-function-call-leaderboard/bfcl/model_handler/proprietary_model/claude.py index 27ec930f6..392418afc 100644 --- a/berkeley-function-call-leaderboard/bfcl/model_handler/proprietary_model/claude.py +++ b/berkeley-function-call-leaderboard/bfcl/model_handler/proprietary_model/claude.py @@ -74,14 +74,25 @@ def _query_FC(self, inference_data: dict): "message": repr(inference_data["message"]), "tools": inference_data["tools"], } - - return self.client.messages.create( + tools = inference_data["tools"] + messages = inference_data["message"] + caching_enabled = inference_data["caching_enabled"] + + if caching_enabled: # Caching will only be enabled for multi_turn categories + messages[-1]['content'][0]["cache_control"] = {"type": "ephemeral"} + user_indices = [i for i, item in enumerate(messages) if item['role'] == 'user'] # Keeping the cache control blocks only in the last two user messages. + for i in user_indices[:-2]: + for content in messages[i]['content']: + if 'cache_control' in content: + del content['cache_control'] + + return self.client.beta.prompt_caching.messages.create( model=self.model_name.strip("-FC"), max_tokens=( 8192 if "claude-3-5-sonnet-20240620" in self.model_name else 4096 - ), # 3.5 Sonnet has a higher max token limit - tools=inference_data["tools"], - messages=inference_data["message"], + ), + tools=tools, + messages=messages ) def _pre_query_processing_FC(self, inference_data: dict, test_entry: dict) -> dict: @@ -92,8 +103,12 @@ def _pre_query_processing_FC(self, inference_data: dict, test_entry: dict) -> di test_entry["question"][round_idx] = combine_consecutive_user_prompts( test_entry["question"][round_idx] ) - inference_data["message"] = [] + + test_entry_id: str = test_entry["id"] + test_category: str = test_entry_id.rsplit("_", 1)[0] + inference_data["caching_enabled"] = 'multi_turn' in test_category #caching_enabled only for multi_turn category + return inference_data def _compile_tools(self, inference_data: dict, test_entry: dict) -> dict: @@ -103,6 +118,9 @@ def _compile_tools(self, inference_data: dict, test_entry: dict) -> dict: functions = func_doc_language_specific_pre_processing(functions, test_category) tools = convert_to_tool(functions, GORILLA_TO_OPENAPI, self.model_style) + if 'multi_turn' in test_category: + tools[-1]['cache_control'] = {'type': 'ephemeral'} + inference_data["tools"] = tools return inference_data @@ -120,7 +138,7 @@ def _parse_query_response_FC(self, api_response: any) -> dict: tool_call_ids.append(content.id) model_responses = tool_call_outputs if tool_call_outputs else text_outputs - + model_responses_message_for_chat_history = api_response.content return { @@ -129,17 +147,25 @@ def _parse_query_response_FC(self, api_response: any) -> dict: "tool_call_ids": tool_call_ids, "input_token": api_response.usage.input_tokens, "output_token": api_response.usage.output_tokens, + "prompt_write_cache_token_count": getattr(api_response.usage, 'cache_creation_input_tokens', 0), + "prompt_read_cache_token_count": getattr(api_response.usage, 'cache_read_input_tokens', 0) } def add_first_turn_message_FC( self, inference_data: dict, first_turn_message: list[dict] ) -> dict: + if isinstance(first_turn_message[0]['content'], str): + message_content = [{"type": "text", "text": first_turn_message[0]['content']}] + first_turn_message[0]['content'] = message_content inference_data["message"].extend(first_turn_message) return inference_data def _add_next_turn_user_message_FC( self, inference_data: dict, user_message: list[dict] ) -> dict: + if isinstance(user_message[0]['content'], str): + message_content = [{"type": "text","text": user_message[0]['content']}] + user_message[0]['content'] = message_content inference_data["message"].extend(user_message) return inference_data @@ -187,17 +213,32 @@ def _query_prompting(self, inference_data: dict): "message": repr(inference_data["message"]), "system_prompt": inference_data["system_prompt"], } - - api_response = self.client.messages.create( + system_prompt_text = inference_data["system_prompt"] + system_prompt = [{"type": "text", "text": system_prompt_text}] + messages = inference_data["message"] + caching_enabled = inference_data["caching_enabled"] + + if caching_enabled: + # Caching system prompt + system_prompt[0]['cache_control'] = {"type": "ephemeral"} + + # Caching messages + messages[-1]['content'][0]["cache_control"] = {"type": "ephemeral"} + user_indices = [i for i, item in enumerate(messages) if item['role'] == 'user'] # Keeping the cache control blocks only in the last two user messages. + for i in user_indices[:-2]: + for content in messages[i]['content']: + if 'cache_control' in content: + del content['cache_control'] + + api_response = self.client.beta.prompt_caching.messages.create( model=self.model_name, max_tokens=( 8192 if "claude-3-5-sonnet-20240620" in self.model_name else 4096 - ), # 3.5 Sonnet has a higher max token limit + ), temperature=self.temperature, - system=inference_data["system_prompt"], - messages=inference_data["message"], + system=system_prompt, + messages=messages ) - return api_response def _pre_query_processing_prompting(self, test_entry: dict) -> dict: @@ -218,25 +259,37 @@ def _pre_query_processing_prompting(self, test_entry: dict) -> dict: test_entry["question"][round_idx] = combine_consecutive_user_prompts( test_entry["question"][round_idx] ) + + test_entry_id: str = test_entry["id"] + test_category: str = test_entry_id.rsplit("_", 1)[0] + caching_enabled: bool = 'multi_turn' in test_category # caching enabled only for multi_turn category - return {"message": [], "system_prompt": system_prompt} + return {"message": [], "system_prompt": system_prompt, "caching_enabled": caching_enabled} def _parse_query_response_prompting(self, api_response: any) -> dict: return { "model_responses": api_response.content[0].text, "input_token": api_response.usage.input_tokens, "output_token": api_response.usage.output_tokens, + "prompt_write_cache_token_count": getattr(api_response.usage, 'cache_creation_input_tokens', 0), + "prompt_read_cache_token_count": getattr(api_response.usage, 'cache_read_input_tokens', 0), } def add_first_turn_message_prompting( self, inference_data: dict, first_turn_message: list[dict] ) -> dict: + if isinstance(first_turn_message[0]['content'], str): + message_content = [{"type": "text", "text": first_turn_message[0]['content']}] + first_turn_message[0]['content'] = message_content inference_data["message"].extend(first_turn_message) return inference_data def _add_next_turn_user_message_prompting( self, inference_data: dict, user_message: list[dict] ) -> dict: + if isinstance(user_message[0]['content'], str): + message_content = [{"type": "text","text": user_message[0]['content']}] + user_message[0]['content'] = message_content inference_data["message"].extend(user_message) return inference_data diff --git a/berkeley-function-call-leaderboard/pyproject.toml b/berkeley-function-call-leaderboard/pyproject.toml index 9c1337a27..64ef8ad65 100644 --- a/berkeley-function-call-leaderboard/pyproject.toml +++ b/berkeley-function-call-leaderboard/pyproject.toml @@ -24,7 +24,7 @@ dependencies = [ "tree-sitter-javascript==0.21.4", "openai==1.46.0", "mistralai==1.1.0", - "anthropic==0.31.1", + "anthropic==0.37.1", "cohere==5.5.8", "typer>=0.12.5", "tabulate>=0.9.0", From ab215dc76397c9fb277fd165e0e15e7444fec4d7 Mon Sep 17 00:00:00 2001 From: "Huanzhi (Hans) Mao" Date: Tue, 12 Nov 2024 16:57:18 -0800 Subject: [PATCH 2/2] fix --- .../model_handler/proprietary_model/claude.py | 134 ++++++++++-------- 1 file changed, 72 insertions(+), 62 deletions(-) diff --git a/berkeley-function-call-leaderboard/bfcl/model_handler/proprietary_model/claude.py b/berkeley-function-call-leaderboard/bfcl/model_handler/proprietary_model/claude.py index 392418afc..f63984076 100644 --- a/berkeley-function-call-leaderboard/bfcl/model_handler/proprietary_model/claude.py +++ b/berkeley-function-call-leaderboard/bfcl/model_handler/proprietary_model/claude.py @@ -12,11 +12,12 @@ convert_system_prompt_into_user_prompt, convert_to_function_call, convert_to_tool, - format_execution_results_prompting, extract_system_prompt, + format_execution_results_prompting, func_doc_language_specific_pre_processing, system_prompt_pre_processing_chat_model, ) +from bfcl.utils import is_multi_turn class ClaudeHandler(BaseHandler): @@ -74,25 +75,26 @@ def _query_FC(self, inference_data: dict): "message": repr(inference_data["message"]), "tools": inference_data["tools"], } - tools = inference_data["tools"] messages = inference_data["message"] - caching_enabled = inference_data["caching_enabled"] - - if caching_enabled: # Caching will only be enabled for multi_turn categories - messages[-1]['content'][0]["cache_control"] = {"type": "ephemeral"} - user_indices = [i for i, item in enumerate(messages) if item['role'] == 'user'] # Keeping the cache control blocks only in the last two user messages. - for i in user_indices[:-2]: - for content in messages[i]['content']: - if 'cache_control' in content: - del content['cache_control'] + + if inference_data["caching_enabled"]: + # Only add cache control to the last two user messages + # Remove previously set cache control flags from all user messages except the last two + count = 0 + for message in reversed(messages): + if message["role"] == "user": + if count < 2: + message["content"][0]["cache_control"] = {"type": "ephemeral"} + else: + if "cache_control" in message["content"][0]: + del message["content"][0]["cache_control"] + count += 1 return self.client.beta.prompt_caching.messages.create( model=self.model_name.strip("-FC"), - max_tokens=( - 8192 if "claude-3-5-sonnet-20240620" in self.model_name else 4096 - ), - tools=tools, - messages=messages + max_tokens=(8192 if "claude-3-5-sonnet-20240620" in self.model_name else 4096), + tools=inference_data["tools"], + messages=messages, ) def _pre_query_processing_FC(self, inference_data: dict, test_entry: dict) -> dict: @@ -107,8 +109,11 @@ def _pre_query_processing_FC(self, inference_data: dict, test_entry: dict) -> di test_entry_id: str = test_entry["id"] test_category: str = test_entry_id.rsplit("_", 1)[0] - inference_data["caching_enabled"] = 'multi_turn' in test_category #caching_enabled only for multi_turn category - + # caching enabled only for multi_turn category + inference_data["caching_enabled"] = ( + is_multi_turn(test_category) and "claude-3-sonnet" not in self.model_name + ) + return inference_data def _compile_tools(self, inference_data: dict, test_entry: dict) -> dict: @@ -118,11 +123,19 @@ def _compile_tools(self, inference_data: dict, test_entry: dict) -> dict: functions = func_doc_language_specific_pre_processing(functions, test_category) tools = convert_to_tool(functions, GORILLA_TO_OPENAPI, self.model_style) - if 'multi_turn' in test_category: - tools[-1]['cache_control'] = {'type': 'ephemeral'} - + if inference_data["caching_enabled"]: + # First time compiling tools, so adding cache control flag to the last tool + if "tools" not in inference_data: + tools[-1]["cache_control"] = {"type": "ephemeral"} + # This is the situation where the tools are already compiled and we are adding more tools to the existing tools (in miss_func category) + # We add the cache control flag to the last tool in the previous existing tools and the last tool in the new tools to maximize cache hit + else: + existing_tool_len = len(inference_data["tools"]) + tools[existing_tool_len - 1]["cache_control"] = {"type": "ephemeral"} + tools[-1]["cache_control"] = {"type": "ephemeral"} + inference_data["tools"] = tools - + return inference_data def _parse_query_response_FC(self, api_response: any) -> dict: @@ -138,7 +151,7 @@ def _parse_query_response_FC(self, api_response: any) -> dict: tool_call_ids.append(content.id) model_responses = tool_call_outputs if tool_call_outputs else text_outputs - + model_responses_message_for_chat_history = api_response.content return { @@ -147,25 +160,21 @@ def _parse_query_response_FC(self, api_response: any) -> dict: "tool_call_ids": tool_call_ids, "input_token": api_response.usage.input_tokens, "output_token": api_response.usage.output_tokens, - "prompt_write_cache_token_count": getattr(api_response.usage, 'cache_creation_input_tokens', 0), - "prompt_read_cache_token_count": getattr(api_response.usage, 'cache_read_input_tokens', 0) } def add_first_turn_message_FC( self, inference_data: dict, first_turn_message: list[dict] ) -> dict: - if isinstance(first_turn_message[0]['content'], str): - message_content = [{"type": "text", "text": first_turn_message[0]['content']}] - first_turn_message[0]['content'] = message_content + for message in first_turn_message: + message["content"] = [{"type": "text", "text": message["content"]}] inference_data["message"].extend(first_turn_message) return inference_data def _add_next_turn_user_message_FC( self, inference_data: dict, user_message: list[dict] ) -> dict: - if isinstance(user_message[0]['content'], str): - message_content = [{"type": "text","text": user_message[0]['content']}] - user_message[0]['content'] = message_content + for message in user_message: + message["content"] = [{"type": "text", "text": message["content"]}] inference_data["message"].extend(user_message) return inference_data @@ -213,31 +222,27 @@ def _query_prompting(self, inference_data: dict): "message": repr(inference_data["message"]), "system_prompt": inference_data["system_prompt"], } - system_prompt_text = inference_data["system_prompt"] - system_prompt = [{"type": "text", "text": system_prompt_text}] - messages = inference_data["message"] - caching_enabled = inference_data["caching_enabled"] - - if caching_enabled: - # Caching system prompt - system_prompt[0]['cache_control'] = {"type": "ephemeral"} - # Caching messages - messages[-1]['content'][0]["cache_control"] = {"type": "ephemeral"} - user_indices = [i for i, item in enumerate(messages) if item['role'] == 'user'] # Keeping the cache control blocks only in the last two user messages. - for i in user_indices[:-2]: - for content in messages[i]['content']: - if 'cache_control' in content: - del content['cache_control'] + if inference_data["caching_enabled"]: + # Cache the system prompt + inference_data["system_prompt"][0]["cache_control"] = {"type": "ephemeral"} + # Add cache control to the last two user messages as well + count = 0 + for message in reversed(inference_data["message"]): + if message["role"] == "user": + if count < 2: + message["content"][0]["cache_control"] = {"type": "ephemeral"} + else: + if "cache_control" in message["content"][0]: + del message["content"][0]["cache_control"] + count += 1 api_response = self.client.beta.prompt_caching.messages.create( model=self.model_name, - max_tokens=( - 8192 if "claude-3-5-sonnet-20240620" in self.model_name else 4096 - ), + max_tokens=(8192 if "claude-3-5-sonnet-20240620" in self.model_name else 4096), temperature=self.temperature, - system=system_prompt, - messages=messages + system=inference_data["system_prompt"], + messages=inference_data["message"], ) return api_response @@ -254,42 +259,47 @@ def _pre_query_processing_prompting(self, test_entry: dict) -> dict: # Claude takes in system prompt in a specific field, not in the message field, so we don't need to add it to the message system_prompt = extract_system_prompt(test_entry["question"][0]) + system_prompt = [{"type": "text", "text": system_prompt}] + # Claude doesn't allow consecutive user prompts, so we need to combine them for round_idx in range(len(test_entry["question"])): test_entry["question"][round_idx] = combine_consecutive_user_prompts( test_entry["question"][round_idx] ) - + test_entry_id: str = test_entry["id"] test_category: str = test_entry_id.rsplit("_", 1)[0] - caching_enabled: bool = 'multi_turn' in test_category # caching enabled only for multi_turn category + # caching enabled only for multi_turn category + caching_enabled: bool = ( + is_multi_turn(test_category) and "claude-3-sonnet" not in self.model_name + ) - return {"message": [], "system_prompt": system_prompt, "caching_enabled": caching_enabled} + return { + "message": [], + "system_prompt": system_prompt, + "caching_enabled": caching_enabled, + } def _parse_query_response_prompting(self, api_response: any) -> dict: return { "model_responses": api_response.content[0].text, "input_token": api_response.usage.input_tokens, "output_token": api_response.usage.output_tokens, - "prompt_write_cache_token_count": getattr(api_response.usage, 'cache_creation_input_tokens', 0), - "prompt_read_cache_token_count": getattr(api_response.usage, 'cache_read_input_tokens', 0), } def add_first_turn_message_prompting( self, inference_data: dict, first_turn_message: list[dict] ) -> dict: - if isinstance(first_turn_message[0]['content'], str): - message_content = [{"type": "text", "text": first_turn_message[0]['content']}] - first_turn_message[0]['content'] = message_content + for message in first_turn_message: + message["content"] = [{"type": "text", "text": message["content"]}] inference_data["message"].extend(first_turn_message) return inference_data def _add_next_turn_user_message_prompting( self, inference_data: dict, user_message: list[dict] ) -> dict: - if isinstance(user_message[0]['content'], str): - message_content = [{"type": "text","text": user_message[0]['content']}] - user_message[0]['content'] = message_content + for message in user_message: + message["content"] = [{"type": "text", "text": message["content"]}] inference_data["message"].extend(user_message) return inference_data