From 98cc898c9df0a6fc704c59f428b8784b47cb5647 Mon Sep 17 00:00:00 2001 From: Mingyi Date: Mon, 15 Jul 2024 01:18:47 -0700 Subject: [PATCH 1/7] simplify memory pool --- .../srt/managers/controller/infer_batch.py | 6 +-- .../srt/managers/controller/radix_cache.py | 4 +- python/sglang/srt/memory_pool.py | 45 +++++++++---------- 3 files changed, 26 insertions(+), 29 deletions(-) diff --git a/python/sglang/srt/managers/controller/infer_batch.py b/python/sglang/srt/managers/controller/infer_batch.py index 375ec6eeb22..c473c729d57 100644 --- a/python/sglang/srt/managers/controller/infer_batch.py +++ b/python/sglang/srt/managers/controller/infer_batch.py @@ -352,7 +352,7 @@ def prepare_for_extend(self, vocab_size: int, int_token_logit_bias: torch.Tensor extend_num_tokens = seq_lens.sum() - prefix_lens.sum() out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens) if out_cache_loc is None: - self.tree_cache.evict(extend_num_tokens, self.token_to_kv_pool.dec_refs) + self.tree_cache.evict(extend_num_tokens, self.token_to_kv_pool.free) out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens) if out_cache_loc is None: @@ -422,7 +422,7 @@ def check_decode_mem(self): if self.token_to_kv_pool.available_size() >= bs: return True - self.tree_cache.evict(bs, self.token_to_kv_pool.dec_refs) + self.tree_cache.evict(bs, self.token_to_kv_pool.free) if self.token_to_kv_pool.available_size() >= bs: return True @@ -453,7 +453,7 @@ def retract_decode(self): token_indices = self.req_to_token_pool.req_to_token[ req_pool_indices_cpu[idx] ][last_uncached_pos : seq_lens_cpu[idx]] - self.token_to_kv_pool.dec_refs(token_indices) + self.token_to_kv_pool.free(token_indices) # release the last node self.tree_cache.dec_lock_ref(req.last_node) diff --git a/python/sglang/srt/managers/controller/radix_cache.py b/python/sglang/srt/managers/controller/radix_cache.py index bc7b758dd49..c06f52f473b 100644 --- a/python/sglang/srt/managers/controller/radix_cache.py +++ b/python/sglang/srt/managers/controller/radix_cache.py @@ -82,12 +82,12 @@ def cache_req( if self.disable: if del_in_memory_pool: - self.token_to_kv_pool.dec_refs(indices) + self.token_to_kv_pool.free(indices) else: return torch.tensor([], dtype=torch.int64), self.root_node # Radix Cache takes one ref in memory pool - self.token_to_kv_pool.dec_refs(indices[last_uncached_pos:new_prefix_len]) + self.token_to_kv_pool.free(indices[last_uncached_pos:new_prefix_len]) if del_in_memory_pool: self.req_to_token_pool.free(req_pool_idx) diff --git a/python/sglang/srt/memory_pool.py b/python/sglang/srt/memory_pool.py index 647be2810c9..6b5b4111523 100644 --- a/python/sglang/srt/memory_pool.py +++ b/python/sglang/srt/memory_pool.py @@ -8,45 +8,45 @@ class ReqToTokenPool: - def __init__(self, size, max_context_len): + """A memory pool that maps a request to its token locations.""" + + def __init__(self, size: int, max_context_len: int): self.mem_state = torch.ones((size,), dtype=torch.bool, device="cuda") - self.can_use_mem_size = size self.req_to_token = torch.empty( (size, max_context_len), dtype=torch.int32, device="cuda" ) + self.can_use_mem_size = size - def alloc(self, need_size): + def alloc(self, need_size: int): if need_size > self.can_use_mem_size: return None - select_index = torch.nonzero(self.mem_state).squeeze(1)[:need_size] + select_index = torch.nonzero(self.mem_state).squeeze(1)[:need_size].to(torch.int32) self.mem_state[select_index] = False self.can_use_mem_size -= need_size - return select_index.to(torch.int32) + return select_index - def free(self, free_index): + def free(self, free_index: int): + self.mem_state[free_index] = True if isinstance(free_index, (int,)): self.can_use_mem_size += 1 else: self.can_use_mem_size += free_index.shape[0] - self.mem_state[free_index] = True - def clear(self): self.mem_state.fill_(True) self.can_use_mem_size = len(self.mem_state) class TokenToKVPool: + """A memory pool that maps a token to its kv cache locations""" + def __init__(self, size, dtype, head_num, head_dim, layer_num): self.size = size - # This can be promised: - # assert torch.all(mem_state <= 1) and torch.all(mem_state >= 0) # We also add one slot. This slot is used for writing dummy output from padded tokens. self.mem_state = torch.ones((self.size + 1,), dtype=torch.bool, device="cuda") - self.can_use_mem_size = self.size # [size, key/value, head_num, head_dim] for each layer self.kv_data = [ @@ -58,6 +58,7 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num): self.prefetch_buffer = torch.empty(0, device="cuda", dtype=torch.int32) self.prefetch_chunk_size = 512 + self.can_use_mem_size = self.size self.clear() def get_key_buffer(self, layer_id): @@ -66,6 +67,9 @@ def get_key_buffer(self, layer_id): def get_value_buffer(self, layer_id): return self.kv_data[layer_id][:, 1] + def available_size(self): + return self.can_use_mem_size + len(self.prefetch_buffer) + def alloc(self, need_size): buffer_len = len(self.prefetch_buffer) if need_size <= buffer_len: @@ -75,13 +79,13 @@ def alloc(self, need_size): addition_size = need_size - buffer_len alloc_size = max(addition_size, self.prefetch_chunk_size) - select_index = torch.nonzero(self.mem_state).squeeze(1)[:alloc_size] - select_index = select_index.to(torch.int32) + select_index = torch.nonzero(self.mem_state).squeeze(1)[:alloc_size].to(torch.int32) if select_index.shape[0] < addition_size: return None - self.add_refs(select_index) + self.mem_state[select_index] = False + self.can_use_mem_size -= len(select_index) self.prefetch_buffer = torch.cat((self.prefetch_buffer, select_index)) ret_index = self.prefetch_buffer[:need_size] @@ -89,16 +93,9 @@ def alloc(self, need_size): return ret_index - def available_size(self): - return self.can_use_mem_size + len(self.prefetch_buffer) - - def add_refs(self, token_index: torch.Tensor): - self.can_use_mem_size -= len(token_index) - self.mem_state[token_index] = False - - def dec_refs(self, token_index: torch.Tensor): - self.can_use_mem_size += len(token_index) - self.mem_state[token_index] = True + def free(self, free_index: torch.Tensor): + self.mem_state[free_index] = True + self.can_use_mem_size += len(free_index) def clear(self): self.mem_state.fill_(True) From 253208e1a814e32bdc3005f2fbfb14dcaaee1b7e Mon Sep 17 00:00:00 2001 From: Mingyi Date: Mon, 15 Jul 2024 01:23:54 -0700 Subject: [PATCH 2/7] simplify constants --- python/sglang/global_config.py | 3 ++- python/sglang/srt/managers/controller/tp_worker.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/python/sglang/global_config.py b/python/sglang/global_config.py index ba2895a9d49..61a79adaa01 100644 --- a/python/sglang/global_config.py +++ b/python/sglang/global_config.py @@ -25,7 +25,8 @@ def __init__(self): # This can improve the speed for large batch sizes during prefill. self.layer_sync_threshold = 8192 - # Runtime constants: Flashinfer + # Runtime constants: others + self.num_continue_decode_steps = 8 self.flashinfer_workspace_size = 192 * 1024 * 1024 # Output tokenization configs diff --git a/python/sglang/srt/managers/controller/tp_worker.py b/python/sglang/srt/managers/controller/tp_worker.py index 12c278fd586..96969095c1a 100644 --- a/python/sglang/srt/managers/controller/tp_worker.py +++ b/python/sglang/srt/managers/controller/tp_worker.py @@ -222,7 +222,7 @@ def forward_step(self): # Run decode batch if self.running_batch is not None: # Run a few decode batches continuously for reducing overhead - for _ in range(10): + for _ in range(global_config.num_continue_decode_steps): self.num_generated_tokens += len(self.running_batch.reqs) self.forward_decode_batch(self.running_batch) From 1191719fe5b1f89194636ae822e4904903f8fdf9 Mon Sep 17 00:00:00 2001 From: Mingyi Date: Mon, 15 Jul 2024 01:39:41 -0700 Subject: [PATCH 3/7] simplify --- python/sglang/global_config.py | 2 +- .../srt/managers/controller/infer_batch.py | 6 +-- .../managers/controller/schedule_heuristic.py | 4 ++ .../srt/managers/controller/tp_worker.py | 54 +++++++++---------- 4 files changed, 31 insertions(+), 35 deletions(-) diff --git a/python/sglang/global_config.py b/python/sglang/global_config.py index 61a79adaa01..629af6a2a06 100644 --- a/python/sglang/global_config.py +++ b/python/sglang/global_config.py @@ -26,7 +26,7 @@ def __init__(self): self.layer_sync_threshold = 8192 # Runtime constants: others - self.num_continue_decode_steps = 8 + self.num_continue_decode_steps = 10 self.flashinfer_workspace_size = 192 * 1024 * 1024 # Output tokenization configs diff --git a/python/sglang/srt/managers/controller/infer_batch.py b/python/sglang/srt/managers/controller/infer_batch.py index c473c729d57..387d8f471f4 100644 --- a/python/sglang/srt/managers/controller/infer_batch.py +++ b/python/sglang/srt/managers/controller/infer_batch.py @@ -174,9 +174,6 @@ def detokenize_incrementally(self, inplace: bool = True): return False, "" - def max_new_tokens(self): - return self.sampling_params.max_new_tokens - def check_finished(self): if self.finished(): return @@ -596,8 +593,7 @@ def filter_batch(self, unfinished_indices: List[int]): "logit_bias", ]: self_val = getattr(self, item, None) - # logit_bias can be None - if self_val is not None: + if self_val is not None: # logit_bias can be None setattr(self, item, self_val[new_indices]) def merge(self, other: "Batch"): diff --git a/python/sglang/srt/managers/controller/schedule_heuristic.py b/python/sglang/srt/managers/controller/schedule_heuristic.py index 4ae1a7069fd..aae6cfb86fa 100644 --- a/python/sglang/srt/managers/controller/schedule_heuristic.py +++ b/python/sglang/srt/managers/controller/schedule_heuristic.py @@ -13,6 +13,10 @@ def __init__( max_total_num_tokens, tree_cache, ): + if tree_cache.disable and schedule_heuristic == "lpm": + # LMP is not meaningless when tree cache is disabled. + schedule_heuristic = "fcfs" + self.schedule_heuristic = schedule_heuristic self.max_running_seqs = max_running_seqs self.max_prefill_num_tokens = max_prefill_num_tokens diff --git a/python/sglang/srt/managers/controller/tp_worker.py b/python/sglang/srt/managers/controller/tp_worker.py index 96969095c1a..6a06891d3c2 100644 --- a/python/sglang/srt/managers/controller/tp_worker.py +++ b/python/sglang/srt/managers/controller/tp_worker.py @@ -227,31 +227,30 @@ def forward_step(self): self.forward_decode_batch(self.running_batch) # Print stats - if self.tp_rank == 0: - if self.decode_forward_ct % 40 == 0: - num_used = self.max_total_num_tokens - ( - self.token_to_kv_pool.available_size() - + self.tree_cache.evictable_size() - ) - throughput = self.num_generated_tokens / ( - time.time() - self.last_stats_tic - ) - self.num_generated_tokens = 0 - self.last_stats_tic = time.time() - logger.info( - f"[gpu_id={self.gpu_id}] Decode batch. " - f"#running-req: {len(self.running_batch.reqs)}, " - f"#token: {num_used}, " - f"token usage: {num_used / self.max_total_num_tokens:.2f}, " - f"gen throughput (token/s): {throughput:.2f}, " - f"#queue-req: {len(self.forward_queue)}" - ) + if self.tp_rank == 0 and self.decode_forward_ct % 40 == 0: + num_used = self.max_total_num_tokens - ( + self.token_to_kv_pool.available_size() + + self.tree_cache.evictable_size() + ) + throughput = self.num_generated_tokens / ( + time.time() - self.last_stats_tic + ) + self.num_generated_tokens = 0 + self.last_stats_tic = time.time() + logger.info( + f"[gpu_id={self.gpu_id}] Decode batch. " + f"#running-req: {len(self.running_batch.reqs)}, " + f"#token: {num_used}, " + f"token usage: {num_used / self.max_total_num_tokens:.2f}, " + f"gen throughput (token/s): {throughput:.2f}, " + f"#queue-req: {len(self.forward_queue)}" + ) if self.running_batch.is_empty(): self.running_batch = None break - if self.out_pyobjs and self.running_batch.has_stream(): + if self.out_pyobjs: break else: # Check the available size @@ -344,7 +343,7 @@ def get_new_fill_batch(self) -> Optional[Batch]: if self.running_batch: available_size -= sum( [ - (r.max_new_tokens() - len(r.output_ids)) * self.new_token_ratio + (r.sampling_params.max_new_tokens - len(r.output_ids)) * self.new_token_ratio for r in self.running_batch.reqs ] ) @@ -358,7 +357,7 @@ def get_new_fill_batch(self) -> Optional[Batch]: req.prefix_indices = req.prefix_indices[:-delta] if req.image_offset is not None: req.image_offset += delta - if req.extend_input_len == 0 and req.max_new_tokens() > 0: + if req.extend_input_len == 0 and req.sampling_params.max_new_tokens > 0: # Need at least one token to compute logits req.extend_input_len = 1 req.prefix_indices = req.prefix_indices[:-1] @@ -366,7 +365,7 @@ def get_new_fill_batch(self) -> Optional[Batch]: req.image_offset += 1 if ( - req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens + req.extend_input_len + req.sampling_params.max_new_tokens + new_batch_total_tokens < available_size and ( req.extend_input_len + new_batch_input_tokens @@ -378,7 +377,7 @@ def get_new_fill_batch(self) -> Optional[Batch]: available_size += delta if not ( - req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens + req.extend_input_len + req.sampling_params.max_new_tokens + new_batch_total_tokens < available_size ): # Undo locking @@ -389,7 +388,7 @@ def get_new_fill_batch(self) -> Optional[Batch]: # Add this request to the running batch can_run_list.append(req) new_batch_total_tokens += ( - req.extend_input_len + req.max_new_tokens() + req.extend_input_len + req.sampling_params.max_new_tokens ) new_batch_input_tokens += req.extend_input_len else: @@ -403,9 +402,6 @@ def get_new_fill_batch(self) -> Optional[Batch]: # Print stats if self.tp_rank == 0: - running_req = ( - 0 if self.running_batch is None else len(self.running_batch.reqs) - ) hit_tokens = sum(len(x.prefix_indices) for x in can_run_list) self.tree_cache_metrics["total"] += ( hit_tokens + new_batch_input_tokens @@ -420,7 +416,7 @@ def get_new_fill_batch(self) -> Optional[Batch]: f"#new-token: {new_batch_input_tokens}, " f"#cached-token: {hit_tokens}, " f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, " - f"#running-req: {running_req}, " + f"#running-req: {running_bs}, " f"#queue-req: {len(self.forward_queue) - len(can_run_list)}" ) # logger.debug( From 0837035e74f222e9bf802306452edff54c0d5afd Mon Sep 17 00:00:00 2001 From: Mingyi Date: Mon, 15 Jul 2024 01:42:39 -0700 Subject: [PATCH 4/7] update --- benchmark/latency_throughput/bench_serving.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/benchmark/latency_throughput/bench_serving.py b/benchmark/latency_throughput/bench_serving.py index 1adb78958cc..23e8245f231 100644 --- a/benchmark/latency_throughput/bench_serving.py +++ b/benchmark/latency_throughput/bench_serving.py @@ -297,7 +297,8 @@ def main(args: argparse.Namespace): benchmark_time = benchmark_end_time - benchmark_start_time # Compute the statistics. - avg_latency = np.mean([latency for _, _, latency in REQUEST_LATENCY]) + latencies = [latency for _, _, latency in REQUEST_LATENCY] + avg_latency = np.mean(latencies) avg_per_token_latency = np.mean( [ latency / (prompt_len + output_len) From 305790c433dc054101fe61ee192d60c76b4c5f05 Mon Sep 17 00:00:00 2001 From: Mingyi Date: Mon, 15 Jul 2024 01:49:19 -0700 Subject: [PATCH 5/7] update --- python/sglang/srt/managers/controller/tp_worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/controller/tp_worker.py b/python/sglang/srt/managers/controller/tp_worker.py index 6a06891d3c2..35282eaf57a 100644 --- a/python/sglang/srt/managers/controller/tp_worker.py +++ b/python/sglang/srt/managers/controller/tp_worker.py @@ -98,7 +98,7 @@ def __init__( ) self.max_total_num_tokens = self.model_runner.max_total_num_tokens self.max_prefill_tokens = ( - 8192 + 16384 if server_args.max_prefill_tokens is None else server_args.max_prefill_tokens ) From 102320d4b78233950a8c87e756d458f9c144ce01 Mon Sep 17 00:00:00 2001 From: Mingyi Date: Mon, 15 Jul 2024 01:54:17 -0700 Subject: [PATCH 6/7] fix --- python/sglang/global_config.py | 2 +- python/sglang/srt/managers/controller/tp_worker.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/global_config.py b/python/sglang/global_config.py index 629af6a2a06..61a79adaa01 100644 --- a/python/sglang/global_config.py +++ b/python/sglang/global_config.py @@ -26,7 +26,7 @@ def __init__(self): self.layer_sync_threshold = 8192 # Runtime constants: others - self.num_continue_decode_steps = 10 + self.num_continue_decode_steps = 8 self.flashinfer_workspace_size = 192 * 1024 * 1024 # Output tokenization configs diff --git a/python/sglang/srt/managers/controller/tp_worker.py b/python/sglang/srt/managers/controller/tp_worker.py index 35282eaf57a..1d22dfdf171 100644 --- a/python/sglang/srt/managers/controller/tp_worker.py +++ b/python/sglang/srt/managers/controller/tp_worker.py @@ -250,7 +250,7 @@ def forward_step(self): self.running_batch = None break - if self.out_pyobjs: + if self.out_pyobjs and self.running_batch.has_stream(): break else: # Check the available size From c9592e518934d66d3539fea17df0ae6f0d370b3d Mon Sep 17 00:00:00 2001 From: Mingyi Date: Mon, 15 Jul 2024 01:59:39 -0700 Subject: [PATCH 7/7] fix --- python/sglang/global_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/global_config.py b/python/sglang/global_config.py index 61a79adaa01..629af6a2a06 100644 --- a/python/sglang/global_config.py +++ b/python/sglang/global_config.py @@ -26,7 +26,7 @@ def __init__(self): self.layer_sync_threshold = 8192 # Runtime constants: others - self.num_continue_decode_steps = 8 + self.num_continue_decode_steps = 10 self.flashinfer_workspace_size = 192 * 1024 * 1024 # Output tokenization configs