Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BFCL] Add Prompt Caching for Claude Models #727

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ def inference_multi_turn_FC(

total_input_token_count: list[list[float]] = []
total_output_token_count: list[list[float]] = []
total_prompt_write_cache_token_count: list[list[float]] = []
total_prompt_read_cache_token_count: list[list[float]] = []
total_latency: list[list[float]] = []
all_model_response: list[list] = (
[]
Expand Down Expand Up @@ -113,6 +115,8 @@ def inference_multi_turn_FC(
current_round_input_token_count: list[float] = []
current_round_output_token_count: list[float] = []
current_round_latency: list[float] = []
current_round_prompt_write_cache_token_count: list[float] = []
current_round_prompt_read_cache_token_count: list[float] = []

count = 0
while True:
Expand Down Expand Up @@ -144,6 +148,8 @@ def inference_multi_turn_FC(
# Process the metadata
current_round_input_token_count.append(model_response_data["input_token"])
current_round_output_token_count.append(model_response_data["output_token"])
current_round_prompt_write_cache_token_count.append(model_response_data["prompt_write_cache_token_count"])
current_round_prompt_read_cache_token_count.append(model_response_data["prompt_read_cache_token_count"])
current_round_latency.append(query_latency)

# Try decoding the model response
Expand Down Expand Up @@ -221,6 +227,8 @@ def inference_multi_turn_FC(
all_debugging_log.append(current_round_debugging_log)
total_input_token_count.append(current_round_input_token_count)
total_output_token_count.append(current_round_output_token_count)
total_prompt_write_cache_token_count.append(current_round_prompt_write_cache_token_count)
total_prompt_read_cache_token_count.append(current_round_prompt_read_cache_token_count)
total_latency.append(current_round_latency)

if force_quit:
Expand All @@ -231,6 +239,8 @@ def inference_multi_turn_FC(
metadata["debugging_log"] = all_debugging_log
metadata["input_token_count"] = total_input_token_count
metadata["output_token_count"] = total_output_token_count
metadata["prompt_write_cache_token_count"] = total_prompt_write_cache_token_count
metadata["prompt_read_cache_token_count"] = total_prompt_read_cache_token_count
metadata["latency"] = total_latency

return all_model_response, metadata
Expand All @@ -252,6 +262,8 @@ def inference_multi_turn_prompting(

total_input_token_count: list[list[float]] = []
total_output_token_count: list[list[float]] = []
total_prompt_write_cache_token_count: list[list[float]] = []
total_prompt_read_cache_token_count: list[list[float]] = []
total_latency: list[list[float]] = []
all_model_response: list[list] = (
[]
Expand Down Expand Up @@ -294,6 +306,8 @@ def inference_multi_turn_prompting(
current_round_input_token_count: list[float] = []
current_round_output_token_count: list[float] = []
current_round_latency: list[float] = []
current_round_prompt_write_cache_token_count: list[float] = []
current_round_prompt_read_cache_token_count: list[float] = []

count = 0
while True:
Expand Down Expand Up @@ -324,6 +338,8 @@ def inference_multi_turn_prompting(
# Process the metadata
current_round_input_token_count.append(model_response_data["input_token"])
current_round_output_token_count.append(model_response_data["output_token"])
current_round_prompt_write_cache_token_count.append(model_response_data.get("prompt_write_cache_token_count", 0))
current_round_prompt_read_cache_token_count.append(model_response_data.get("prompt_read_cache_token_count", 0))
current_round_latency.append(query_latency)

# Try decoding the model response
Expand Down Expand Up @@ -401,6 +417,8 @@ def inference_multi_turn_prompting(
all_debugging_log.append(current_round_debugging_log)
total_input_token_count.append(current_round_input_token_count)
total_output_token_count.append(current_round_output_token_count)
total_prompt_write_cache_token_count.append(current_round_prompt_write_cache_token_count)
total_prompt_read_cache_token_count.append(current_round_prompt_read_cache_token_count)
total_latency.append(current_round_latency)

if force_quit:
Expand All @@ -411,6 +429,8 @@ def inference_multi_turn_prompting(
metadata["debugging_log"] = all_debugging_log
metadata["input_token_count"] = total_input_token_count
metadata["output_token_count"] = total_output_token_count
metadata["prompt_write_cache_token_count"] = total_prompt_write_cache_token_count
metadata["prompt_read_cache_token_count"] = total_prompt_read_cache_token_count
metadata["latency"] = total_latency

return all_model_response, metadata
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from bfcl.model_handler.proprietary_model.nvidia import NvidiaHandler
from bfcl.model_handler.proprietary_model.openai import OpenAIHandler
from bfcl.model_handler.proprietary_model.yi import YiHandler

# TODO: Add Deepseek V2, meta-llama/Llama-3.1-405B-Instruct

# Inference through API calls
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import json
import os

from anthropic import Anthropic
from transformers import GPT2TokenizerFast

from anthropic.types import TextBlock, ToolUseBlock
from bfcl.model_handler.base_handler import BaseHandler
from bfcl.model_handler.constant import GORILLA_TO_OPENAPI
Expand All @@ -18,13 +19,19 @@
system_prompt_pre_processing_chat_model,
)

# in order to avoid the warning:
# huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
# To disable this warning, you can either:
# - Avoid using `tokenizers` before the fork if possible
# - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
os.environ["TOKENIZERS_PARALLELISM"] = "false"

class ClaudeHandler(BaseHandler):
def __init__(self, model_name, temperature) -> None:
super().__init__(model_name, temperature)
self.model_style = ModelStyle.Anthropic
self.client = Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY"))

def decode_ast(self, result, language="Python"):
if "FC" not in self.model_name:
func = result
Expand Down Expand Up @@ -75,15 +82,47 @@ def _query_FC(self, inference_data: dict):
"tools": inference_data["tools"],
}

return self.client.messages.create(
tools = inference_data["tools"]

if tools:
# Use a tokenizer to calculate token counts
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")

cumulative_token_count = 0
cacheable_tools = []
cache_control_applied = False

for idx, tool in enumerate(tools):
tool_text = json.dumps(tool)
token_count = len(tokenizer.encode(tool_text))
cumulative_token_count += token_count
cacheable_tools.append(tool)

# Apply cache_control if cumulative tokens exceed 1024
if not cache_control_applied and cumulative_token_count >= 1024:
cacheable_tools[-1]['cache_control'] = {'type': 'ephemeral'}
cache_control_applied = True
print(f"Caching all tools with cumulative tokens {cumulative_token_count}")

if not cache_control_applied:
# If cumulative tokens never reached 1024, no caching will occur because 1024 is the threshold for caching
print("Cumulative token count did not reach 1024. No caching applied.")
else:
cacheable_tools = tools

response = self.client.beta.prompt_caching.messages.create(
model=self.model_name.strip("-FC"),
max_tokens=(
8192 if "claude-3-5-sonnet-20240620" in self.model_name else 4096
), # 3.5 Sonnet has a higher max token limit
tools=inference_data["tools"],
messages=inference_data["message"],
),
tools=cacheable_tools,
messages=inference_data["message"]
)

return response



def _pre_query_processing_FC(self, inference_data: dict, test_entry: dict) -> dict:
for round_idx in range(len(test_entry["question"])):
test_entry["question"][round_idx] = convert_system_prompt_into_user_prompt(
Expand Down Expand Up @@ -129,6 +168,8 @@ def _parse_query_response_FC(self, api_response: any) -> dict:
"tool_call_ids": tool_call_ids,
"input_token": api_response.usage.input_tokens,
"output_token": api_response.usage.output_tokens,
"prompt_write_cache_token_count": getattr(api_response.usage, 'cache_creation_input_tokens', 0),
"prompt_read_cache_token_count": getattr(api_response.usage, 'cache_read_input_tokens', 0),
}

def add_first_turn_message_FC(
Expand Down Expand Up @@ -188,17 +229,41 @@ def _query_prompting(self, inference_data: dict):
"system_prompt": inference_data["system_prompt"],
}

api_response = self.client.messages.create(
# Use the GPT2 tokenizer to calculate token counts
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")

system_prompt_text = inference_data["system_prompt"]
token_count = len(tokenizer.encode(system_prompt_text))

system_prompt = [
{
"type": "text",
"text": system_prompt_text
}
]

# Decide whether to apply caching based on token count
if token_count >= 1024:
# Add 'cache_control' to the system prompt
system_prompt[0]['cache_control'] = {"type": "ephemeral"}
print(f"Caching system prompt with token count: {token_count}")
else:
print(f"System prompt token count ({token_count}) is less than 1024. No caching applied.")


# Use the beta prompt caching endpoint
api_response = self.client.beta.prompt_caching.messages.create(
model=self.model_name,
max_tokens=(
8192 if "claude-3-5-sonnet-20240620" in self.model_name else 4096
), # 3.5 Sonnet has a higher max token limit
),
temperature=self.temperature,
system=inference_data["system_prompt"],
messages=inference_data["message"],
system=system_prompt,
messages=inference_data["message"]
)

return api_response


def _pre_query_processing_prompting(self, test_entry: dict) -> dict:
functions: list = test_entry["function"]
Expand Down Expand Up @@ -226,6 +291,8 @@ def _parse_query_response_prompting(self, api_response: any) -> dict:
"model_responses": api_response.content[0].text,
"input_token": api_response.usage.input_tokens,
"output_token": api_response.usage.output_tokens,
"prompt_write_cache_token_count": getattr(api_response.usage, 'cache_creation_input_tokens', 0),
"prompt_read_cache_token_count": getattr(api_response.usage, 'cache_read_input_tokens', 0),
}

def add_first_turn_message_prompting(
Expand Down Expand Up @@ -262,3 +329,4 @@ def _add_execution_results_prompting(
)

return inference_data

3 changes: 2 additions & 1 deletion berkeley-function-call-leaderboard/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ dependencies = [
"tree-sitter-javascript==0.21.4",
"openai==1.46.0",
"mistralai==1.1.0",
"anthropic==0.31.1",
"anthropic==0.37.1",
"transformers==4.46.0",
"cohere==5.5.8",
"typer>=0.12.5",
"tabulate>=0.9.0",
Expand Down