Skip to content

Commit

Permalink
[BFCL] Speed Up Locally-hosted Model Inference Process (ShishirPatil#671
Browse files Browse the repository at this point in the history
)

Fix ShishirPatil#649 

Instead of send requests to the vllm server one by one in sequence, we
should send all requests all at once to vllm to utiliza its batching and
optimizaiton benefits.

Tested on 8 x A100 (40G) with Llama 3.1 70B. The inference speed on
single-turn entries are roughtly the same (within 1 minute difference)
as when using `llm.generate` before the BFCL V3 release in ShishirPatil#644]. The
multi-turn entries still takes around 2 hours to complete, but that's
largely due to the nature of the multi-turn dataset; it has been much
faster than previously where it would take 2 days to finish.

This PR **will not** affect the leaderboard score.
  • Loading branch information
HuanzhiMao authored and VishnuSuresh27 committed Nov 11, 2024
1 parent 2abbfd6 commit d9c0835
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 20 deletions.
1 change: 1 addition & 0 deletions berkeley-function-call-leaderboard/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

All notable changes to the Berkeley Function Calling Leaderboard will be documented in this file.

- [Oct 4, 2024] [#671](https://github.com/ShishirPatil/gorilla/pull/671): Speed up locally-hosted model's inference process by parallelizing the inference requests.
- [Sept 27, 2024] [#640](https://github.com/ShishirPatil/gorilla/pull/640): Add the following new models to the leaderboard:
- `microsoft/Phi-3.5-mini-instruct`
- `microsoft/Phi-3-medium-128k-instruct`
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import subprocess
import time
from concurrent.futures import ThreadPoolExecutor

import requests
from bfcl.model_handler.base_handler import BaseHandler
Expand Down Expand Up @@ -46,12 +47,16 @@ def decode_execute(self, result):
return default_decode_execute_prompting(result)

def batch_inference(
self, test_entries: list[dict], num_gpus: int, gpu_memory_utilization: float, include_debugging_log: bool
self,
test_entries: list[dict],
num_gpus: int,
gpu_memory_utilization: float,
include_debugging_log: bool,
):
"""
Batch inference for OSS models.
"""

process = subprocess.Popen(
[
"vllm",
Expand Down Expand Up @@ -102,24 +107,22 @@ def batch_inference(
process.stderr = subprocess.DEVNULL

# Once the server is ready, make the completion requests
for test_entry in tqdm(test_entries, desc="Generating results"):
try:
if "multi_turn" in test_entry["id"]:
model_responses, metadata = self.inference_multi_turn_prompting(test_entry, include_debugging_log)
else:
model_responses, metadata = self.inference_single_turn_prompting(test_entry, include_debugging_log)
except Exception as e:
print(f"Error during inference for test entry {test_entry['id']}: {str(e)}")
model_responses = f"Error during inference: {str(e)}"
metadata = {}

result_to_write = {
"id": test_entry["id"],
"result": model_responses,
}
result_to_write.update(metadata)

self.write(result_to_write)
futures = []
with ThreadPoolExecutor(max_workers=100) as executor:
with tqdm(
total=len(test_entries),
desc=f"Generating results for {self.model_name}",
) as pbar:

for test_case in test_entries:
future = executor.submit(self._multi_threaded_inference, test_case, include_debugging_log)
futures.append(future)

for future in futures:
# This will wait for the task to complete, so that we are always writing in order
result = future.result()
self.write(result)
pbar.update()

except Exception as e:
raise e
Expand All @@ -136,6 +139,32 @@ def batch_inference(
process.wait() # Wait again to ensure it's fully terminated
print("Process killed.")

def _multi_threaded_inference(self, test_case, include_debugging_log):
"""
This is a wrapper function to make sure that, if an error occurs during inference, the process does not stop.
"""
assert type(test_case["function"]) is list

try:
if "multi_turn" in test_case["id"]:
model_responses, metadata = self.inference_multi_turn_prompting(test_case, include_debugging_log)
else:
model_responses, metadata = self.inference_single_turn_prompting(test_case, include_debugging_log)
except Exception as e:
print("-" * 100)
print(
"❗️❗️ Error occurred during inference. Maximum reties reached for rate limit or other error. Continuing to next test case."
)
print(f"❗️❗️ Test case ID: {test_case['id']}, Error: {str(e)}")
print("-" * 100)

model_responses = f"Error during inference: {str(e)}"

return {
"id": test_case["id"],
"result": model_responses,
}

#### Prompting methods ####

def _query_prompting(self, inference_data: dict):
Expand Down

0 comments on commit d9c0835

Please sign in to comment.