Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chunked prefill support #797

Merged
merged 13 commits into from
Jul 29, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions python/sglang/srt/managers/controller/schedule_heuristic.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,24 +38,24 @@ def __init__(
self.max_total_num_tokens = max_total_num_tokens
self.tree_cache = tree_cache

def get_priority_queue(self, forward_queue):
def get_priority_queue(self, waiting_queue):
if self.schedule_heuristic == "lpm":
# longest prefix match
forward_queue.sort(key=lambda x: -len(x.prefix_indices))
return forward_queue
waiting_queue.sort(key=lambda x: -len(x.prefix_indices))
return waiting_queue
elif self.schedule_heuristic == "fcfs":
# first come first serve
return forward_queue
return waiting_queue
elif self.schedule_heuristic == "lof":
# longest output first
forward_queue.sort(key=lambda x: -x.sampling_params.max_new_tokens)
return forward_queue
waiting_queue.sort(key=lambda x: -x.sampling_params.max_new_tokens)
return waiting_queue
elif self.schedule_heuristic == "random":
random.shuffle(forward_queue)
return forward_queue
random.shuffle(waiting_queue)
return waiting_queue
elif self.schedule_heuristic == "dfs-weight":
last_node_to_reqs = defaultdict(list)
for req in forward_queue:
for req in waiting_queue:
last_node_to_reqs[req.last_node].append(req)

node_to_weight = defaultdict(int)
Expand All @@ -67,7 +67,7 @@ def get_priority_queue(self, forward_queue):
self.get_dfs_priority(
self.tree_cache.root_node, node_to_weight, last_node_to_reqs, q
)
assert len(q) == len(forward_queue)
assert len(q) == len(waiting_queue)
return q
else:
raise ValueError(f"Unknown schedule_heuristic: {self.schedule_heuristic}")
Expand Down
115 changes: 92 additions & 23 deletions python/sglang/srt/managers/controller/tp_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@ def __init__(
self.schedule_heuristic = server_args.schedule_heuristic
self.disable_regex_jump_forward = server_args.disable_regex_jump_forward

# Chunked prefill
self.chunked_prefill_size = server_args.chunked_prefill_size
self.current_inflight_req = None

# Init model and tokenizer
self.model_config = ModelConfig(
server_args.model_path,
Expand Down Expand Up @@ -157,7 +161,7 @@ def __init__(
self.token_to_kv_pool = self.model_runner.token_to_kv_pool

# Init running status
self.forward_queue: List[Req] = []
self.waiting_queue: List[Req] = []
self.running_batch: Batch = None
self.out_pyobjs = []
self.decode_forward_ct = 0
Expand Down Expand Up @@ -220,6 +224,7 @@ def forward_step(self):
# Run a new prefill batch
self.forward_prefill_batch(new_batch)
self.cache_filled_batch(new_batch)
self.filter_out_inflight(new_batch)

if not new_batch.is_empty():
if self.running_batch is None:
Expand Down Expand Up @@ -261,7 +266,7 @@ def print_stats(self):
f"#token: {num_used}, "
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
f"gen throughput (token/s): {throughput:.2f}, "
f"#queue-req: {len(self.forward_queue)}"
f"#queue-req: {len(self.waiting_queue)}"
)

def check_memory(self):
Expand Down Expand Up @@ -328,7 +333,7 @@ def handle_generate_request(
),
self.max_req_input_len - 1 - len(req.origin_input_ids),
)
self.forward_queue.append(req)
self.waiting_queue.append(req)

def get_new_prefill_batch(self) -> Optional[Batch]:
running_bs = (
Expand All @@ -338,7 +343,7 @@ def get_new_prefill_batch(self) -> Optional[Batch]:
return

# Compute matched prefix length
for req in self.forward_queue:
for req in self.waiting_queue:
req.input_ids = req.origin_input_ids + req.output_ids
prefix_indices, last_node = self.tree_cache.match_prefix(req.input_ids)
if req.return_logprob:
Expand All @@ -348,7 +353,7 @@ def get_new_prefill_batch(self) -> Optional[Batch]:
req.last_node = last_node

# Get priority queue
self.forward_queue = self.scheduler.get_priority_queue(self.forward_queue)
self.waiting_queue = self.scheduler.get_priority_queue(self.waiting_queue)

# Add requests if there is available space
can_run_list = []
Expand All @@ -367,7 +372,33 @@ def get_new_prefill_batch(self) -> Optional[Batch]:
]
)

for req in self.forward_queue:
# Handle the current inflight request
take_inflight = 0
hnyls2002 marked this conversation as resolved.
Show resolved Hide resolved
if self.current_inflight_req:
take_inflight = 1
r = self.current_inflight_req
r.input_ids = r.origin_input_ids + r.output_ids
truncated = (
len(r.input_ids) - len(r.prefix_indices) > self.chunked_prefill_size
)
r.extend_input_len = min(
len(r.input_ids) - len(r.prefix_indices), self.chunked_prefill_size
)
r.input_ids = r.input_ids[: len(r.prefix_indices) + r.extend_input_len]
can_run_list.append(r)

if not truncated:
# Finish inflight
self.current_inflight_req = None
new_batch_total_tokens += (
r.extend_input_len + r.sampling_params.max_new_tokens
)
new_batch_input_tokens += r.extend_input_len
else:
new_batch_total_tokens += r.extend_input_len
new_batch_input_tokens += r.extend_input_len

for req in self.waiting_queue:
if req.return_logprob and req.normalized_prompt_logprob is None:
# Need at least two tokens to compute normalized logprob
if req.extend_input_len < 2:
Expand Down Expand Up @@ -409,11 +440,36 @@ def get_new_prefill_batch(self) -> Optional[Batch]:
break
else:
# Add this request to the running batch
can_run_list.append(req)
new_batch_total_tokens += (
req.extend_input_len + req.sampling_params.max_new_tokens
)
new_batch_input_tokens += req.extend_input_len
if (
new_batch_input_tokens + req.extend_input_len
<= self.chunked_prefill_size
or (
req.return_logprob and req.normalized_prompt_logprob is None
)
):
can_run_list.append(req)
new_batch_total_tokens += (
req.extend_input_len + req.sampling_params.max_new_tokens
)
new_batch_input_tokens += req.extend_input_len
else:
trunc_len = self.chunked_prefill_size - new_batch_input_tokens

if trunc_len <= 0:
# Undo locking
delta = self.tree_cache.dec_lock_ref(req.last_node)
available_size += delta
break

req.extend_input_len = trunc_len
req.input_ids = req.input_ids[
: len(req.prefix_indices) + req.extend_input_len
]
can_run_list.append(req)
self.current_inflight_req = req
new_batch_input_tokens += req.extend_input_len
new_batch_total_tokens += req.extend_input_len
break
else:
break

Expand All @@ -440,7 +496,7 @@ def get_new_prefill_batch(self) -> Optional[Batch]:
f"#cached-token: {hit_tokens}, "
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
f"#running-req: {running_bs}, "
f"#queue-req: {len(self.forward_queue) - len(can_run_list)}"
f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + take_inflight}"
)

# Return the new batch
Expand All @@ -450,7 +506,7 @@ def get_new_prefill_batch(self) -> Optional[Batch]:
self.token_to_kv_pool,
self.tree_cache,
)
self.forward_queue = [x for x in self.forward_queue if x not in can_run_list]
self.waiting_queue = [x for x in self.waiting_queue if x not in can_run_list]
return new_batch

def forward_prefill_batch(self, batch: Batch):
Expand Down Expand Up @@ -482,9 +538,10 @@ def forward_prefill_batch(self, batch: Batch):
# Check finish conditions
pt = 0
for i, req in enumerate(batch.reqs):
req.completion_tokens_wo_jump_forward += 1
req.output_ids.append(next_token_ids[i])
req.check_finished()
if req is not self.current_inflight_req:
req.completion_tokens_wo_jump_forward += 1
req.output_ids.append(next_token_ids[i])
req.check_finished()

if req.return_logprob:
self.add_logprob_return_values(i, req, pt, next_token_ids, output)
Expand Down Expand Up @@ -545,7 +602,7 @@ def cache_filled_batch(self, batch: Batch):
req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy()
for i, req in enumerate(batch.reqs):
new_prefix_indices, new_last_node = self.tree_cache.cache_req(
token_ids=tuple(req.origin_input_ids + req.output_ids)[:-1],
token_ids=tuple(req.input_ids),
last_uncached_pos=len(req.prefix_indices),
req_pool_idx=req_pool_indices_cpu[i],
del_in_memory_pool=False,
Expand All @@ -566,7 +623,7 @@ def forward_decode_batch(self, batch: Batch):
f"#retracted_reqs: {len(retracted_reqs)}, "
f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
)
self.forward_queue.extend(retracted_reqs)
self.waiting_queue.extend(retracted_reqs)
else:
self.new_token_ratio = max(
self.new_token_ratio - self.new_token_ratio_decay,
Expand All @@ -576,7 +633,7 @@ def forward_decode_batch(self, batch: Batch):
if not self.disable_regex_jump_forward:
# Check for jump-forward
jump_forward_reqs = batch.check_for_jump_forward(self.model_runner)
self.forward_queue.extend(jump_forward_reqs)
self.waiting_queue.extend(jump_forward_reqs)
if batch.is_empty():
return

Expand Down Expand Up @@ -711,8 +768,20 @@ def handle_finished_requests(self, batch: Batch):
else:
batch.reqs = []

def filter_out_inflight(self, batch: Batch):
if self.current_inflight_req is None:
return

unfinished_indices = list(range(len(batch.reqs)))
for i, req in enumerate(batch.reqs):
hnyls2002 marked this conversation as resolved.
Show resolved Hide resolved
if req is self.current_inflight_req:
unfinished_indices.remove(i)
break

batch.filter_batch(unfinished_indices)
hnyls2002 marked this conversation as resolved.
Show resolved Hide resolved

def flush_cache(self):
if len(self.forward_queue) == 0 and (
if len(self.waiting_queue) == 0 and (
self.running_batch is None or len(self.running_batch.reqs) == 0
):
self.tree_cache.reset()
Expand All @@ -725,20 +794,20 @@ def flush_cache(self):
else:
warnings.warn(
f"Cache not flushed because there are pending requests. "
f"#queue-req: {len(self.forward_queue)}, "
f"#queue-req: {len(self.waiting_queue)}, "
f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
)

def abort_request(self, recv_req):
# Delete requests in the waiting queue
to_del = None
for i, req in enumerate(self.forward_queue):
for i, req in enumerate(self.waiting_queue):
if req.rid == recv_req.rid:
to_del = i
break

if to_del is not None:
del self.forward_queue[to_del]
del self.waiting_queue[to_del]

# Delete requests in the running batch
if self.running_batch:
Expand Down
54 changes: 36 additions & 18 deletions python/sglang/srt/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,12 +175,46 @@ def _set_torch_compile_config():
torch._dynamo.config.accumulated_cache_size_limit = 256


def set_envs_and_config(server_args: ServerArgs):
hnyls2002 marked this conversation as resolved.
Show resolved Hide resolved
# Set global environments
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
os.environ["NCCL_CUMEM_ENABLE"] = "0"
os.environ["NCCL_NVLS_ENABLE"] = "0"
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"

# Set ulimit
set_ulimit()

# Enable show time cost for debugging
if server_args.show_time_cost:
enable_show_time_cost()

# Disable disk cache
if server_args.disable_disk_cache:
disable_cache()

# Fix triton bugs
if server_args.tp_size * server_args.dp_size > 1:
# FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency.
maybe_set_triton_cache_manager()

# Set torch compile config
if server_args.enable_torch_compile:
_set_torch_compile_config()

# Set global chat template
if server_args.chat_template:
# TODO: replace this with huggingface transformers template
load_chat_template_for_openai_api(server_args.chat_template)


def launch_server(
server_args: ServerArgs,
model_overide_args: Optional[dict] = None,
pipe_finish_writer: Optional[mp.connection.Connection] = None,
):
server_args.check_server_args()
server_args.post_server_args()

"""Launch an HTTP server."""
global tokenizer_manager
Expand All @@ -190,16 +224,6 @@ def launch_server(
format="%(message)s",
)

# Set global environments
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
os.environ["NCCL_CUMEM_ENABLE"] = "0"
os.environ["NCCL_NVLS_ENABLE"] = "0"
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
set_ulimit()
if server_args.show_time_cost:
enable_show_time_cost()
if server_args.disable_disk_cache:
disable_cache()
if not server_args.disable_flashinfer:
assert_pkg_version(
"flashinfer",
Expand All @@ -208,14 +232,8 @@ def launch_server(
"reinstall the latest version by following the instructions "
"at https://docs.flashinfer.ai/installation.html.",
)
if server_args.tp_size * server_args.dp_size > 1:
# FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency.
maybe_set_triton_cache_manager()
if server_args.chat_template:
# TODO: replace this with huggingface transformers template
load_chat_template_for_openai_api(server_args.chat_template)
if server_args.enable_torch_compile:
_set_torch_compile_config()

set_envs_and_config(server_args)

# Allocate ports
server_args.port, server_args.additional_ports = allocate_init_ports(
Expand Down
Loading
Loading