Skip to content

Commit

Permalink
[BFCL] Multi Turn Dataset and Possible Answer Fix (ShishirPatil#683)
Browse files Browse the repository at this point in the history
This PR addresses ambiguity in action intention and function parameters.

Total number of entries affected: 

- multi_turn_base: 

This will affect the leaderboard score. We will update it in a separate
PR.

---------

Co-authored-by: Fanjia-Yan <[email protected]>
Co-authored-by: Jason <[email protected]>
Co-authored-by: AndyChenYH <[email protected]>
Co-authored-by: Charlie Cheng-Jie Ji <[email protected]>
Co-authored-by: VishnuSuresh27 <[email protected]>
Co-authored-by: Charlie Cheng-Jie Ji <[email protected]>
Co-authored-by: Fanjia Yan <[email protected]>
Co-authored-by: Shishir Patil <[email protected]>
  • Loading branch information
9 people committed Nov 11, 2024
1 parent 1586799 commit fa8acb3
Show file tree
Hide file tree
Showing 10 changed files with 430 additions and 432 deletions.
1 change: 1 addition & 0 deletions berkeley-function-call-leaderboard/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

All notable changes to the Berkeley Function Calling Leaderboard will be documented in this file.

- [Oct 17, 2024] [#683](https://github.com/ShishirPatil/gorilla/pull/683): Bug fix for the multi turn categories for ambiguity in action intention and function parameters.
- [Oct 17, 2024] [#709](https://github.com/ShishirPatil/gorilla/pull/709): Rephrase question prompt for Java and JavaScript categories to improve clarity and action intent.
- [Oct 17, 2024] [#708](https://github.com/ShishirPatil/gorilla/pull/708): Update the ground truth for the REST category to be up-to-date with the latest API response structure.
- [Oct 16, 2024] [#701](https://github.com/ShishirPatil/gorilla/pull/701): Bug fix the multi turn function source code for `TravelAPI`.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,38 @@
import random
from typing import Dict, List, Optional, Union


from copy import deepcopy
DEFAULT_STATE = {
"generated_ids": set(),
"user_count": 4,
"user_map": {
"Alice": "USR001",
"Bob": "USR002",
"Catherine": "USR003",
"Daniel": "USR004",
},
"inbox": {
1: {
"sender_id": "USR001",
"receiver_id": "USR002",
"message": "My name is Alice. I want to connect.",
},
2: {
"sender_id": "USR001",
"receiver_id": "USR003",
"message": "Could you upload the file?",
},
3: {
"sender_id": "USR001",
"receiver_id": "USR004",
"message": "Could you upload the file?",
},
4: {"sender_id": "USR001", "receiver_id": "USR002", "message": "I am busy."},
5: {"sender_id": "USR001", "receiver_id": "USR002", "message": "I am on leave."},
},
"message_count": 0,
"current_user": None,
}

class MessageAPI:
"""
A class representing a Message API for managing user interactions in a workspace.
Expand All @@ -21,7 +52,7 @@ class MessageAPI:
list_users(): List all users in the workspace.
get_user_id(user: str): Get the user ID for a given username.
login(user_id: str): Log in a user.
send_message(sender_id: str, receiver_id: str, message: str): Send a message to another user.
send_message(receiver_id: str, message: str): Send a message to another user.
view_messages_received(): View messages received by the current user.
view_messages_sent(): View messages sent by the current user.
delete_message(receiver_id: str, message_index: int): Delete a sent message.
Expand All @@ -34,43 +65,13 @@ def __init__(self):
"""
Initialize the MessageAPI with a workspace ID.
"""
self.generated_ids = set()
self.user_count: int = 4
self.user_map: Dict[str, str] = {
"Alice": "USR001",
"Bob": "USR002",
"Catherine": "USR003",
"Daniel": "USR004",
}
self.inbox: Dict[int, Dict[str, Union[str, int]]] = {
1: {
"sender_id": "USR001",
"receiver_id": "USR002",
"message": "My name is Alice. I want to connect.",
},
2: {
"sender_id": "USR002",
"receiver_id": "USR003",
"message": "Could you upload the file?",
},
3: {
"sender_id": "USR002",
"receiver_id": "USR004",
"message": "Could you upload the file?",
},
4: {
"sender_id": "USR003",
"receiver_id": "USR002",
"message": "I am busy.",
},
5: {
"sender_id": "USR004",
"receiver_id": "USR002",
"message": "I am on leave.",
},
}
self.message_count: int = 0 # useless(?)
self.current_user: Optional[str] = None
DEFAULT_STATE_COPY = deepcopy(DEFAULT_STATE)
self.generated_ids = DEFAULT_STATE_COPY["generated_ids"]
self.user_count: int = DEFAULT_STATE_COPY["user_count"]
self.user_map: Dict[str, str] = DEFAULT_STATE_COPY["user_map"]
self.inbox: Dict[int, Dict[str, Union[str, int]]] = DEFAULT_STATE_COPY["inbox"]
self.message_count: int = DEFAULT_STATE_COPY["message_count"]
self.current_user: Optional[str] = DEFAULT_STATE_COPY["current_user"]

def _load_scenario(self, scenario: dict, long_context=False) -> None:
"""
Expand All @@ -79,9 +80,11 @@ def _load_scenario(self, scenario: dict, long_context=False) -> None:
Args:
scenario (dict): A dictionary containing message data.
"""
DEFAULT_STATE_COPY = deepcopy(DEFAULT_STATE)
self._random = random.Random((scenario.get("random_seed", 200191)))
self.user_count = scenario.get("user_count", 4)
self.current_user = scenario.get("current_user", None)
self.user_count = scenario.get("user_count", DEFAULT_STATE_COPY["user_count"])
self.current_user = scenario.get("current_user", DEFAULT_STATE_COPY["current_user"])
self.user_map = scenario.get("user_map", DEFAULT_STATE_COPY["user_map"])

def __eq__(self, value: object) -> bool:
if not isinstance(value, MessageAPI):
Expand Down Expand Up @@ -153,12 +156,11 @@ def message_login(self, user_id: str) -> Dict[str, Union[str, bool]]:
}

def send_message(
self, sender_id: str, receiver_id: str, message: str
self, receiver_id: str, message: str
) -> Dict[str, Union[str, bool]]:
"""
Send a message to a user.
Args:
sender_id (str): User ID of the user sending the message.
receiver_id (str): User ID of the user to send the message to.
message (str): Message to be sent.
Returns:
Expand All @@ -176,7 +178,7 @@ def send_message(
message_id = self.generate_id()
# Store the message in the inbox
self.inbox[message_id] = {
"sender_id": sender_id,
"sender_id": self.current_user,
"receiver_id": receiver_id,
"message": message,
}
Expand All @@ -188,12 +190,11 @@ def send_message(
}

def delete_message(
self, sender_id: str, receiver_id: str, message_id: int
self, receiver_id: str, message_id: int
) -> Dict[str, Union[bool, str]]:
"""
Delete a message sent to a user.
Args:
sender_id (str): User ID of the user sending the message.
receiver_id (str): User ID of the user to send the message to.
message_id (int): ID of the message to be deleted.
Returns:
Expand All @@ -215,7 +216,7 @@ def delete_message(
return {"error": "You do not have permission to delete this message."}
# Check if the sender and receiver match the input arguments
if (
message_data["sender_id"] != sender_id
message_data["sender_id"] != self.current_user
or message_data["receiver_id"] != receiver_id
):
return {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def multi_turn_checker(
), f"Model instances and ground truth instances do not match in length for turn {turn_index}. Model instances: {len(model_instances)}, Ground truth instances: {len(ground_truth_instances)}"
assert set(model_instances.keys()) == set(ground_truth_instances.keys())

# Check the status of the instances
# Check the state of the instances
state_check_result = state_checker(model_instances, ground_truth_instances)
if not state_check_result["valid"]:
return state_check_result
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -738,9 +738,7 @@ def format_execution_results_prompting(


def default_decode_ast_prompting(result, language="Python"):
result = result.strip()
result = result.rstrip("\n")
result = result.lstrip('\n')
result = result.strip("`\n ")
if not result.startswith("["):
result = "[" + result
if not result.endswith("]"):
Expand All @@ -750,9 +748,7 @@ def default_decode_ast_prompting(result, language="Python"):


def default_decode_execute_prompting(result):
result = result.strip()
result = result.rstrip("\n")
result = result.lstrip('\n')
result = result.strip("`\n ")
if not result.startswith("["):
result = "[" + result
if not result.endswith("]"):
Expand Down
Loading

0 comments on commit fa8acb3

Please sign in to comment.