From e175f3a10dc8ae08d3076bf4d22a66141cfa36fe Mon Sep 17 00:00:00 2001 From: hnyls2002 Date: Sat, 13 Jul 2024 23:45:23 +0000 Subject: [PATCH 1/4] use flashinfer sample --- .../srt/managers/controller/infer_batch.py | 35 +++++++------------ 1 file changed, 12 insertions(+), 23 deletions(-) diff --git a/python/sglang/srt/managers/controller/infer_batch.py b/python/sglang/srt/managers/controller/infer_batch.py index d89e9786e2c..2beb7349870 100644 --- a/python/sglang/srt/managers/controller/infer_batch.py +++ b/python/sglang/srt/managers/controller/infer_batch.py @@ -7,6 +7,7 @@ import numpy as np import torch +from flashinfer.sampling import top_k_top_p_sampling_from_probs from sglang.srt.constrained import RegexGuide from sglang.srt.constrained.jump_forward import JumpForwardMap @@ -401,10 +402,10 @@ def prepare_for_extend(self, vocab_size: int, int_token_logit_bias: torch.Tensor ).view(-1, 1) self.top_ps = torch.tensor( [r.sampling_params.top_p for r in reqs], dtype=torch.float, device=device - ).view(-1, 1) + ) self.top_ks = torch.tensor( [r.sampling_params.top_k for r in reqs], dtype=torch.int, device=device - ).view(-1, 1) + ) self.frequency_penalties = torch.tensor( [r.sampling_params.frequency_penalty for r in reqs], dtype=torch.float, @@ -663,18 +664,17 @@ def sample(self, logits: torch.Tensor): # TODO(lmzheng): apply penalty probs = torch.softmax(logits, dim=-1) - probs_sort, probs_idx = _top_p_top_k(probs, self.top_ps, self.top_ks) try: - sampled_index = torch.multinomial(probs_sort, num_samples=1) + max_top_k_round, batch_size = 32, probs.shape[0] + uniform_samples = torch.rand( + (max_top_k_round, batch_size), device=probs.device + ) + batch_next_token_ids, _ = top_k_top_p_sampling_from_probs( + probs, uniform_samples, self.top_ks, self.top_ps + ) except RuntimeError as e: warnings.warn(f"Ignore errors in sampling: {e}") - sampled_index = torch.ones(probs_sort.shape[:-1] + (1,), dtype=torch.int64, device=probs.device) - batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view( - -1 - ) - batch_next_token_probs = torch.gather( - probs_sort, dim=1, index=sampled_index - ).view(-1) + # TODO: handle this error if has_regex: batch_next_token_ids_cpu = batch_next_token_ids.cpu().numpy() @@ -684,18 +684,7 @@ def sample(self, logits: torch.Tensor): req.regex_fsm_state, batch_next_token_ids_cpu[i] ) - return batch_next_token_ids, batch_next_token_probs - - -def _top_p_top_k(probs: torch.Tensor, top_ps: torch.Tensor, top_ks: torch.Tensor): - probs_sort, probs_idx = probs.sort(dim=-1, descending=True) - probs_sum = torch.cumsum(probs_sort, dim=-1) - probs_sort[(probs_sum - probs_sort) > top_ps] = 0.0 - probs_sort[ - torch.arange(0, probs.shape[-1], device=probs.device).view(1, -1) >= top_ks - ] = 0.0 - probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0]) - return probs_sort, probs_idx + return batch_next_token_ids, _ @dataclass From 6b5bfaf7c6867e54f7c21f242da47eabc0840b88 Mon Sep 17 00:00:00 2001 From: hnyls2002 Date: Sun, 14 Jul 2024 00:22:05 +0000 Subject: [PATCH 2/4] simplify --- python/sglang/srt/layers/radix_attention.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/layers/radix_attention.py b/python/sglang/srt/layers/radix_attention.py index 73e122a758d..31293b89214 100644 --- a/python/sglang/srt/layers/radix_attention.py +++ b/python/sglang/srt/layers/radix_attention.py @@ -138,8 +138,5 @@ def forward(self, q, k, v, input_metadata: InputMetadata): def store_kv_cache(self, cache_k, cache_v, input_metadata: InputMetadata): key_buffer = input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id) value_buffer = input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id) - if input_metadata.out_cache_loc is not None: - key_buffer[input_metadata.out_cache_loc] = cache_k - value_buffer[input_metadata.out_cache_loc] = cache_v - else: - raise RuntimeError() + key_buffer[input_metadata.out_cache_loc] = cache_k + value_buffer[input_metadata.out_cache_loc] = cache_v From d73a02be41cb9e654c47bb5bf78d67f69ebaf910 Mon Sep 17 00:00:00 2001 From: hnyls2002 Date: Wed, 17 Jul 2024 20:18:46 +0000 Subject: [PATCH 3/4] handle sample error --- python/sglang/srt/managers/controller/infer_batch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/controller/infer_batch.py b/python/sglang/srt/managers/controller/infer_batch.py index 9d229833ac6..5298eb30088 100644 --- a/python/sglang/srt/managers/controller/infer_batch.py +++ b/python/sglang/srt/managers/controller/infer_batch.py @@ -670,7 +670,7 @@ def sample(self, logits: torch.Tensor): ) except RuntimeError as e: warnings.warn(f"Ignore errors in sampling: {e}") - # TODO: handle this error + batch_next_token_ids = torch.argmax(probs, dim=-1) if has_regex: batch_next_token_ids_cpu = batch_next_token_ids.cpu().numpy() From 312e02370ded00671a8c1985bb53d6a7664e80e7 Mon Sep 17 00:00:00 2001 From: hnyls2002 Date: Wed, 17 Jul 2024 20:23:46 +0000 Subject: [PATCH 4/4] remove useless probs return, which may cause errors when sample fails --- python/sglang/bench_latency.py | 4 ++-- python/sglang/srt/managers/controller/infer_batch.py | 2 +- python/sglang/srt/managers/controller/tp_worker.py | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/python/sglang/bench_latency.py b/python/sglang/bench_latency.py index 45d23b63da4..c2eb93a241d 100644 --- a/python/sglang/bench_latency.py +++ b/python/sglang/bench_latency.py @@ -156,14 +156,14 @@ def extend(reqs, model_runner): ) batch.prepare_for_extend(model_runner.model_config.vocab_size, None) output = model_runner.forward(batch, ForwardMode.EXTEND) - next_token_ids, _ = batch.sample(output.next_token_logits) + next_token_ids = batch.sample(output.next_token_logits) return next_token_ids, output.next_token_logits, batch def decode(input_token_ids, batch, model_runner): batch.prepare_for_decode(input_token_ids.cpu().numpy()) output = model_runner.forward(batch, ForwardMode.DECODE) - next_token_ids, _ = batch.sample(output.next_token_logits) + next_token_ids = batch.sample(output.next_token_logits) return next_token_ids, output.next_token_logits diff --git a/python/sglang/srt/managers/controller/infer_batch.py b/python/sglang/srt/managers/controller/infer_batch.py index 5298eb30088..6f0a08f379a 100644 --- a/python/sglang/srt/managers/controller/infer_batch.py +++ b/python/sglang/srt/managers/controller/infer_batch.py @@ -680,7 +680,7 @@ def sample(self, logits: torch.Tensor): req.regex_fsm_state, batch_next_token_ids_cpu[i] ) - return batch_next_token_ids, _ + return batch_next_token_ids @dataclass diff --git a/python/sglang/srt/managers/controller/tp_worker.py b/python/sglang/srt/managers/controller/tp_worker.py index 897cab140ed..80b0516445f 100644 --- a/python/sglang/srt/managers/controller/tp_worker.py +++ b/python/sglang/srt/managers/controller/tp_worker.py @@ -451,7 +451,7 @@ def forward_prefill_batch(self, batch: Batch): # Forward and sample the next tokens if batch.extend_num_tokens != 0: output = self.model_runner.forward(batch, ForwardMode.EXTEND) - next_token_ids, _ = batch.sample(output.next_token_logits) + next_token_ids = batch.sample(output.next_token_logits) # Move logprobs to cpu if output.next_token_logprobs is not None: @@ -574,7 +574,7 @@ def forward_decode_batch(self, batch: Batch): # Forward and sample the next tokens output = self.model_runner.forward(batch, ForwardMode.DECODE) - next_token_ids, _ = batch.sample(output.next_token_logits) + next_token_ids = batch.sample(output.next_token_logits) # Move logprobs to cpu if output.next_token_logprobs is not None: