From 07cb9e160668b47ca181bbd47619f9ce035a5412 Mon Sep 17 00:00:00 2001 From: "Huanzhi (Hans) Mao" Date: Sun, 21 Jul 2024 23:39:45 -0700 Subject: [PATCH] [BFCL] Fix language_specific_pre_processing for Java and JavaScript Test Category (#538) In our BFCL official communication channels, including the evaluation manual blog, GitHub issue replies (such as #424), and our Discord channel, we have previously stated the following: > For Java and JavaScript test category, before querying the model, we do some pre-processing on the prompt and function document. Specifically, at the end of the prompt, we will explicitly state that `the provided function is in Java 8/JavaScript/Python syntax`. And for parameter types that are not native to JSON, we will change their type to `String` (since `String` is JSON compatible) and add in the parameter description that `f" This is Java/JavaScript {value['type']} in string representation."` > As an example, when expecting type `ArrayList`, model will get the instruction that this is a `String` type parameter with the parameter description containing `"This is Java ArrayList in string representation."`, and thus the model should output the value in `String` format (eg, `"new ArrayList<>(Arrays.asList(10, 20, 30))"`), which is JSON compatible. However, the code for `language_specific_pre_processing` did not implement this correctly. Due to an indentation issue, the parameter description was only modified when the parameter type was `any`, and the part where the parameter type is cast to `String` was never implemented. This issue was unnoticed until PR #516 was merged because of the double-casting problem. It *significantly impacts* the evaluation score for the Java and JavaScript categories. We will update the leaderboard very soon. This PR: - Addresses the above issue and ensures that the evaluation logic aligns with the previously described behaviour - Updates two entries in the JavaScript dataset, due to their parameters missing a `description` field - Index: `14, 45` We sincerely apologize for the oversight. --- berkeley-function-call-leaderboard/README.md | 1 + ...illa_openfunctions_v1_test_javascript.json | 4 +- .../model_handler/claude_fc_handler.py | 2 +- .../model_handler/claude_prompt_handler.py | 4 +- .../model_handler/cohere_handler.py | 4 +- .../model_handler/databricks_handler.py | 2 +- .../model_handler/firework_ai_handler.py | 2 +- .../model_handler/gemini_handler.py | 2 +- .../model_handler/glm_handler.py | 2 +- .../model_handler/gorilla_handler.py | 2 +- .../model_handler/gpt_handler.py | 4 +- .../model_handler/granite_handler.py | 2 +- .../model_handler/mistral_handler.py | 4 +- .../model_handler/nexus_handler.py | 2 +- .../model_handler/nvidia_handler.py | 2 +- .../model_handler/oss_handler.py | 2 +- .../model_handler/utils.py | 37 ++++++++++--------- 17 files changed, 40 insertions(+), 38 deletions(-) diff --git a/berkeley-function-call-leaderboard/README.md b/berkeley-function-call-leaderboard/README.md index a53e7c6e8b..47b02710ad 100644 --- a/berkeley-function-call-leaderboard/README.md +++ b/berkeley-function-call-leaderboard/README.md @@ -208,6 +208,7 @@ Some companies have proposed some optimization strategies in their models' handl ## Changelog +* [July 21, 2024] [#538](https://github.com/ShishirPatil/gorilla/pull/538): Fix `language_specific_pre_processing` function to properly handle pre-processing for prompts and function docs in Java and JavaScript test categories. All entries in these categories are affected. * [July 16, 2024] [#525](https://github.com/ShishirPatil/gorilla/pull/525), [#536](https://github.com/ShishirPatil/gorilla/pull/536): Add new model `ibm-granite/granite-20b-functioncalling` to the leaderboard. * [July 10, 2024] [#522](https://github.com/ShishirPatil/gorilla/pull/522): Bug fix in the evaluation dataset for Executable Parallel Multiple category. This includes updates to both prompts and function docs. 2 entries are affected. * [July 8, 2024] [#516](https://github.com/ShishirPatil/gorilla/pull/516): Fix double-casting issue in `model_handler` for Java and JavaScript test categories. diff --git a/berkeley-function-call-leaderboard/data/gorilla_openfunctions_v1_test_javascript.json b/berkeley-function-call-leaderboard/data/gorilla_openfunctions_v1_test_javascript.json index 6264b4c286..90a2edf0cf 100644 --- a/berkeley-function-call-leaderboard/data/gorilla_openfunctions_v1_test_javascript.json +++ b/berkeley-function-call-leaderboard/data/gorilla_openfunctions_v1_test_javascript.json @@ -12,7 +12,7 @@ {"question": "How can I sort a list of items myItemList alphabetically and ascendingly, but place items with a status of 'urgent' at the top, assuming the list is an array of objects with 'name' and 'status' properties?", "function": {"name": "prioritizeAndSort", "description": "This function sorts an array of objects based on their 'name' property, while prioritizing items based on a specified status.", "parameters": {"type": "dict", "properties": {"items": {"type": "array", "items": {"type": "String"}, "description": "The array of objects to be sorted."}, "priorityStatus": {"type": "String", "description": "The status value that should be given priority in the sorting."}, "ascending": {"type": "Boolean", "description": "A flag indicating whether the sorting should be in ascending (true) or descending (false) order, excluding priority items."}}, "required": ["items", "priorityStatus", "ascending"]}}} {"question": "How can I implement a 'dataFetch' operation with an API endpoint URL of 'https://api.example.com/data', expecting the response to be a JSON object containing '{\"key\": \"value\"}', given a request configuration object '{\"method\": \"GET\"}'?", "function": {"name": "performDataFetch", "description": "This function fetches data from a specified API endpoint using the provided request configuration, checks the response against an expected JSON object, and handles any potential errors. It supports various request methods like GET or POST.", "parameters": {"type": "dict", "properties": {"apiEndpoint": {"type": "String", "description": "The URL of the API endpoint from which the data will be fetched."}, "requestConfig": {"type": "dict", "properties": {"method": {"type": "String", "description": "The HTTP method to be used for the request."}, "headers": {"type": "dict", "description": "Any headers to be included in the request."}, "body": {"type": "String", "description": "The request payload, if needed for methods like POST."}}, "description": "The configuration object for the API request."}, "expectedResponse": {"type": "dict", "description": "The JSON object expected to be returned by the API call."}, "handleErrors": {"type": "Boolean", "description": "If true, the function will handle errors gracefully and provide appropriate feedback. Default false"}}, "required": ["apiEndpoint", "requestConfig", "expectedResponse"]}}} {"question": "How can I generate a dynamic chart with user-provided data `userDataArray` and apply a scaling factor of 3 for the axis values, linking it to a given dashboard `dashboardElement`?", "function": {"name": "DynamicChartGenerator", "description": "This function creates a dynamic chart based on user input, applies a scaling factor to the axis values, and integrates the chart into a specified dashboard for display.", "parameters": {"type": "dict", "properties": {"userData": {"type": "array", "items": {"type": "String"}, "description": "The data provided by the user to plot on the chart."}, "scalingFactor": {"type": "float", "description": "A scaling factor applied to the chart's axis values. Optional parameter."}, "dashboard": {"type": "any", "description": "The dashboard where the chart will be displayed."}, "options": {"type": "dict", "description": "Additional configuration options for the chart. Default empty dict"}}, "required": ["userData", "scalingFactor", "dashboard"]}}} -{"question": "How can I generate a data accessor for a chart component named 'BarChart', with a module name 'chartModule', in a data visualization library `visualizationLibrary`, to fetch and update its 'DataPoints' and 'Labels' through a configuration object named 'config'?", "function": {"name": "chartDataAccessorFactory", "description": "This function generates a data accessor for a specific chart component within a data visualization librar `. It provides the capability to fetch and update specific properties such as 'DataPoints' and 'Labels' of the chart through a configuration object.", "parameters": {"type": "dict", "properties": {"chart": {"type": "dict", "properties": {"nm": {"type": "String", "description": "The name of the chart component."}, "mn": {"type": "String", "description": "The module name of the chart component."}}, "required": ["nm", "mn"]}, "library": {"type": "any", "description": "The instance of the data visualization library where the chart component is defined."}, "configObject": {"type": "String", "description": "The name of the configuration object used to fetch and update the chart's properties."}}, "required": ["chart", "library", "configObject"]}}} +{"question": "How can I generate a data accessor for a chart component named 'BarChart', with a module name 'chartModule', in a data visualization library `visualizationLibrary`, to fetch and update its 'DataPoints' and 'Labels' through a configuration object named 'config'?", "function": {"name": "chartDataAccessorFactory", "description": "This function generates a data accessor for a specific chart component within a data visualization librar `. It provides the capability to fetch and update specific properties such as 'DataPoints' and 'Labels' of the chart through a configuration object.", "parameters": {"type": "dict", "properties": {"chart": {"type": "dict", "properties": {"nm": {"type": "String", "description": "The name of the chart component."}, "mn": {"type": "String", "description": "The module name of the chart component."}}, "description": "The details of the chart component.", "required": ["nm", "mn"]}, "library": {"type": "any", "description": "The instance of the data visualization library where the chart component is defined."}, "configObject": {"type": "String", "description": "The name of the configuration object used to fetch and update the chart's properties."}}, "required": ["chart", "library", "configObject"]}}} {"question": "How can I generate a new ChartSeries with initial settings including axis labels `axisLabelsArray`, data points `dataPointsArray`, and a default color scheme `defaultColor`, and then integrate it into a specific chart layout `chartLayoutObject`?", "function": {"name": "ChartSeriesGenerator", "description": "This function creates a new ChartSeries with customizable settings for axis labels, data points, and color schemes, and attaches it to a given chart layout.", "parameters": {"type": "dict", "properties": {"labels": {"type": "array", "items": {"type": "String"}, "description": "The labels for the chart's axis."}, "data": {"type": "array", "items": {"type": "String"}, "description": "The data points for the series."}, "color": {"type": "String", "description": "The default color for the series. Optional parameter."}, "chartLayout": {"type": "dict", "description": "The layout object of the chart where the series will be added."}}, "required": ["labels", "data", "chartLayout"]}}} {"question": "How do I compute the updated coordinates for a set of vertices (10, 15) and (20, 25) after rotating them around a pivot point (12, 17) by 30 degrees?", "function": {"name": "rotateVertices", "description": "This function computes the updated coordinates of a set of vertices after rotating them around a pivot point by a given angle.", "parameters": {"type": "dict", "properties": {"vertices": {"type": "array", "items": {"type": "float"}, "description": "An array of vertices to rotate, where each vertex is in the format [x, y]."}, "pivot": {"type": "array", "items": {"type": "float"}, "description": "The pivot point around which the vertices are to be rotated, in the format [x, y]."}, "angle": {"type": "float", "description": "The rotation angle in degrees."}}, "required": ["vertices", "pivot", "angle"]}}} {"question": "How can I generate a notification handler for an application `app` that filters messages based on priority level 3, linked to a messaging service 'messagingSvc', and categorized under notification type 2?", "function": {"name": "generateNotificationHandler", "description": "This function generates a notification handler for an application, which can filter incoming messages by priority level. It can also be linked to a specific messaging service and categorized under a certain notification type.", "parameters": {"type": "dict", "properties": {"app": {"type": "any", "description": "The application for which to generate the notification handler."}, "priorityLevel": {"type": "integer", "description": "The priority level to filter messages. A certain level (e.g., 3) may determine the filtering criteria."}, "messagingService": {"type": "any", "description": "The messaging service associated with the notification handler."}, "notificationType": {"type": "integer", "description": "The notification type category for the handler."}}, "required": ["app", "priorityLevel", "messagingService", "notificationType"]}}} @@ -43,7 +43,7 @@ {"question": "How can I create a task queue with a concurrency of 5, where tasks are functions that log a message to the console, and ensure that when the queue becomes saturated, it logs 'Queue is saturated', and when it becomes unsaturated, it logs 'Queue is unsaturated'?", "function": {"name": "B", "description": "This complex function initializes a task queue with customizable concurrency, task addition, and event handling capabilities. It allows for synchronous and asynchronous task execution, pausing and resuming the queue, and handling various queue events.", "parameters": {"type": "dict", "properties": {"e": {"type": "any", "description": "The initial task or an array of tasks to be added to the queue. Default null"}, "t": {"type": "float", "description": "The concurrency level of the task queue."}, "n": {"type": "float", "description": "The payload size for each task worker. Optional parameter. Default 0.0"}}, "required": ["t"]}}} {"question": "How can I execute a callback function named 'processResult' that handles an error 'null' and a result value of 'Operation successful'?", "function": {"name": "invokeCallback", "description": "This function invokes a callback with an error and a value. If the callback throws an error, it is caught and re-thrown asynchronously.", "parameters": {"type": "dict", "properties": {"callback": {"type": "any", "description": "The callback function to be invoked."}, "error": {"type": "any", "description": "The error to pass to the callback function. Can be 'null' if there is no error."}, "value": {"type": "any", "description": "The value to pass to the callback function."}}, "required": ["callback", "error", "value"]}}} {"question": "How can I execute a custom callback function named 'processNode' on a specific node named 'currentNode' with a state object 'nodeState' during a tree traversal?", "function": {"name": "skipThrough", "description": "This function allows for a custom operation to be performed on a node during a tree traversal by executing a callback function with the node and a state object as arguments.", "parameters": {"type": "dict", "properties": {"node": {"type": "any", "description": "The current node being processed in the tree traversal."}, "st": {"type": "any", "description": "The state object associated with the current node."}, "c": {"type": "any", "description": "The callback function to be executed on the current node and state object."}}, "required": ["node", "st", "c"]}}} -{"question": "How can I asynchronously retrieve a map of remote Git references and their corresponding commit hashes for a repository URL 'https://github.com/yarnpkg/berry' from a starting directory '/home/user/projects'?", "function": {"name": "Sde", "description": "This asynchronous function retrieves a map of remote Git references and their corresponding commit hashes for a given repository URL, using a specified starting directory.", "parameters": {"type": "dict", "properties": {"t": {"type": "String", "description": "The repository URL."}, "e": {"type": "dict", "properties": {"startingCwd": {"type": "String", "description": "The starting directory from which the Git command is executed."}, "configuration": {"type": "dict", "description": "Additional configuration for the Git command."}}, "required": ["startingCwd"]}}, "required": ["t", "e"]}}} +{"question": "How can I asynchronously retrieve a map of remote Git references and their corresponding commit hashes for a repository URL 'https://github.com/yarnpkg/berry' from a starting directory '/home/user/projects'?", "function": {"name": "Sde", "description": "This asynchronous function retrieves a map of remote Git references and their corresponding commit hashes for a given repository URL, using a specified starting directory.", "parameters": {"type": "dict", "properties": {"t": {"type": "String", "description": "The repository URL."}, "e": {"type": "dict", "properties": {"startingCwd": {"type": "String", "description": "The starting directory from which the Git command is executed."}, "configuration": {"type": "dict", "description": "Additional configuration for the Git command."}}, "description": "The execution context for the Git command.", "required": ["startingCwd"]}}, "required": ["t", "e"]}}} {"question": "How can I update the property 'version' of an object named 'packageInfo' to '1.2.3', ensuring the update only occurs if the new value differs from the existing one or if 'version' is not already a property of the object?", "function": {"name": "vOe", "description": "This function updates a property of an object to a new value, but only if the new value is different from the existing one or if the property does not already exist on the object.", "parameters": {"type": "dict", "properties": {"r": {"type": "any", "description": "The object to update."}, "e": {"type": "String", "description": "The property of the object to update."}, "t": {"type": "any", "description": "The new value to assign to the property."}}, "required": ["r", "e", "t"]}}} {"question": "How can I calculate the difference in days between the dates '2023-04-01' and '2023-04-15' using a specific time unit of 'days'?", "function": {"name": "sTe", "description": "This function calculates the difference between two dates in a specified time unit.", "parameters": {"type": "dict", "properties": {"r": {"type": "String", "description": "The start date for the calculation."}, "e": {"type": "String", "description": "The end date for the calculation."}, "t": {"type": "String", "description": "The unit of time to calculate the difference in. For example, 'days', 'hours', etc."}}, "required": ["r", "e", "t"]}}} {"question": "How can I update the DOM event listeners from an old virtual node oldVirtualNode to a new one newVirtualNode, considering the new virtual node has a click event that needs to be normalized and updated?", "function": {"name": "updateDOMListeners", "description": "This function updates the DOM event listeners from an old virtual node to a new one, ensuring that any changes in event listeners are properly handled and applied to the target element.", "parameters": {"type": "dict", "properties": {"oldVnode": {"type": "any", "description": "The old virtual node, containing data about previous event listeners."}, "vnode": {"type": "any", "description": "The new virtual node, containing data about current event listeners."}}, "required": ["oldVnode", "vnode"]}}} diff --git a/berkeley-function-call-leaderboard/model_handler/claude_fc_handler.py b/berkeley-function-call-leaderboard/model_handler/claude_fc_handler.py index 384e177fdb..c1b974796e 100644 --- a/berkeley-function-call-leaderboard/model_handler/claude_fc_handler.py +++ b/berkeley-function-call-leaderboard/model_handler/claude_fc_handler.py @@ -30,7 +30,7 @@ def inference(self, prompt, functions, test_category): return handler.inference(prompt, functions, test_category) else: prompt = augment_prompt_by_languge(prompt, test_category) - functions = language_specific_pre_processing(functions, test_category, True) + functions = language_specific_pre_processing(functions, test_category) if type(functions) is not list: functions = [functions] claude_tool = convert_to_tool( diff --git a/berkeley-function-call-leaderboard/model_handler/claude_prompt_handler.py b/berkeley-function-call-leaderboard/model_handler/claude_prompt_handler.py index 04ab78ef23..40c88bf335 100644 --- a/berkeley-function-call-leaderboard/model_handler/claude_prompt_handler.py +++ b/berkeley-function-call-leaderboard/model_handler/claude_prompt_handler.py @@ -82,7 +82,7 @@ def _get_claude_function_calling_response(self, prompt, functions, test_category def inference(self, prompt, functions, test_category): prompt = augment_prompt_by_languge(prompt, test_category) if "FC" in self.model_name: - functions = language_specific_pre_processing(functions, test_category, True) + functions = language_specific_pre_processing(functions, test_category) result, metadata = self._get_claude_function_calling_response( prompt, functions, test_category ) @@ -90,7 +90,7 @@ def inference(self, prompt, functions, test_category): else: start = time.time() functions = language_specific_pre_processing( - functions, test_category, False + functions, test_category ) response = self.client.messages.create( model=self.model_name, diff --git a/berkeley-function-call-leaderboard/model_handler/cohere_handler.py b/berkeley-function-call-leaderboard/model_handler/cohere_handler.py index 2bd5dc5dd5..b699575eac 100644 --- a/berkeley-function-call-leaderboard/model_handler/cohere_handler.py +++ b/berkeley-function-call-leaderboard/model_handler/cohere_handler.py @@ -52,7 +52,7 @@ def inference(self, prompt, functions, test_category): if "FC" not in self.model_name: prompt = augment_prompt_by_languge(prompt, test_category) functions = language_specific_pre_processing( - functions, test_category, False + functions, test_category ) message = USER_PROMPT_FOR_CHAT_MODEL.format( user_prompt=prompt, functions=str(functions) @@ -69,7 +69,7 @@ def inference(self, prompt, functions, test_category): result = response.text else: prompt = augment_prompt_by_languge(prompt, test_category) - functions = language_specific_pre_processing(functions, test_category, True) + functions = language_specific_pre_processing(functions, test_category) if type(functions) is not list: functions = [functions] message = prompt diff --git a/berkeley-function-call-leaderboard/model_handler/databricks_handler.py b/berkeley-function-call-leaderboard/model_handler/databricks_handler.py index fa1201c6a4..53474ca911 100644 --- a/berkeley-function-call-leaderboard/model_handler/databricks_handler.py +++ b/berkeley-function-call-leaderboard/model_handler/databricks_handler.py @@ -26,7 +26,7 @@ def __init__(self, model_name, temperature=0.7, top_p=1, max_tokens=1000) -> Non ) def inference(self, prompt, functions, test_category): - functions = language_specific_pre_processing(functions, test_category, False) + functions = language_specific_pre_processing(functions, test_category) if type(functions) is not list: functions = [functions] message = [ diff --git a/berkeley-function-call-leaderboard/model_handler/firework_ai_handler.py b/berkeley-function-call-leaderboard/model_handler/firework_ai_handler.py index 74895ef73a..dbaa1592b4 100644 --- a/berkeley-function-call-leaderboard/model_handler/firework_ai_handler.py +++ b/berkeley-function-call-leaderboard/model_handler/firework_ai_handler.py @@ -34,7 +34,7 @@ def write(self, result, file_to_open): f.write(json.dumps(result) + "\n") def inference(self, prompt, functions, test_category): - functions = language_specific_pre_processing(functions, test_category, True) + functions = language_specific_pre_processing(functions, test_category) if type(functions) is not list: functions = [functions] message = [{"role": "user", "content": prompt}] diff --git a/berkeley-function-call-leaderboard/model_handler/gemini_handler.py b/berkeley-function-call-leaderboard/model_handler/gemini_handler.py index f0ae39586b..8f21ebe02e 100644 --- a/berkeley-function-call-leaderboard/model_handler/gemini_handler.py +++ b/berkeley-function-call-leaderboard/model_handler/gemini_handler.py @@ -95,7 +95,7 @@ def _query_gemini(self, user_query, functions): def inference(self, prompt, functions, test_category): prompt = augment_prompt_by_languge(prompt, test_category) - functions = language_specific_pre_processing(functions, test_category, True) + functions = language_specific_pre_processing(functions, test_category) gemini_tool = convert_to_tool( functions, GORILLA_TO_OPENAPI, self.model_style, test_category, True ) diff --git a/berkeley-function-call-leaderboard/model_handler/glm_handler.py b/berkeley-function-call-leaderboard/model_handler/glm_handler.py index c8f0481409..e8bdb63f69 100644 --- a/berkeley-function-call-leaderboard/model_handler/glm_handler.py +++ b/berkeley-function-call-leaderboard/model_handler/glm_handler.py @@ -88,7 +88,7 @@ def inference(self, test_question, test_category, num_gpus): for line in test_question: prompt = augment_prompt_by_languge(line["question"], test_category) function = language_specific_pre_processing( - line["function"], test_category, False + line["function"], test_category ) chat_template_ques_jsons.append( self.apply_chat_template(prompt, function, test_category) diff --git a/berkeley-function-call-leaderboard/model_handler/gorilla_handler.py b/berkeley-function-call-leaderboard/model_handler/gorilla_handler.py index 70fe0e54a0..91a640c0a5 100644 --- a/berkeley-function-call-leaderboard/model_handler/gorilla_handler.py +++ b/berkeley-function-call-leaderboard/model_handler/gorilla_handler.py @@ -43,7 +43,7 @@ def _get_gorilla_response(self, prompt, functions): def inference(self, prompt, functions, test_category): prompt = augment_prompt_by_languge(prompt, test_category) - functions = language_specific_pre_processing(functions, test_category, False) + functions = language_specific_pre_processing(functions, test_category) if type(functions) is not list: functions = [functions] try: diff --git a/berkeley-function-call-leaderboard/model_handler/gpt_handler.py b/berkeley-function-call-leaderboard/model_handler/gpt_handler.py index e4dd96cd80..3c1618a5be 100644 --- a/berkeley-function-call-leaderboard/model_handler/gpt_handler.py +++ b/berkeley-function-call-leaderboard/model_handler/gpt_handler.py @@ -26,7 +26,7 @@ def __init__(self, model_name, temperature=0.7, top_p=1, max_tokens=1000) -> Non def inference(self, prompt,functions,test_category): if "FC" not in self.model_name: prompt = augment_prompt_by_languge(prompt,test_category) - functions = language_specific_pre_processing(functions,test_category,False) + functions = language_specific_pre_processing(functions,test_category) message = [ { "role": "system", @@ -51,7 +51,7 @@ def inference(self, prompt,functions,test_category): result = response.choices[0].message.content else: prompt = augment_prompt_by_languge(prompt, test_category) - functions = language_specific_pre_processing(functions, test_category, True) + functions = language_specific_pre_processing(functions, test_category) if type(functions) is not list: functions = [functions] message = [{"role": "user", "content": prompt}] diff --git a/berkeley-function-call-leaderboard/model_handler/granite_handler.py b/berkeley-function-call-leaderboard/model_handler/granite_handler.py index f0a624d350..959f9ec563 100644 --- a/berkeley-function-call-leaderboard/model_handler/granite_handler.py +++ b/berkeley-function-call-leaderboard/model_handler/granite_handler.py @@ -32,7 +32,7 @@ def _format_prompt(prompt, function, test_category): if language_specific_prompt_augmented_str.strip(): prompt = prompt.replace(language_specific_prompt_augmented_str, "") - functions = language_specific_pre_processing(function, test_category, False) + functions = language_specific_pre_processing(function, test_category) functions = convert_to_tool( functions, GORILLA_TO_OPENAPI, diff --git a/berkeley-function-call-leaderboard/model_handler/mistral_handler.py b/berkeley-function-call-leaderboard/model_handler/mistral_handler.py index 5283d63900..7fc263559b 100644 --- a/berkeley-function-call-leaderboard/model_handler/mistral_handler.py +++ b/berkeley-function-call-leaderboard/model_handler/mistral_handler.py @@ -27,7 +27,7 @@ def __init__(self, model_name, temperature=0.7, top_p=1, max_tokens=1000) -> Non def inference(self, prompt, functions, test_category): prompt = augment_prompt_by_languge(prompt, test_category) if "FC" in self.model_name: - functions = language_specific_pre_processing(functions, test_category, True) + functions = language_specific_pre_processing(functions, test_category) tool = convert_to_tool( functions, GORILLA_TO_OPENAPI, self.model_style, test_category, True ) @@ -57,7 +57,7 @@ def inference(self, prompt, functions, test_category): result = chat_response.choices[0].message.content else: functions = language_specific_pre_processing( - functions, test_category, False + functions, test_category ) message = [ ChatMessage(role="system", content=SYSTEM_PROMPT_FOR_CHAT_MODEL), diff --git a/berkeley-function-call-leaderboard/model_handler/nexus_handler.py b/berkeley-function-call-leaderboard/model_handler/nexus_handler.py index 5dfa8ecdbe..18b538c233 100644 --- a/berkeley-function-call-leaderboard/model_handler/nexus_handler.py +++ b/berkeley-function-call-leaderboard/model_handler/nexus_handler.py @@ -137,7 +137,7 @@ def query(payload): def inference(self, prompt, functions, test_category): prompt = augment_prompt_by_languge(prompt, test_category) - functions = language_specific_pre_processing(functions, test_category, False) + functions = language_specific_pre_processing(functions, test_category) raven_prompt = self._format_raven_function(prompt, functions) result, metadata = self._query_raven(raven_prompt) return result, metadata diff --git a/berkeley-function-call-leaderboard/model_handler/nvidia_handler.py b/berkeley-function-call-leaderboard/model_handler/nvidia_handler.py index dc49b794b8..99b271b451 100644 --- a/berkeley-function-call-leaderboard/model_handler/nvidia_handler.py +++ b/berkeley-function-call-leaderboard/model_handler/nvidia_handler.py @@ -25,7 +25,7 @@ def __init__(self, model_name, temperature=0.7, top_p=1, max_tokens=1000) -> Non ) def inference(self, prompt, functions, test_category): prompt = augment_prompt_by_languge(prompt,test_category) - functions = language_specific_pre_processing(functions,test_category,False) + functions = language_specific_pre_processing(functions,test_category) message = [ { "role": "system", diff --git a/berkeley-function-call-leaderboard/model_handler/oss_handler.py b/berkeley-function-call-leaderboard/model_handler/oss_handler.py index defe98af9b..f49328c53c 100644 --- a/berkeley-function-call-leaderboard/model_handler/oss_handler.py +++ b/berkeley-function-call-leaderboard/model_handler/oss_handler.py @@ -59,7 +59,7 @@ def _batch_generate( ques_json = line prompt = augment_prompt_by_languge(ques_json["question"], test_category) functions = language_specific_pre_processing( - ques_json["function"], test_category, False + ques_json["function"], test_category ) prompts.append(format_prompt_func(prompt, functions, test_category)) ans_id = shortuuid.uuid() diff --git a/berkeley-function-call-leaderboard/model_handler/utils.py b/berkeley-function-call-leaderboard/model_handler/utils.py index 4844f9fccb..0cc6ff1d45 100644 --- a/berkeley-function-call-leaderboard/model_handler/utils.py +++ b/berkeley-function-call-leaderboard/model_handler/utils.py @@ -335,13 +335,13 @@ def augment_prompt_by_languge(prompt, test_category): if test_category == "java": prompt = prompt + "\n Note that the provided function is in Java 8 SDK syntax." elif test_category == "javascript": - prompt = prompt + "\n Note that the provided function is in JavaScript." + prompt = prompt + "\n Note that the provided function is in JavaScript syntax." else: - prompt = prompt + "\n Note that the provided function is in Python." + prompt = prompt + "\n Note that the provided function is in Python 3 syntax." return prompt -def language_specific_pre_processing(function, test_category, string_param): +def language_specific_pre_processing(function, test_category): if type(function) is dict: function = [function] if len(function) == 0: @@ -350,27 +350,28 @@ def language_specific_pre_processing(function, test_category, string_param): properties = item["parameters"]["properties"] if test_category == "java": for key, value in properties.items(): - if value["type"] == "Any" or value["type"] == "any": - properties[key][ - "description" - ] += "This parameter can be of any type of Java object." + if value["type"] == "any": properties[key]["description"] += ( - "This is Java" + value["type"] + " in string representation." + " This parameter can be of any type of Java object in string representation." ) + else: + value["description"] += ( + f" This is Java {value['type']} in string representation." + ) + value["type"] = "string" + elif test_category == "javascript": for key, value in properties.items(): - if value["type"] == "Any" or value["type"] == "any": - properties[key][ - "description" - ] += "This parameter can be of any type of Javascript object." - else: - if "description" not in properties[key]: - properties[key]["description"] = "" + if value["type"] == "any": properties[key]["description"] += ( - "This is Javascript " - + value["type"] - + " in string representation." + " This parameter can be of any type of JavaScript object." + ) + else: + value["description"] += ( + f" This is JavaScript {value['type']} in string representation." ) + value["type"] = "string" + return function