From 3de2f30a27b1d9ffef6dfddcdcc7877c2a2dc857 Mon Sep 17 00:00:00 2001 From: Liangsheng Yin Date: Wed, 17 Jul 2024 13:24:43 -0700 Subject: [PATCH] Flashinfer sample kernel (#617) --- python/sglang/bench_latency.py | 4 +- .../srt/managers/controller/infer_batch.py | 37 ++++++------------- .../srt/managers/controller/tp_worker.py | 4 +- python/sglang/srt/server.py | 2 +- 4 files changed, 17 insertions(+), 30 deletions(-) diff --git a/python/sglang/bench_latency.py b/python/sglang/bench_latency.py index 45d23b63da4..c2eb93a241d 100644 --- a/python/sglang/bench_latency.py +++ b/python/sglang/bench_latency.py @@ -156,14 +156,14 @@ def extend(reqs, model_runner): ) batch.prepare_for_extend(model_runner.model_config.vocab_size, None) output = model_runner.forward(batch, ForwardMode.EXTEND) - next_token_ids, _ = batch.sample(output.next_token_logits) + next_token_ids = batch.sample(output.next_token_logits) return next_token_ids, output.next_token_logits, batch def decode(input_token_ids, batch, model_runner): batch.prepare_for_decode(input_token_ids.cpu().numpy()) output = model_runner.forward(batch, ForwardMode.DECODE) - next_token_ids, _ = batch.sample(output.next_token_logits) + next_token_ids = batch.sample(output.next_token_logits) return next_token_ids, output.next_token_logits diff --git a/python/sglang/srt/managers/controller/infer_batch.py b/python/sglang/srt/managers/controller/infer_batch.py index 387d8f471f4..6f0a08f379a 100644 --- a/python/sglang/srt/managers/controller/infer_batch.py +++ b/python/sglang/srt/managers/controller/infer_batch.py @@ -7,6 +7,7 @@ import numpy as np import torch +from flashinfer.sampling import top_k_top_p_sampling_from_probs from sglang.srt.constrained import RegexGuide from sglang.srt.constrained.jump_forward import JumpForwardMap @@ -398,10 +399,10 @@ def prepare_for_extend(self, vocab_size: int, int_token_logit_bias: torch.Tensor ).view(-1, 1) self.top_ps = torch.tensor( [r.sampling_params.top_p for r in reqs], dtype=torch.float, device=device - ).view(-1, 1) + ) self.top_ks = torch.tensor( [r.sampling_params.top_k for r in reqs], dtype=torch.int, device=device - ).view(-1, 1) + ) self.frequency_penalties = torch.tensor( [r.sampling_params.frequency_penalty for r in reqs], dtype=torch.float, @@ -659,20 +660,17 @@ def sample(self, logits: torch.Tensor): # TODO(lmzheng): apply penalty probs = torch.softmax(logits, dim=-1) - probs_sort, probs_idx = _top_p_top_k(probs, self.top_ps, self.top_ks) try: - sampled_index = torch.multinomial(probs_sort, num_samples=1) + max_top_k_round, batch_size = 32, probs.shape[0] + uniform_samples = torch.rand( + (max_top_k_round, batch_size), device=probs.device + ) + batch_next_token_ids, _ = top_k_top_p_sampling_from_probs( + probs, uniform_samples, self.top_ks, self.top_ps + ) except RuntimeError as e: warnings.warn(f"Ignore errors in sampling: {e}") - sampled_index = torch.ones( - probs_sort.shape[:-1] + (1,), dtype=torch.int64, device=probs.device - ) - batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view( - -1 - ) - batch_next_token_probs = torch.gather( - probs_sort, dim=1, index=sampled_index - ).view(-1) + batch_next_token_ids = torch.argmax(probs, dim=-1) if has_regex: batch_next_token_ids_cpu = batch_next_token_ids.cpu().numpy() @@ -682,18 +680,7 @@ def sample(self, logits: torch.Tensor): req.regex_fsm_state, batch_next_token_ids_cpu[i] ) - return batch_next_token_ids, batch_next_token_probs - - -def _top_p_top_k(probs: torch.Tensor, top_ps: torch.Tensor, top_ks: torch.Tensor): - probs_sort, probs_idx = probs.sort(dim=-1, descending=True) - probs_sum = torch.cumsum(probs_sort, dim=-1) - probs_sort[(probs_sum - probs_sort) > top_ps] = 0.0 - probs_sort[ - torch.arange(0, probs.shape[-1], device=probs.device).view(1, -1) >= top_ks - ] = 0.0 - probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0]) - return probs_sort, probs_idx + return batch_next_token_ids @dataclass diff --git a/python/sglang/srt/managers/controller/tp_worker.py b/python/sglang/srt/managers/controller/tp_worker.py index 897cab140ed..80b0516445f 100644 --- a/python/sglang/srt/managers/controller/tp_worker.py +++ b/python/sglang/srt/managers/controller/tp_worker.py @@ -451,7 +451,7 @@ def forward_prefill_batch(self, batch: Batch): # Forward and sample the next tokens if batch.extend_num_tokens != 0: output = self.model_runner.forward(batch, ForwardMode.EXTEND) - next_token_ids, _ = batch.sample(output.next_token_logits) + next_token_ids = batch.sample(output.next_token_logits) # Move logprobs to cpu if output.next_token_logprobs is not None: @@ -574,7 +574,7 @@ def forward_decode_batch(self, batch: Batch): # Forward and sample the next tokens output = self.model_runner.forward(batch, ForwardMode.DECODE) - next_token_ids, _ = batch.sample(output.next_token_logits) + next_token_ids = batch.sample(output.next_token_logits) # Move logprobs to cpu if output.next_token_logprobs is not None: diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 57862c42c51..3e52bfcddbf 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -154,7 +154,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg if not server_args.disable_flashinfer: assert_pkg_version( "flashinfer", - "0.0.8", + "0.1.0", "Please uninstall the old version and " "reinstall the latest version by following the instructions " "at https://docs.flashinfer.ai/installation.html.",