Skip to content

Commit

Permalink
handle parallel function calls from gemini (ShishirPatil#406)
Browse files Browse the repository at this point in the history
Handle parallel function calls for Gemini handler for the Berkeley
Function Calling Leaderboard.

This PR does NOT change values in BFCL.

Co-authored-by: Xiaowei Li <[email protected]>
  • Loading branch information
vandyxiaowei and Xiaowei Li authored May 9, 2024
1 parent 0eb02bb commit 42f9d28
Showing 1 changed file with 13 additions and 19 deletions.
32 changes: 13 additions & 19 deletions berkeley-function-call-leaderboard/model_handler/gemini_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,7 @@ def _query_gemini(self, user_query, functions):
}

# NOTE: To run the gemini model, you need to provide your own GCP project ID, which can be found in the GCP console.
if self.model_name == "gemini-1.5-pro-preview-0409":
API_URL = "https://us-central1-aiplatform.googleapis.com/v1beta1/projects/{YOUR_GCP_PROJECT_ID_HERE}/locations/us-central1/publishers/google/models/gemini-1.5-pro-preview-0409:generateContent"
else:
API_URL = "https://us-central1-aiplatform.googleapis.com/v1beta1/projects/{YOUR_GCP_PROJECT_ID_HERE}/locations/us-central1/publishers/google/models/gemini-1.0-pro:generateContent"
API_URL = "https://us-central1-aiplatform.googleapis.com/v1beta1/projects/{YOUR_GCP_PROJECT_ID_HERE}/locations/us-central1/publishers/google/models/" + self.model_name + ":generateContent"
headers = {
"Authorization": "Bearer " + token,
"Content-Type": "application/json",
Expand All @@ -65,22 +62,19 @@ def _query_gemini(self, user_query, functions):
"output_tokens": 0,
"latency": latency,
}
contents = result["candidates"][0]["content"]["parts"][0]
if "functionCall" in contents:
if (
"name" in contents["functionCall"]
and "args" in contents["functionCall"]
):
result = {
contents["functionCall"]["name"]: json.dumps(
contents["functionCall"]["args"]
)
}

parts = []
for part in result["candidates"][0]["content"]["parts"]:
if "functionCall" in part:
if (
"name" in part["functionCall"]
and "args" in part["functionCall"]
):
parts.append({part["functionCall"]["name"]: json.dumps(part["functionCall"]["args"])})
else:
parts.append("Parsing error: " + json.dumps(part["functionCall"]))
else:
result = "Parsing error: " + json.dumps(contents["functionCall"])
else:
result = contents["text"]
parts.append(part["text"])
result = parts
metatdata = {}
metatdata["input_tokens"] = json.loads(response.content)["usageMetadata"][
"promptTokenCount"
Expand Down

0 comments on commit 42f9d28

Please sign in to comment.