Skip to content

Commit

Permalink
Use a single workspace for flashinfer (#1077)
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy authored Aug 15, 2024
1 parent 6767e22 commit 326df4b
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 18 deletions.
2 changes: 1 addition & 1 deletion benchmark/gsm8k/bench_sglang.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def main(args):
@sgl.function
def few_shot_gsm8k(s, question):
s += few_shot_examples + question
s += sgl.gen("answer", max_tokens=512, stop="Question")
s += sgl.gen("answer", max_tokens=512, stop=["Question", "Assistant:"])

#####################################
########## SGL Program End ##########
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/global_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(self):
# Runtime constants: others
self.num_continue_decode_steps = 10
self.retract_decode_steps = 20
self.flashinfer_workspace_size = 192 * 1024 * 1024
self.flashinfer_workspace_size = 384 * 1024 * 1024

# Output tokenization configs
self.skip_special_tokens_in_output = True
Expand Down
12 changes: 6 additions & 6 deletions python/sglang/srt/model_executor/cuda_graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,13 +120,13 @@ def __init__(self, model_runner, max_batch_size_to_capture, use_torch_compile):
)
if model_runner.sliding_window_size is None:
self.flashinfer_workspace_buffer = (
self.model_runner.flashinfer_workspace_buffers[0]
self.model_runner.flashinfer_workspace_buffer
)
else:
self.flashinfer_workspace_buffers = [
self.model_runner.flashinfer_workspace_buffers[0],
self.model_runner.flashinfer_workspace_buffers[2],
]
self.flashinfer_workspace_buffer = (
self.model_runner.flashinfer_workspace_buffer
)

self.flashinfer_kv_indptr = [
self.flashinfer_kv_indptr,
self.flashinfer_kv_indptr.clone(),
Expand Down Expand Up @@ -200,7 +200,7 @@ def capture_one_batch_size(self, bs, forward):
for i in range(2):
flashinfer_decode_wrapper.append(
BatchDecodeWithPagedKVCacheWrapper(
self.flashinfer_workspace_buffers[i],
self.flashinfer_workspace_buffer,
"NHD",
use_cuda_graph=True,
use_tensor_cores=use_tensor_cores,
Expand Down
16 changes: 7 additions & 9 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,28 +318,26 @@ def init_flashinfer(self):
use_tensor_cores = False

if self.sliding_window_size is None:
self.flashinfer_workspace_buffers = torch.empty(
2,
self.flashinfer_workspace_buffer = torch.empty(
global_config.flashinfer_workspace_size,
dtype=torch.uint8,
device="cuda",
)
self.flashinfer_prefill_wrapper_ragged = (
BatchPrefillWithRaggedKVCacheWrapper(
self.flashinfer_workspace_buffers[0], "NHD"
self.flashinfer_workspace_buffer, "NHD"
)
)
self.flashinfer_prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
self.flashinfer_workspace_buffers[1], "NHD"
self.flashinfer_workspace_buffer, "NHD"
)
self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
self.flashinfer_workspace_buffers[0],
self.flashinfer_workspace_buffer,
"NHD",
use_tensor_cores=use_tensor_cores,
)
else:
self.flashinfer_workspace_buffers = torch.empty(
4,
global_config.flashinfer_workspace_size,
dtype=torch.uint8,
device="cuda",
Expand All @@ -350,17 +348,17 @@ def init_flashinfer(self):
for i in range(2):
self.flashinfer_prefill_wrapper_ragged.append(
BatchPrefillWithRaggedKVCacheWrapper(
self.flashinfer_workspace_buffers[2 * i + 0], "NHD"
self.flashinfer_workspace_buffer, "NHD"
)
)
self.flashinfer_prefill_wrapper_paged.append(
BatchPrefillWithPagedKVCacheWrapper(
self.flashinfer_workspace_buffers[2 * i + 1], "NHD"
self.flashinfer_workspace_buffer, "NHD"
)
)
self.flashinfer_decode_wrapper.append(
BatchDecodeWithPagedKVCacheWrapper(
self.flashinfer_workspace_buffers[2 * i + 0],
self.flashinfer_workspace_buffer,
"NHD",
use_tensor_cores=use_tensor_cores,
)
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ def _set_envs_and_config(server_args: ServerArgs):
if not server_args.disable_flashinfer:
assert_pkg_version(
"flashinfer",
"0.1.4",
"0.1.5",
"Please uninstall the old version and "
"reinstall the latest version by following the instructions "
"at https://docs.flashinfer.ai/installation.html.",
Expand Down

0 comments on commit 326df4b

Please sign in to comment.