Skip to content

Commit

Permalink
[BFCL] Fix Hanging Inference for OSS Models on GPU Platforms (Shishir…
Browse files Browse the repository at this point in the history
…Patil#663)

This PR addresses issues encountered when running locally-hosted models
on GPU-renting platforms (e.g., Lambda Cloud). Specifically, there were
problems with output display from `vllm` due to the use of subprocesses
for launching these models. Additionally, some multi-turn functions
(such as `xargs`) rely on subprocesses, which caused inference on
certain test entries (such as `multi_turn_36 `) to hang indefinitely,
resulting in an undesirable pipeline halt.

To fix this, the terminal logging logic has been updated to utilize a
separate thread for reading from the subprocess pipe and printing to the
terminal.

Alos, for readability, the `_format_prompt` function has been moved to
the `Prompting methods` section; this would not change the leaderboard
score.
  • Loading branch information
HuanzhiMao authored and VishnuSuresh27 committed Nov 11, 2024
1 parent d9c0835 commit e110fbc
Showing 1 changed file with 38 additions and 10 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import subprocess
import threading
import time
from concurrent.futures import ThreadPoolExecutor

Expand Down Expand Up @@ -35,11 +36,6 @@ def inference(self, test_entry: dict, include_debugging_log: bool):
"OSS Models should call the batch_inference method instead."
)

def _format_prompt(self, messages, function):
raise NotImplementedError(
"OSS Models should implement their own prompt formatting."
)

def decode_ast(self, result, language="Python"):
return default_decode_ast_prompting(result, language)

Expand Down Expand Up @@ -77,6 +73,30 @@ def batch_inference(
text=True, # To get the output as text instead of bytes
)

stop_event = (
threading.Event()
) # Event to signal threads to stop; no need to see vllm logs after server is ready

def log_subprocess_output(pipe, stop_event):
# Read lines until stop event is set
for line in iter(pipe.readline, ""):
if stop_event.is_set():
break
else:
print(line, end="")
pipe.close()
print("vllm server log tracking thread stopped successfully.")

# Start threads to read and print stdout and stderr
stdout_thread = threading.Thread(
target=log_subprocess_output, args=(process.stdout, stop_event)
)
stderr_thread = threading.Thread(
target=log_subprocess_output, args=(process.stderr, stop_event)
)
stdout_thread.start()
stderr_thread.start()

try:
# Wait for the server to be ready
server_ready = False
Expand All @@ -100,11 +120,8 @@ def batch_inference(
# If the connection is not ready, wait and try again
time.sleep(1)

# After the server is ready, stop capturing the output, otherwise the terminal looks messy
process.stdout.close()
process.stderr.close()
process.stdout = subprocess.DEVNULL
process.stderr = subprocess.DEVNULL
# Signal threads to stop reading output
stop_event.set()

# Once the server is ready, make the completion requests
futures = []
Expand All @@ -124,6 +141,7 @@ def batch_inference(
self.write(result)
pbar.update()


except Exception as e:
raise e

Expand All @@ -139,6 +157,11 @@ def batch_inference(
process.wait() # Wait again to ensure it's fully terminated
print("Process killed.")

# Wait for the output threads to finish
stop_event.set()
stdout_thread.join()
stderr_thread.join()

def _multi_threaded_inference(self, test_case, include_debugging_log):
"""
This is a wrapper function to make sure that, if an error occurs during inference, the process does not stop.
Expand Down Expand Up @@ -167,6 +190,11 @@ def _multi_threaded_inference(self, test_case, include_debugging_log):

#### Prompting methods ####

def _format_prompt(self, messages, function):
raise NotImplementedError(
"OSS Models should implement their own prompt formatting."
)

def _query_prompting(self, inference_data: dict):
# We use the OpenAI Completions API with vLLM
function: list[dict] = inference_data["function"]
Expand Down

0 comments on commit e110fbc

Please sign in to comment.