Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify mem state #623

Merged
merged 7 commits into from
Jul 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion benchmark/latency_throughput/bench_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,8 @@ def main(args: argparse.Namespace):
benchmark_time = benchmark_end_time - benchmark_start_time

# Compute the statistics.
avg_latency = np.mean([latency for _, _, latency in REQUEST_LATENCY])
latencies = [latency for _, _, latency in REQUEST_LATENCY]
avg_latency = np.mean(latencies)
avg_per_token_latency = np.mean(
[
latency / (prompt_len + output_len)
Expand Down
3 changes: 2 additions & 1 deletion python/sglang/global_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ def __init__(self):
# This can improve the speed for large batch sizes during prefill.
self.layer_sync_threshold = 8192

# Runtime constants: Flashinfer
# Runtime constants: others
self.num_continue_decode_steps = 10
self.flashinfer_workspace_size = 192 * 1024 * 1024

# Output tokenization configs
Expand Down
12 changes: 4 additions & 8 deletions python/sglang/srt/managers/controller/infer_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,6 @@ def detokenize_incrementally(self, inplace: bool = True):

return False, ""

def max_new_tokens(self):
return self.sampling_params.max_new_tokens

def check_finished(self):
if self.finished():
return
Expand Down Expand Up @@ -352,7 +349,7 @@ def prepare_for_extend(self, vocab_size: int, int_token_logit_bias: torch.Tensor
extend_num_tokens = seq_lens.sum() - prefix_lens.sum()
out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
if out_cache_loc is None:
self.tree_cache.evict(extend_num_tokens, self.token_to_kv_pool.dec_refs)
self.tree_cache.evict(extend_num_tokens, self.token_to_kv_pool.free)
out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)

if out_cache_loc is None:
Expand Down Expand Up @@ -422,7 +419,7 @@ def check_decode_mem(self):
if self.token_to_kv_pool.available_size() >= bs:
return True

self.tree_cache.evict(bs, self.token_to_kv_pool.dec_refs)
self.tree_cache.evict(bs, self.token_to_kv_pool.free)

if self.token_to_kv_pool.available_size() >= bs:
return True
Expand Down Expand Up @@ -453,7 +450,7 @@ def retract_decode(self):
token_indices = self.req_to_token_pool.req_to_token[
req_pool_indices_cpu[idx]
][last_uncached_pos : seq_lens_cpu[idx]]
self.token_to_kv_pool.dec_refs(token_indices)
self.token_to_kv_pool.free(token_indices)

# release the last node
self.tree_cache.dec_lock_ref(req.last_node)
Expand Down Expand Up @@ -596,8 +593,7 @@ def filter_batch(self, unfinished_indices: List[int]):
"logit_bias",
]:
self_val = getattr(self, item, None)
# logit_bias can be None
if self_val is not None:
if self_val is not None: # logit_bias can be None
setattr(self, item, self_val[new_indices])

def merge(self, other: "Batch"):
Expand Down
4 changes: 2 additions & 2 deletions python/sglang/srt/managers/controller/radix_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,12 @@ def cache_req(

if self.disable:
if del_in_memory_pool:
self.token_to_kv_pool.dec_refs(indices)
self.token_to_kv_pool.free(indices)
else:
return torch.tensor([], dtype=torch.int64), self.root_node

# Radix Cache takes one ref in memory pool
self.token_to_kv_pool.dec_refs(indices[last_uncached_pos:new_prefix_len])
self.token_to_kv_pool.free(indices[last_uncached_pos:new_prefix_len])

if del_in_memory_pool:
self.req_to_token_pool.free(req_pool_idx)
Expand Down
4 changes: 4 additions & 0 deletions python/sglang/srt/managers/controller/schedule_heuristic.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ def __init__(
max_total_num_tokens,
tree_cache,
):
if tree_cache.disable and schedule_heuristic == "lpm":
# LMP is not meaningless when tree cache is disabled.
schedule_heuristic = "fcfs"

self.schedule_heuristic = schedule_heuristic
self.max_running_seqs = max_running_seqs
self.max_prefill_num_tokens = max_prefill_num_tokens
Expand Down
56 changes: 26 additions & 30 deletions python/sglang/srt/managers/controller/tp_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def __init__(
)
self.max_total_num_tokens = self.model_runner.max_total_num_tokens
self.max_prefill_tokens = (
8192
16384
if server_args.max_prefill_tokens is None
else server_args.max_prefill_tokens
)
Expand Down Expand Up @@ -222,30 +222,29 @@ def forward_step(self):
# Run decode batch
if self.running_batch is not None:
# Run a few decode batches continuously for reducing overhead
for _ in range(10):
for _ in range(global_config.num_continue_decode_steps):
self.num_generated_tokens += len(self.running_batch.reqs)
self.forward_decode_batch(self.running_batch)

# Print stats
if self.tp_rank == 0:
if self.decode_forward_ct % 40 == 0:
num_used = self.max_total_num_tokens - (
self.token_to_kv_pool.available_size()
+ self.tree_cache.evictable_size()
)
throughput = self.num_generated_tokens / (
time.time() - self.last_stats_tic
)
self.num_generated_tokens = 0
self.last_stats_tic = time.time()
logger.info(
f"[gpu_id={self.gpu_id}] Decode batch. "
f"#running-req: {len(self.running_batch.reqs)}, "
f"#token: {num_used}, "
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
f"gen throughput (token/s): {throughput:.2f}, "
f"#queue-req: {len(self.forward_queue)}"
)
if self.tp_rank == 0 and self.decode_forward_ct % 40 == 0:
num_used = self.max_total_num_tokens - (
self.token_to_kv_pool.available_size()
+ self.tree_cache.evictable_size()
)
throughput = self.num_generated_tokens / (
time.time() - self.last_stats_tic
)
self.num_generated_tokens = 0
self.last_stats_tic = time.time()
logger.info(
f"[gpu_id={self.gpu_id}] Decode batch. "
f"#running-req: {len(self.running_batch.reqs)}, "
f"#token: {num_used}, "
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
f"gen throughput (token/s): {throughput:.2f}, "
f"#queue-req: {len(self.forward_queue)}"
)

if self.running_batch.is_empty():
self.running_batch = None
Expand Down Expand Up @@ -344,7 +343,7 @@ def get_new_fill_batch(self) -> Optional[Batch]:
if self.running_batch:
available_size -= sum(
[
(r.max_new_tokens() - len(r.output_ids)) * self.new_token_ratio
(r.sampling_params.max_new_tokens - len(r.output_ids)) * self.new_token_ratio
for r in self.running_batch.reqs
]
)
Expand All @@ -358,15 +357,15 @@ def get_new_fill_batch(self) -> Optional[Batch]:
req.prefix_indices = req.prefix_indices[:-delta]
if req.image_offset is not None:
req.image_offset += delta
if req.extend_input_len == 0 and req.max_new_tokens() > 0:
if req.extend_input_len == 0 and req.sampling_params.max_new_tokens > 0:
# Need at least one token to compute logits
req.extend_input_len = 1
req.prefix_indices = req.prefix_indices[:-1]
if req.image_offset is not None:
req.image_offset += 1

if (
req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens
req.extend_input_len + req.sampling_params.max_new_tokens + new_batch_total_tokens
< available_size
and (
req.extend_input_len + new_batch_input_tokens
Expand All @@ -378,7 +377,7 @@ def get_new_fill_batch(self) -> Optional[Batch]:
available_size += delta

if not (
req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens
req.extend_input_len + req.sampling_params.max_new_tokens + new_batch_total_tokens
< available_size
):
# Undo locking
Expand All @@ -389,7 +388,7 @@ def get_new_fill_batch(self) -> Optional[Batch]:
# Add this request to the running batch
can_run_list.append(req)
new_batch_total_tokens += (
req.extend_input_len + req.max_new_tokens()
req.extend_input_len + req.sampling_params.max_new_tokens
)
new_batch_input_tokens += req.extend_input_len
else:
Expand All @@ -403,9 +402,6 @@ def get_new_fill_batch(self) -> Optional[Batch]:

# Print stats
if self.tp_rank == 0:
running_req = (
0 if self.running_batch is None else len(self.running_batch.reqs)
)
hit_tokens = sum(len(x.prefix_indices) for x in can_run_list)
self.tree_cache_metrics["total"] += (
hit_tokens + new_batch_input_tokens
Expand All @@ -420,7 +416,7 @@ def get_new_fill_batch(self) -> Optional[Batch]:
f"#new-token: {new_batch_input_tokens}, "
f"#cached-token: {hit_tokens}, "
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
f"#running-req: {running_req}, "
f"#running-req: {running_bs}, "
f"#queue-req: {len(self.forward_queue) - len(can_run_list)}"
)
# logger.debug(
Expand Down
45 changes: 21 additions & 24 deletions python/sglang/srt/memory_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,45 +8,45 @@


class ReqToTokenPool:
def __init__(self, size, max_context_len):
"""A memory pool that maps a request to its token locations."""

def __init__(self, size: int, max_context_len: int):
self.mem_state = torch.ones((size,), dtype=torch.bool, device="cuda")
self.can_use_mem_size = size
self.req_to_token = torch.empty(
(size, max_context_len), dtype=torch.int32, device="cuda"
)
self.can_use_mem_size = size

def alloc(self, need_size):
def alloc(self, need_size: int):
if need_size > self.can_use_mem_size:
return None

select_index = torch.nonzero(self.mem_state).squeeze(1)[:need_size]
select_index = torch.nonzero(self.mem_state).squeeze(1)[:need_size].to(torch.int32)
self.mem_state[select_index] = False
self.can_use_mem_size -= need_size

return select_index.to(torch.int32)
return select_index

def free(self, free_index):
def free(self, free_index: int):
self.mem_state[free_index] = True
if isinstance(free_index, (int,)):
self.can_use_mem_size += 1
else:
self.can_use_mem_size += free_index.shape[0]

self.mem_state[free_index] = True

def clear(self):
self.mem_state.fill_(True)
self.can_use_mem_size = len(self.mem_state)


class TokenToKVPool:
"""A memory pool that maps a token to its kv cache locations"""

def __init__(self, size, dtype, head_num, head_dim, layer_num):
self.size = size

# This can be promised:
# assert torch.all(mem_state <= 1) and torch.all(mem_state >= 0)
# We also add one slot. This slot is used for writing dummy output from padded tokens.
self.mem_state = torch.ones((self.size + 1,), dtype=torch.bool, device="cuda")
self.can_use_mem_size = self.size

# [size, key/value, head_num, head_dim] for each layer
self.kv_data = [
Expand All @@ -58,6 +58,7 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num):
self.prefetch_buffer = torch.empty(0, device="cuda", dtype=torch.int32)
self.prefetch_chunk_size = 512

self.can_use_mem_size = self.size
self.clear()

def get_key_buffer(self, layer_id):
Expand All @@ -66,6 +67,9 @@ def get_key_buffer(self, layer_id):
def get_value_buffer(self, layer_id):
return self.kv_data[layer_id][:, 1]

def available_size(self):
return self.can_use_mem_size + len(self.prefetch_buffer)

def alloc(self, need_size):
buffer_len = len(self.prefetch_buffer)
if need_size <= buffer_len:
Expand All @@ -75,30 +79,23 @@ def alloc(self, need_size):

addition_size = need_size - buffer_len
alloc_size = max(addition_size, self.prefetch_chunk_size)
select_index = torch.nonzero(self.mem_state).squeeze(1)[:alloc_size]
select_index = select_index.to(torch.int32)
select_index = torch.nonzero(self.mem_state).squeeze(1)[:alloc_size].to(torch.int32)

if select_index.shape[0] < addition_size:
return None

self.add_refs(select_index)
self.mem_state[select_index] = False
self.can_use_mem_size -= len(select_index)

self.prefetch_buffer = torch.cat((self.prefetch_buffer, select_index))
ret_index = self.prefetch_buffer[:need_size]
self.prefetch_buffer = self.prefetch_buffer[need_size:]

return ret_index

def available_size(self):
return self.can_use_mem_size + len(self.prefetch_buffer)

def add_refs(self, token_index: torch.Tensor):
self.can_use_mem_size -= len(token_index)
self.mem_state[token_index] = False

def dec_refs(self, token_index: torch.Tensor):
self.can_use_mem_size += len(token_index)
self.mem_state[token_index] = True
def free(self, free_index: torch.Tensor):
self.mem_state[free_index] = True
self.can_use_mem_size += len(free_index)

def clear(self):
self.mem_state.fill_(True)
Expand Down