Skip to content

Commit

Permalink
[BFCL] Multi Turn Dataset Fix (Base Category) (#723)
Browse files Browse the repository at this point in the history
This PR fixes the ambiguous prompt issue and some wrong ground truth
issues for the multi_turn_base category. After this PR, the
multi_turn_base entries should be bug-free.

Following #719 and #722 , this is also part of the effort to thoroughly
bug fix the multi turn categories. We will have more PR coming in the
next few days.

---------

Co-authored-by: Charlie Cheng-Jie Ji
<[email protected]>
Co-authored-by: Fanjia-Yan
<[email protected]>
Co-authored-by: VishnuSuresh27
<[email protected]>
  • Loading branch information
HuanzhiMao authored Oct 30, 2024
1 parent 4c16dbb commit a79d891
Show file tree
Hide file tree
Showing 18 changed files with 509 additions and 364 deletions.
2 changes: 1 addition & 1 deletion berkeley-function-call-leaderboard/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

All notable changes to the Berkeley Function Calling Leaderboard will be documented in this file.

- [Oct 24, 2024] [#719](https://github.com/ShishirPatil/gorilla/pull/719), [#722](https://github.com/ShishirPatil/gorilla/pull/722): Bug fix in the dataset and ground truth for the multi-turn categories.
- [Oct 24, 2024] [#719](https://github.com/ShishirPatil/gorilla/pull/719), [#722](https://github.com/ShishirPatil/gorilla/pull/722), [#723](https://github.com/ShishirPatil/gorilla/pull/723): Bug fix in the dataset and ground truth for the multi-turn categories.
- [Oct 17, 2024] [#683](https://github.com/ShishirPatil/gorilla/pull/683): Bug fix for the multi turn categories for ambiguity in action intention and function parameters.
- [Oct 17, 2024] [#709](https://github.com/ShishirPatil/gorilla/pull/709): Rephrase question prompt for Java and JavaScript categories to improve clarity and action intent.
- [Oct 17, 2024] [#708](https://github.com/ShishirPatil/gorilla/pull/708): Update the ground truth for the REST category to be up-to-date with the latest API response structure.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ def touch(self, file_name: str) -> Union[None, Dict[str, str]]:
return {"error": f"touch: cannot touch '{file_name}': File exists"}

self._current_dir._add_file(file_name)
return None
return {}

def echo(
self, content: str, file_name: Optional[str] = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def __eq__(self, value: object) -> bool:

return True

def generate_id(self):
def _generate_id(self):
"""
Generate a unique ID for a message.
Expand All @@ -113,7 +113,7 @@ def generate_id(self):
while new_id in self.generated_ids:
new_id = self._random.randint(10000, 99999)
self.generated_ids.add(new_id)
return {"new_id": new_id}
return new_id

def list_users(self) -> Dict[str, List[str]]:
"""
Expand Down Expand Up @@ -184,7 +184,7 @@ def send_message(self, receiver_id: str, message: str) -> Dict[str, Union[str, b
if receiver_id not in self.user_map.values():
return {"error": f"Receiver ID '{receiver_id}' not found."}
# Generate a unique message ID
message_id = self.generate_id()
message_id = self._generate_id()
# Store the message in the inbox
self.inbox.append({receiver_id: message})
self.message_count += 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def _load_scenario(self, scenario: dict, long_context=False) -> None:
"authenticated", DEFAULT_STATE_COPY["authenticated"]
)
self.tweets = scenario.get("tweets", DEFAULT_STATE_COPY["tweets"])
self.tweets = {int(k): v for k, v in self.tweets.items()} # Convert tweet keys from string to int from loaded scenario
self.comments = scenario.get("comments", DEFAULT_STATE_COPY["comments"])
self.retweets = scenario.get("retweets", DEFAULT_STATE_COPY["retweets"])
self.following_list = scenario.get(
Expand Down Expand Up @@ -79,8 +80,8 @@ def post_tweet(
Args:
content (str): Content of the tweet.
tags (List[str]): [Optional] List of tags for the tweet. Tag name should start with #.
mentions (List[str]): [Optional] List of users mentioned in the tweet. Mention name should start with @.
tags (List[str]): [Optional] List of tags for the tweet. Tag name should start with #. This is only relevant if the user wants to add tags to the tweet.
mentions (List[str]): [Optional] List of users mentioned in the tweet. Mention name should start with @. This is only relevant if the user wants to add mentions to the tweet.
Returns:
id (int): ID of the posted tweet.
username (str): Username of the poster.
Expand Down Expand Up @@ -113,7 +114,7 @@ def retweet(self, tweet_id: int) -> Dict[str, str]:
"""
if not self.authenticated:
return {"error": "User not authenticated. Please authenticate before retweeting."}

if tweet_id not in self.tweets:
return {"error": f"Tweet with ID {tweet_id} not found."}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import random
from copy import deepcopy
from datetime import datetime, time
from datetime import datetime, time, timedelta
from typing import Dict, List, Optional, Union

from .long_context import (
Expand All @@ -12,18 +13,24 @@
WATCH_LIST_EXTENSION,
)

CURRENT_TIME = datetime(2024, 9, 1, 10, 30)

DEFAULT_STATE = {
"orders": {
12345: {
"id": 12345,
"order_type": "Buy",
"symbol": "AAPL",
"price": 210.65,
"num_shares": 10,
"amount": 10,
"status": "Completed",
},
12446: {
"id": 12446,
"order_type": "Sell",
"symbol": "GOOG",
"price": 2840.56,
"num_shares": 5,
"amount": 5,
"status": "Pending",
},
},
Expand Down Expand Up @@ -116,6 +123,7 @@
},
"watch_list": ["NVDA"],
"transaction_history": [],
"random_seed": 1053520,
}


Expand All @@ -124,14 +132,14 @@ class TradingBot:
A class representing a trading bot for executing stock trades and managing a trading account.
Attributes:
orders (Dict[int, Dict[str, Union[str, float, int]]]): A dictionary of orders, keyed by order ID.
orders (Dict[int, Dict[str, Union[str, float, int]]]): A dictionary of orders for purchasing and selling of stock, keyed by order ID.
account_info (Dict[str, Union[int, float]]): Information about the trading account.
authenticated (bool): Whether the user is currently authenticated.
market_status (str): The current status of the market ('Open' or 'Closed').
order_counter (int): A counter for generating unique order IDs.
stocks (Dict[str, Dict[str, Union[float, int]]]): Information about various stocks.
watch_list (List[str]): A list of stock symbols being watched.
transaction_history (List[Dict[str, Union[str, float, int]]]): A history of transactions.
transaction_history (List[Dict[str, Union[str, float, int]]]): A history of trading account related transactions.
"""

def __init__(self):
Expand Down Expand Up @@ -177,6 +185,31 @@ def _load_scenario(self, scenario: dict, long_context=False) -> None:
"transaction_history", DEFAULT_STATE_COPY["transaction_history"]
)
self.long_context = long_context
self._random = random.Random(
(scenario.get("random_seed", DEFAULT_STATE_COPY["random_seed"]))
)

def _generate_transaction_timestamp(self) -> str:
"""
Generate a timestamp for a transaction.
Returns:
timestamp (str): A formatted timestamp string.
"""
# Define the start and end dates for the range
start_date = CURRENT_TIME
end_date = CURRENT_TIME + timedelta(days=1)

start_timestamp = int(start_date.timestamp())
end_timestamp = int(end_date.timestamp())

# Generate a random timestamp within the range
random_timestamp = self._random.randint(start_timestamp, end_timestamp)

# Convert the random timestamp to a datetime object
random_date = datetime.fromtimestamp(random_timestamp)

return random_date.strftime("%Y-%m-%d %H:%M:%S")

def get_current_time(self) -> Dict[str, str]:
"""
Expand All @@ -185,8 +218,7 @@ def get_current_time(self) -> Dict[str, str]:
Returns:
current_time (str): Current time in HH:MM AM/PM format.
"""
current_time = datetime(2024, 9, 1, 10, 30)
return {"current_time": current_time.strftime("%I:%M %p")}
return {"current_time": CURRENT_TIME.strftime("%I:%M %p")}

def update_market_status(self, current_time_str: str) -> Dict[str, str]:
"""
Expand All @@ -196,7 +228,7 @@ def update_market_status(self, current_time_str: str) -> Dict[str, str]:
current_time_str (str): Current time in HH:MM AM/PM format.
Returns:
status (str): Status of the market ('Open' or 'Closed').
status (str): Status of the market. [Enum]: ["Open", "Closed"]
"""
market_open_time = time(9, 30) # Market opens at 9:30 AM
market_close_time = time(16, 0) # Market closes at 4:00 PM
Expand Down Expand Up @@ -248,8 +280,8 @@ def get_stock_info(self, symbol: str) -> Dict[str, Union[float, int, str]]:
price (float): Current price of the stock.
percent_change (float): Percentage change in stock price.
volume (float): Trading volume of the stock.
MA5 (float): 5-day Moving Average of the stock.
MA20 (float): 20-day Moving Average of the stock.
MA(5) (float): 5-day Moving Average of the stock.
MA(20) (float): 20-day Moving Average of the stock.
"""
if symbol not in self.stocks:
return {"error": f"Stock with symbol '{symbol}' not found."}
Expand All @@ -271,7 +303,7 @@ def get_order_details(self, order_id: int) -> Dict[str, Union[str, float, int]]:
symbol (str): Symbol of the stock in the order.
price (float): Price at which the order was placed.
num_shares (int): Number of shares in the order.
status (str): Current status of the order.
status (str): Current status of the order. [Enum]: ["Open", "Pending", "Completed", "Cancelled"]
"""
if order_id not in self.orders:
return {
Expand Down Expand Up @@ -343,13 +375,17 @@ def place_order(

order_id = self.order_counter
self.orders[order_id] = {
"id": order_id,
"order_type": order_type,
"symbol": symbol,
"price": price,
"amount": amount,
"status": "Open",
}
self.order_counter += 1
# We return the status as "Pending" to indicate that the order has been placed but not yet executed
# When polled later, the status will show as 'Open'
# This is to simulate the delay between placing an order and it being executed
return {
"order_id": order_id,
"order_type": order_type,
Expand Down Expand Up @@ -388,7 +424,7 @@ def make_transaction(
{
"type": "deposit",
"amount": amount,
"timestamp": self.get_current_time(),
"timestamp": self._generate_transaction_timestamp(),
}
)
return {
Expand All @@ -403,7 +439,7 @@ def make_transaction(
{
"type": "withdrawal",
"amount": amount,
"timestamp": self.get_current_time(),
"timestamp": self._generate_transaction_timestamp(),
}
)
return {
Expand Down Expand Up @@ -483,7 +519,7 @@ def fund_account(self, amount: float) -> Dict[str, Union[str, float]]:
return {"error": "Funding amount must be positive."}
self.account_info["balance"] += amount
self.transaction_history.append(
{"type": "funding", "amount": amount, "timestamp": self.get_current_time()}
{"type": "deposit", "amount": amount, "timestamp": self._generate_transaction_timestamp()}
)
return {
"status": "Account funded successfully",
Expand Down Expand Up @@ -524,6 +560,28 @@ def get_watchlist(self) -> Dict[str, List[str]]:
watch_list.extend(WATCH_LIST_EXTENSION)
return watch_list
return {"watchlist": self.watch_list}

def get_order_history(self) -> Dict[str, List[Dict[str, Union[str, int, float]]]]:
"""
Get the stock order history.
Returns:
order_history (List[Dict]): List of orders in the order history.
- id (int): Order ID.
- order_type (str): Type of the order. [Enum]: ["Buy", "Sell"]
- symbol (str): Symbol of the stock in the order.
- price (float): Price at which the order was placed.
- amount (int): Number of shares in the order.
- status (str): Current status of the order. [Enum]: ["Open", "Pending", "Completed", "Cancelled"]
"""
if not self.authenticated:
return [
{
"error": "User not authenticated. Please log in to view order history."
}
]

return {"history": list(self.orders.values())}

def get_transaction_history(
self, start_date: Optional[str] = None, end_date: Optional[str] = None
Expand All @@ -536,7 +594,10 @@ def get_transaction_history(
end_date (str): [Optional] End date for the history (format: 'YYYY-MM-DD').
Returns:
history (List[Dict[str, str]]): List of transactions within the specified date range.
transaction_history (List[Dict]): List of transactions within the specified date range.
- type (str): Type of transaction. [Enum]: ["deposit", "withdrawal"]
- amount (float): Amount involved in the transaction.
- timestamp (str): Timestamp of the transaction, formatted as 'YYYY-MM-DD HH:MM:SS'.
"""
if not self.authenticated:
return [
Expand All @@ -558,13 +619,15 @@ def get_transaction_history(
filtered_history = [
transaction
for transaction in self.transaction_history
if start <= datetime.strptime(transaction["timestamp"], "%I:%M %p") <= end
if start
<= datetime.strptime(transaction["timestamp"], "%Y-%m-%d %H:%M:%S")
<= end
]

if self.long_context:
filtered_history.extend(TRANSACTION_HISTORY_EXTENSION)

return {"history": filtered_history}
return {"transaction_history": filtered_history}

def update_stock_price(
self, symbol: str, new_price: float
Expand Down Expand Up @@ -670,6 +733,8 @@ def notify_price_change(self, stocks: List[str], threshold: float) -> Dict[str,
]

if changed_stocks:
return {"notification": f"Stocks {', '.join(changed_stocks)} have significant price changes."}
return {
"notification": f"Stocks {', '.join(changed_stocks)} have significant price changes."
}
else:
return {"notification": "No significant price changes in the selected stocks."}
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,7 @@ def get_flight_cost(
("ATV", "GFD"): 240,
("PHV", "GFD"): 220,
("LHR", "CDG"): 100,
("OKD", "LAX"): 220
}

# Ensure the travel_from and travel_to is a tuple in the correct order (from, to)
Expand Down Expand Up @@ -639,8 +640,8 @@ def compute_exchange_rate(
Compute the exchange rate between two currencies
Args:
base_currency (str): The base currency
target_currency (str): The target currency
base_currency (str): The base currency. [Enum]: USD, RMB, EUR, JPY, GBP, CAD, AUD, INR, RUB, BRL, MXN
target_currency (str): The target currency. [Enum]: USD, RMB, EUR, JPY, GBP, CAD, AUD, INR, RUB, BRL, MXN
value (float): The value to convert
Returns:
exchanged_value (float): The value after the exchange
Expand All @@ -656,7 +657,7 @@ def compute_exchange_rate(
("USD", "INR"): 70,
("USD", "RUB"): 60,
("USD", "BRL"): 3.8,
("USD", "MXN"): 20,
("USD", "MXN"): 20
}
for key, val in exchange_rates.items():
if base_currency == key[0] and target_currency == key[1]:
Expand Down Expand Up @@ -744,7 +745,7 @@ def get_nearest_airport_by_city(self, location: str) -> Dict[str, str]:
Get the nearest airport to the given location
Args:
location (str): The name of the location.
location (str): The name of the location. [Enum]: Rivermist, Stonebrook, Maplecrest, Silverpine, Shadowridge, London, Paris, Sunset Valley, Oakendale, Willowbend, Crescent Hollow, Autumnville, Pinehaven, Greenfield, San Francisco, Los Angeles, New York, Chicago, Boston, Beijing, Hong Kong, Rome, Tokyo
Returns:
nearest_airport (str): The nearest airport to the given location
"""
Expand Down
Loading

0 comments on commit a79d891

Please sign in to comment.