Skip to content

Commit

Permalink
[BFCL] Prompt Caching for Claude Models (#751)
Browse files Browse the repository at this point in the history
This PR request seeks to merge my changes of adding prompt caching
abilities when running inference on Claude models. The benefit will be
reduced cost significantly for inference on BFCL's multi-turn datasets
when using the following models (in both Function Calling and Prompting
modes):
- Claude 3.5 Sonnet
- Claude 3 Haiku
- Claude 3 Opus

Summary of changes made:

- Cached user messages
- Cached system prompt (for Prompting mode)
- Cached tools (for Function-Calling mode)

Please note:

- This implementation rightfully avoids caching in single-turn cases as
there aren't any future turns that could avail cache reading benefits.
- According to the [Anthropic
guide](https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching#cache-storage-and-sharing),
using prompting caching **will not** affect the model accuracy.
> Prompt caching has no effect on output token generation. The response
you receive will be identical to what you would get if prompt caching
was not used.

---------

Co-authored-by: Huanzhi (Hans) Mao <[email protected]>
  • Loading branch information
VishnuSuresh27 and HuanzhiMao authored Nov 13, 2024
1 parent 19490f1 commit 5a42197
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@
convert_system_prompt_into_user_prompt,
convert_to_function_call,
convert_to_tool,
format_execution_results_prompting,
extract_system_prompt,
format_execution_results_prompting,
func_doc_language_specific_pre_processing,
system_prompt_pre_processing_chat_model,
)
from bfcl.utils import is_multi_turn


class ClaudeHandler(BaseHandler):
Expand Down Expand Up @@ -74,14 +75,26 @@ def _query_FC(self, inference_data: dict):
"message": repr(inference_data["message"]),
"tools": inference_data["tools"],
}

return self.client.messages.create(
messages = inference_data["message"]

if inference_data["caching_enabled"]:
# Only add cache control to the last two user messages
# Remove previously set cache control flags from all user messages except the last two
count = 0
for message in reversed(messages):
if message["role"] == "user":
if count < 2:
message["content"][0]["cache_control"] = {"type": "ephemeral"}
else:
if "cache_control" in message["content"][0]:
del message["content"][0]["cache_control"]
count += 1

return self.client.beta.prompt_caching.messages.create(
model=self.model_name.strip("-FC"),
max_tokens=(
8192 if "claude-3-5-sonnet-20240620" in self.model_name else 4096
), # 3.5 Sonnet has a higher max token limit
max_tokens=(8192 if "claude-3-5-sonnet-20240620" in self.model_name else 4096),
tools=inference_data["tools"],
messages=inference_data["message"],
messages=messages,
)

def _pre_query_processing_FC(self, inference_data: dict, test_entry: dict) -> dict:
Expand All @@ -92,8 +105,15 @@ def _pre_query_processing_FC(self, inference_data: dict, test_entry: dict) -> di
test_entry["question"][round_idx] = combine_consecutive_user_prompts(
test_entry["question"][round_idx]
)

inference_data["message"] = []

test_entry_id: str = test_entry["id"]
test_category: str = test_entry_id.rsplit("_", 1)[0]
# caching enabled only for multi_turn category
inference_data["caching_enabled"] = (
is_multi_turn(test_category) and "claude-3-sonnet" not in self.model_name
)

return inference_data

def _compile_tools(self, inference_data: dict, test_entry: dict) -> dict:
Expand All @@ -103,8 +123,19 @@ def _compile_tools(self, inference_data: dict, test_entry: dict) -> dict:
functions = func_doc_language_specific_pre_processing(functions, test_category)
tools = convert_to_tool(functions, GORILLA_TO_OPENAPI, self.model_style)

if inference_data["caching_enabled"]:
# First time compiling tools, so adding cache control flag to the last tool
if "tools" not in inference_data:
tools[-1]["cache_control"] = {"type": "ephemeral"}
# This is the situation where the tools are already compiled and we are adding more tools to the existing tools (in miss_func category)
# We add the cache control flag to the last tool in the previous existing tools and the last tool in the new tools to maximize cache hit
else:
existing_tool_len = len(inference_data["tools"])
tools[existing_tool_len - 1]["cache_control"] = {"type": "ephemeral"}
tools[-1]["cache_control"] = {"type": "ephemeral"}

inference_data["tools"] = tools

return inference_data

def _parse_query_response_FC(self, api_response: any) -> dict:
Expand Down Expand Up @@ -134,12 +165,16 @@ def _parse_query_response_FC(self, api_response: any) -> dict:
def add_first_turn_message_FC(
self, inference_data: dict, first_turn_message: list[dict]
) -> dict:
for message in first_turn_message:
message["content"] = [{"type": "text", "text": message["content"]}]
inference_data["message"].extend(first_turn_message)
return inference_data

def _add_next_turn_user_message_FC(
self, inference_data: dict, user_message: list[dict]
) -> dict:
for message in user_message:
message["content"] = [{"type": "text", "text": message["content"]}]
inference_data["message"].extend(user_message)
return inference_data

Expand Down Expand Up @@ -188,16 +223,27 @@ def _query_prompting(self, inference_data: dict):
"system_prompt": inference_data["system_prompt"],
}

api_response = self.client.messages.create(
if inference_data["caching_enabled"]:
# Cache the system prompt
inference_data["system_prompt"][0]["cache_control"] = {"type": "ephemeral"}
# Add cache control to the last two user messages as well
count = 0
for message in reversed(inference_data["message"]):
if message["role"] == "user":
if count < 2:
message["content"][0]["cache_control"] = {"type": "ephemeral"}
else:
if "cache_control" in message["content"][0]:
del message["content"][0]["cache_control"]
count += 1

api_response = self.client.beta.prompt_caching.messages.create(
model=self.model_name,
max_tokens=(
8192 if "claude-3-5-sonnet-20240620" in self.model_name else 4096
), # 3.5 Sonnet has a higher max token limit
max_tokens=(8192 if "claude-3-5-sonnet-20240620" in self.model_name else 4096),
temperature=self.temperature,
system=inference_data["system_prompt"],
messages=inference_data["message"],
)

return api_response

def _pre_query_processing_prompting(self, test_entry: dict) -> dict:
Expand All @@ -213,13 +259,26 @@ def _pre_query_processing_prompting(self, test_entry: dict) -> dict:
# Claude takes in system prompt in a specific field, not in the message field, so we don't need to add it to the message
system_prompt = extract_system_prompt(test_entry["question"][0])

system_prompt = [{"type": "text", "text": system_prompt}]

# Claude doesn't allow consecutive user prompts, so we need to combine them
for round_idx in range(len(test_entry["question"])):
test_entry["question"][round_idx] = combine_consecutive_user_prompts(
test_entry["question"][round_idx]
)

return {"message": [], "system_prompt": system_prompt}
test_entry_id: str = test_entry["id"]
test_category: str = test_entry_id.rsplit("_", 1)[0]
# caching enabled only for multi_turn category
caching_enabled: bool = (
is_multi_turn(test_category) and "claude-3-sonnet" not in self.model_name
)

return {
"message": [],
"system_prompt": system_prompt,
"caching_enabled": caching_enabled,
}

def _parse_query_response_prompting(self, api_response: any) -> dict:
return {
Expand All @@ -231,12 +290,16 @@ def _parse_query_response_prompting(self, api_response: any) -> dict:
def add_first_turn_message_prompting(
self, inference_data: dict, first_turn_message: list[dict]
) -> dict:
for message in first_turn_message:
message["content"] = [{"type": "text", "text": message["content"]}]
inference_data["message"].extend(first_turn_message)
return inference_data

def _add_next_turn_user_message_prompting(
self, inference_data: dict, user_message: list[dict]
) -> dict:
for message in user_message:
message["content"] = [{"type": "text", "text": message["content"]}]
inference_data["message"].extend(user_message)
return inference_data

Expand Down
2 changes: 1 addition & 1 deletion berkeley-function-call-leaderboard/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ dependencies = [
"tree-sitter-javascript==0.21.4",
"openai==1.46.0",
"mistralai==1.1.0",
"anthropic==0.31.1",
"anthropic==0.37.1",
"cohere==5.5.8",
"typer>=0.12.5",
"tabulate>=0.9.0",
Expand Down

0 comments on commit 5a42197

Please sign in to comment.