Skip to content

Commit

Permalink
Simplify mem state (#623)
Browse files Browse the repository at this point in the history
  • Loading branch information
wisclmy0611 authored Jul 15, 2024
1 parent bae9541 commit 5ac8b80
Show file tree
Hide file tree
Showing 7 changed files with 61 additions and 66 deletions.
3 changes: 2 additions & 1 deletion benchmark/latency_throughput/bench_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,8 @@ def main(args: argparse.Namespace):
benchmark_time = benchmark_end_time - benchmark_start_time

# Compute the statistics.
avg_latency = np.mean([latency for _, _, latency in REQUEST_LATENCY])
latencies = [latency for _, _, latency in REQUEST_LATENCY]
avg_latency = np.mean(latencies)
avg_per_token_latency = np.mean(
[
latency / (prompt_len + output_len)
Expand Down
3 changes: 2 additions & 1 deletion python/sglang/global_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ def __init__(self):
# This can improve the speed for large batch sizes during prefill.
self.layer_sync_threshold = 8192

# Runtime constants: Flashinfer
# Runtime constants: others
self.num_continue_decode_steps = 10
self.flashinfer_workspace_size = 192 * 1024 * 1024

# Output tokenization configs
Expand Down
12 changes: 4 additions & 8 deletions python/sglang/srt/managers/controller/infer_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,6 @@ def detokenize_incrementally(self, inplace: bool = True):

return False, ""

def max_new_tokens(self):
return self.sampling_params.max_new_tokens

def check_finished(self):
if self.finished():
return
Expand Down Expand Up @@ -352,7 +349,7 @@ def prepare_for_extend(self, vocab_size: int, int_token_logit_bias: torch.Tensor
extend_num_tokens = seq_lens.sum() - prefix_lens.sum()
out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
if out_cache_loc is None:
self.tree_cache.evict(extend_num_tokens, self.token_to_kv_pool.dec_refs)
self.tree_cache.evict(extend_num_tokens, self.token_to_kv_pool.free)
out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)

if out_cache_loc is None:
Expand Down Expand Up @@ -422,7 +419,7 @@ def check_decode_mem(self):
if self.token_to_kv_pool.available_size() >= bs:
return True

self.tree_cache.evict(bs, self.token_to_kv_pool.dec_refs)
self.tree_cache.evict(bs, self.token_to_kv_pool.free)

if self.token_to_kv_pool.available_size() >= bs:
return True
Expand Down Expand Up @@ -453,7 +450,7 @@ def retract_decode(self):
token_indices = self.req_to_token_pool.req_to_token[
req_pool_indices_cpu[idx]
][last_uncached_pos : seq_lens_cpu[idx]]
self.token_to_kv_pool.dec_refs(token_indices)
self.token_to_kv_pool.free(token_indices)

# release the last node
self.tree_cache.dec_lock_ref(req.last_node)
Expand Down Expand Up @@ -596,8 +593,7 @@ def filter_batch(self, unfinished_indices: List[int]):
"logit_bias",
]:
self_val = getattr(self, item, None)
# logit_bias can be None
if self_val is not None:
if self_val is not None: # logit_bias can be None
setattr(self, item, self_val[new_indices])

def merge(self, other: "Batch"):
Expand Down
4 changes: 2 additions & 2 deletions python/sglang/srt/managers/controller/radix_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,12 @@ def cache_req(

if self.disable:
if del_in_memory_pool:
self.token_to_kv_pool.dec_refs(indices)
self.token_to_kv_pool.free(indices)
else:
return torch.tensor([], dtype=torch.int64), self.root_node

# Radix Cache takes one ref in memory pool
self.token_to_kv_pool.dec_refs(indices[last_uncached_pos:new_prefix_len])
self.token_to_kv_pool.free(indices[last_uncached_pos:new_prefix_len])

if del_in_memory_pool:
self.req_to_token_pool.free(req_pool_idx)
Expand Down
4 changes: 4 additions & 0 deletions python/sglang/srt/managers/controller/schedule_heuristic.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ def __init__(
max_total_num_tokens,
tree_cache,
):
if tree_cache.disable and schedule_heuristic == "lpm":
# LMP is not meaningless when tree cache is disabled.
schedule_heuristic = "fcfs"

self.schedule_heuristic = schedule_heuristic
self.max_running_seqs = max_running_seqs
self.max_prefill_num_tokens = max_prefill_num_tokens
Expand Down
56 changes: 26 additions & 30 deletions python/sglang/srt/managers/controller/tp_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def __init__(
)
self.max_total_num_tokens = self.model_runner.max_total_num_tokens
self.max_prefill_tokens = (
8192
16384
if server_args.max_prefill_tokens is None
else server_args.max_prefill_tokens
)
Expand Down Expand Up @@ -222,30 +222,29 @@ def forward_step(self):
# Run decode batch
if self.running_batch is not None:
# Run a few decode batches continuously for reducing overhead
for _ in range(10):
for _ in range(global_config.num_continue_decode_steps):
self.num_generated_tokens += len(self.running_batch.reqs)
self.forward_decode_batch(self.running_batch)

# Print stats
if self.tp_rank == 0:
if self.decode_forward_ct % 40 == 0:
num_used = self.max_total_num_tokens - (
self.token_to_kv_pool.available_size()
+ self.tree_cache.evictable_size()
)
throughput = self.num_generated_tokens / (
time.time() - self.last_stats_tic
)
self.num_generated_tokens = 0
self.last_stats_tic = time.time()
logger.info(
f"[gpu_id={self.gpu_id}] Decode batch. "
f"#running-req: {len(self.running_batch.reqs)}, "
f"#token: {num_used}, "
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
f"gen throughput (token/s): {throughput:.2f}, "
f"#queue-req: {len(self.forward_queue)}"
)
if self.tp_rank == 0 and self.decode_forward_ct % 40 == 0:
num_used = self.max_total_num_tokens - (
self.token_to_kv_pool.available_size()
+ self.tree_cache.evictable_size()
)
throughput = self.num_generated_tokens / (
time.time() - self.last_stats_tic
)
self.num_generated_tokens = 0
self.last_stats_tic = time.time()
logger.info(
f"[gpu_id={self.gpu_id}] Decode batch. "
f"#running-req: {len(self.running_batch.reqs)}, "
f"#token: {num_used}, "
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
f"gen throughput (token/s): {throughput:.2f}, "
f"#queue-req: {len(self.forward_queue)}"
)

if self.running_batch.is_empty():
self.running_batch = None
Expand Down Expand Up @@ -344,7 +343,7 @@ def get_new_fill_batch(self) -> Optional[Batch]:
if self.running_batch:
available_size -= sum(
[
(r.max_new_tokens() - len(r.output_ids)) * self.new_token_ratio
(r.sampling_params.max_new_tokens - len(r.output_ids)) * self.new_token_ratio
for r in self.running_batch.reqs
]
)
Expand All @@ -358,15 +357,15 @@ def get_new_fill_batch(self) -> Optional[Batch]:
req.prefix_indices = req.prefix_indices[:-delta]
if req.image_offset is not None:
req.image_offset += delta
if req.extend_input_len == 0 and req.max_new_tokens() > 0:
if req.extend_input_len == 0 and req.sampling_params.max_new_tokens > 0:
# Need at least one token to compute logits
req.extend_input_len = 1
req.prefix_indices = req.prefix_indices[:-1]
if req.image_offset is not None:
req.image_offset += 1

if (
req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens
req.extend_input_len + req.sampling_params.max_new_tokens + new_batch_total_tokens
< available_size
and (
req.extend_input_len + new_batch_input_tokens
Expand All @@ -378,7 +377,7 @@ def get_new_fill_batch(self) -> Optional[Batch]:
available_size += delta

if not (
req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens
req.extend_input_len + req.sampling_params.max_new_tokens + new_batch_total_tokens
< available_size
):
# Undo locking
Expand All @@ -389,7 +388,7 @@ def get_new_fill_batch(self) -> Optional[Batch]:
# Add this request to the running batch
can_run_list.append(req)
new_batch_total_tokens += (
req.extend_input_len + req.max_new_tokens()
req.extend_input_len + req.sampling_params.max_new_tokens
)
new_batch_input_tokens += req.extend_input_len
else:
Expand All @@ -403,9 +402,6 @@ def get_new_fill_batch(self) -> Optional[Batch]:

# Print stats
if self.tp_rank == 0:
running_req = (
0 if self.running_batch is None else len(self.running_batch.reqs)
)
hit_tokens = sum(len(x.prefix_indices) for x in can_run_list)
self.tree_cache_metrics["total"] += (
hit_tokens + new_batch_input_tokens
Expand All @@ -420,7 +416,7 @@ def get_new_fill_batch(self) -> Optional[Batch]:
f"#new-token: {new_batch_input_tokens}, "
f"#cached-token: {hit_tokens}, "
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
f"#running-req: {running_req}, "
f"#running-req: {running_bs}, "
f"#queue-req: {len(self.forward_queue) - len(can_run_list)}"
)
# logger.debug(
Expand Down
45 changes: 21 additions & 24 deletions python/sglang/srt/memory_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,45 +8,45 @@


class ReqToTokenPool:
def __init__(self, size, max_context_len):
"""A memory pool that maps a request to its token locations."""

def __init__(self, size: int, max_context_len: int):
self.mem_state = torch.ones((size,), dtype=torch.bool, device="cuda")
self.can_use_mem_size = size
self.req_to_token = torch.empty(
(size, max_context_len), dtype=torch.int32, device="cuda"
)
self.can_use_mem_size = size

def alloc(self, need_size):
def alloc(self, need_size: int):
if need_size > self.can_use_mem_size:
return None

select_index = torch.nonzero(self.mem_state).squeeze(1)[:need_size]
select_index = torch.nonzero(self.mem_state).squeeze(1)[:need_size].to(torch.int32)
self.mem_state[select_index] = False
self.can_use_mem_size -= need_size

return select_index.to(torch.int32)
return select_index

def free(self, free_index):
def free(self, free_index: int):
self.mem_state[free_index] = True
if isinstance(free_index, (int,)):
self.can_use_mem_size += 1
else:
self.can_use_mem_size += free_index.shape[0]

self.mem_state[free_index] = True

def clear(self):
self.mem_state.fill_(True)
self.can_use_mem_size = len(self.mem_state)


class TokenToKVPool:
"""A memory pool that maps a token to its kv cache locations"""

def __init__(self, size, dtype, head_num, head_dim, layer_num):
self.size = size

# This can be promised:
# assert torch.all(mem_state <= 1) and torch.all(mem_state >= 0)
# We also add one slot. This slot is used for writing dummy output from padded tokens.
self.mem_state = torch.ones((self.size + 1,), dtype=torch.bool, device="cuda")
self.can_use_mem_size = self.size

# [size, key/value, head_num, head_dim] for each layer
self.kv_data = [
Expand All @@ -58,6 +58,7 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num):
self.prefetch_buffer = torch.empty(0, device="cuda", dtype=torch.int32)
self.prefetch_chunk_size = 512

self.can_use_mem_size = self.size
self.clear()

def get_key_buffer(self, layer_id):
Expand All @@ -66,6 +67,9 @@ def get_key_buffer(self, layer_id):
def get_value_buffer(self, layer_id):
return self.kv_data[layer_id][:, 1]

def available_size(self):
return self.can_use_mem_size + len(self.prefetch_buffer)

def alloc(self, need_size):
buffer_len = len(self.prefetch_buffer)
if need_size <= buffer_len:
Expand All @@ -75,30 +79,23 @@ def alloc(self, need_size):

addition_size = need_size - buffer_len
alloc_size = max(addition_size, self.prefetch_chunk_size)
select_index = torch.nonzero(self.mem_state).squeeze(1)[:alloc_size]
select_index = select_index.to(torch.int32)
select_index = torch.nonzero(self.mem_state).squeeze(1)[:alloc_size].to(torch.int32)

if select_index.shape[0] < addition_size:
return None

self.add_refs(select_index)
self.mem_state[select_index] = False
self.can_use_mem_size -= len(select_index)

self.prefetch_buffer = torch.cat((self.prefetch_buffer, select_index))
ret_index = self.prefetch_buffer[:need_size]
self.prefetch_buffer = self.prefetch_buffer[need_size:]

return ret_index

def available_size(self):
return self.can_use_mem_size + len(self.prefetch_buffer)

def add_refs(self, token_index: torch.Tensor):
self.can_use_mem_size -= len(token_index)
self.mem_state[token_index] = False

def dec_refs(self, token_index: torch.Tensor):
self.can_use_mem_size += len(token_index)
self.mem_state[token_index] = True
def free(self, free_index: torch.Tensor):
self.mem_state[free_index] = True
self.can_use_mem_size += len(free_index)

def clear(self):
self.mem_state.fill_(True)
Expand Down

0 comments on commit 5ac8b80

Please sign in to comment.