Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BFCL] Prompt Caching for Claude Models #751

Merged
merged 5 commits into from
Nov 13, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Implement prompt caching for claude models
VishnuSuresh27 committed Nov 11, 2024
commit b6b5de38f939c3b9e65e0f1b1cbf69887c529a07
Original file line number Diff line number Diff line change
@@ -74,14 +74,25 @@ def _query_FC(self, inference_data: dict):
"message": repr(inference_data["message"]),
"tools": inference_data["tools"],
}

return self.client.messages.create(
tools = inference_data["tools"]
messages = inference_data["message"]
caching_enabled = inference_data["caching_enabled"]

if caching_enabled: # Caching will only be enabled for multi_turn categories
messages[-1]['content'][0]["cache_control"] = {"type": "ephemeral"}
user_indices = [i for i, item in enumerate(messages) if item['role'] == 'user'] # Keeping the cache control blocks only in the last two user messages.
for i in user_indices[:-2]:
for content in messages[i]['content']:
if 'cache_control' in content:
del content['cache_control']

return self.client.beta.prompt_caching.messages.create(
model=self.model_name.strip("-FC"),
max_tokens=(
8192 if "claude-3-5-sonnet-20240620" in self.model_name else 4096
), # 3.5 Sonnet has a higher max token limit
tools=inference_data["tools"],
messages=inference_data["message"],
),
tools=tools,
messages=messages
)

def _pre_query_processing_FC(self, inference_data: dict, test_entry: dict) -> dict:
@@ -92,8 +103,12 @@ def _pre_query_processing_FC(self, inference_data: dict, test_entry: dict) -> di
test_entry["question"][round_idx] = combine_consecutive_user_prompts(
test_entry["question"][round_idx]
)

inference_data["message"] = []

test_entry_id: str = test_entry["id"]
test_category: str = test_entry_id.rsplit("_", 1)[0]
inference_data["caching_enabled"] = 'multi_turn' in test_category #caching_enabled only for multi_turn category

return inference_data

def _compile_tools(self, inference_data: dict, test_entry: dict) -> dict:
@@ -103,6 +118,9 @@ def _compile_tools(self, inference_data: dict, test_entry: dict) -> dict:
functions = func_doc_language_specific_pre_processing(functions, test_category)
tools = convert_to_tool(functions, GORILLA_TO_OPENAPI, self.model_style)

if 'multi_turn' in test_category:
tools[-1]['cache_control'] = {'type': 'ephemeral'}

inference_data["tools"] = tools

return inference_data
@@ -120,7 +138,7 @@ def _parse_query_response_FC(self, api_response: any) -> dict:
tool_call_ids.append(content.id)

model_responses = tool_call_outputs if tool_call_outputs else text_outputs

model_responses_message_for_chat_history = api_response.content

return {
@@ -129,17 +147,25 @@ def _parse_query_response_FC(self, api_response: any) -> dict:
"tool_call_ids": tool_call_ids,
"input_token": api_response.usage.input_tokens,
"output_token": api_response.usage.output_tokens,
"prompt_write_cache_token_count": getattr(api_response.usage, 'cache_creation_input_tokens', 0),
"prompt_read_cache_token_count": getattr(api_response.usage, 'cache_read_input_tokens', 0)
}

def add_first_turn_message_FC(
self, inference_data: dict, first_turn_message: list[dict]
) -> dict:
if isinstance(first_turn_message[0]['content'], str):
message_content = [{"type": "text", "text": first_turn_message[0]['content']}]
first_turn_message[0]['content'] = message_content
inference_data["message"].extend(first_turn_message)
return inference_data

def _add_next_turn_user_message_FC(
self, inference_data: dict, user_message: list[dict]
) -> dict:
if isinstance(user_message[0]['content'], str):
message_content = [{"type": "text","text": user_message[0]['content']}]
user_message[0]['content'] = message_content
inference_data["message"].extend(user_message)
return inference_data

@@ -187,17 +213,32 @@ def _query_prompting(self, inference_data: dict):
"message": repr(inference_data["message"]),
"system_prompt": inference_data["system_prompt"],
}

api_response = self.client.messages.create(
system_prompt_text = inference_data["system_prompt"]
system_prompt = [{"type": "text", "text": system_prompt_text}]
messages = inference_data["message"]
caching_enabled = inference_data["caching_enabled"]

if caching_enabled:
# Caching system prompt
system_prompt[0]['cache_control'] = {"type": "ephemeral"}

# Caching messages
messages[-1]['content'][0]["cache_control"] = {"type": "ephemeral"}
user_indices = [i for i, item in enumerate(messages) if item['role'] == 'user'] # Keeping the cache control blocks only in the last two user messages.
for i in user_indices[:-2]:
for content in messages[i]['content']:
if 'cache_control' in content:
del content['cache_control']

api_response = self.client.beta.prompt_caching.messages.create(
model=self.model_name,
max_tokens=(
8192 if "claude-3-5-sonnet-20240620" in self.model_name else 4096
), # 3.5 Sonnet has a higher max token limit
),
temperature=self.temperature,
system=inference_data["system_prompt"],
messages=inference_data["message"],
system=system_prompt,
messages=messages
)

return api_response

def _pre_query_processing_prompting(self, test_entry: dict) -> dict:
@@ -218,25 +259,37 @@ def _pre_query_processing_prompting(self, test_entry: dict) -> dict:
test_entry["question"][round_idx] = combine_consecutive_user_prompts(
test_entry["question"][round_idx]
)

test_entry_id: str = test_entry["id"]
test_category: str = test_entry_id.rsplit("_", 1)[0]
caching_enabled: bool = 'multi_turn' in test_category # caching enabled only for multi_turn category

return {"message": [], "system_prompt": system_prompt}
return {"message": [], "system_prompt": system_prompt, "caching_enabled": caching_enabled}

def _parse_query_response_prompting(self, api_response: any) -> dict:
return {
"model_responses": api_response.content[0].text,
"input_token": api_response.usage.input_tokens,
"output_token": api_response.usage.output_tokens,
"prompt_write_cache_token_count": getattr(api_response.usage, 'cache_creation_input_tokens', 0),
"prompt_read_cache_token_count": getattr(api_response.usage, 'cache_read_input_tokens', 0),
}

def add_first_turn_message_prompting(
self, inference_data: dict, first_turn_message: list[dict]
) -> dict:
if isinstance(first_turn_message[0]['content'], str):
message_content = [{"type": "text", "text": first_turn_message[0]['content']}]
first_turn_message[0]['content'] = message_content
inference_data["message"].extend(first_turn_message)
return inference_data

def _add_next_turn_user_message_prompting(
self, inference_data: dict, user_message: list[dict]
) -> dict:
if isinstance(user_message[0]['content'], str):
message_content = [{"type": "text","text": user_message[0]['content']}]
user_message[0]['content'] = message_content
inference_data["message"].extend(user_message)
return inference_data

2 changes: 1 addition & 1 deletion berkeley-function-call-leaderboard/pyproject.toml
Original file line number Diff line number Diff line change
@@ -24,7 +24,7 @@ dependencies = [
"tree-sitter-javascript==0.21.4",
"openai==1.46.0",
"mistralai==1.1.0",
"anthropic==0.31.1",
"anthropic==0.37.1",
"cohere==5.5.8",
"typer>=0.12.5",
"tabulate>=0.9.0",