Skip to content

Commit

Permalink
Flashinfer sample kernel (#617)
Browse files Browse the repository at this point in the history
  • Loading branch information
hnyls2002 authored Jul 17, 2024
1 parent 4efcc59 commit 3de2f30
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 30 deletions.
4 changes: 2 additions & 2 deletions python/sglang/bench_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,14 +156,14 @@ def extend(reqs, model_runner):
)
batch.prepare_for_extend(model_runner.model_config.vocab_size, None)
output = model_runner.forward(batch, ForwardMode.EXTEND)
next_token_ids, _ = batch.sample(output.next_token_logits)
next_token_ids = batch.sample(output.next_token_logits)
return next_token_ids, output.next_token_logits, batch


def decode(input_token_ids, batch, model_runner):
batch.prepare_for_decode(input_token_ids.cpu().numpy())
output = model_runner.forward(batch, ForwardMode.DECODE)
next_token_ids, _ = batch.sample(output.next_token_logits)
next_token_ids = batch.sample(output.next_token_logits)
return next_token_ids, output.next_token_logits


Expand Down
37 changes: 12 additions & 25 deletions python/sglang/srt/managers/controller/infer_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import numpy as np
import torch
from flashinfer.sampling import top_k_top_p_sampling_from_probs

from sglang.srt.constrained import RegexGuide
from sglang.srt.constrained.jump_forward import JumpForwardMap
Expand Down Expand Up @@ -398,10 +399,10 @@ def prepare_for_extend(self, vocab_size: int, int_token_logit_bias: torch.Tensor
).view(-1, 1)
self.top_ps = torch.tensor(
[r.sampling_params.top_p for r in reqs], dtype=torch.float, device=device
).view(-1, 1)
)
self.top_ks = torch.tensor(
[r.sampling_params.top_k for r in reqs], dtype=torch.int, device=device
).view(-1, 1)
)
self.frequency_penalties = torch.tensor(
[r.sampling_params.frequency_penalty for r in reqs],
dtype=torch.float,
Expand Down Expand Up @@ -659,20 +660,17 @@ def sample(self, logits: torch.Tensor):

# TODO(lmzheng): apply penalty
probs = torch.softmax(logits, dim=-1)
probs_sort, probs_idx = _top_p_top_k(probs, self.top_ps, self.top_ks)
try:
sampled_index = torch.multinomial(probs_sort, num_samples=1)
max_top_k_round, batch_size = 32, probs.shape[0]
uniform_samples = torch.rand(
(max_top_k_round, batch_size), device=probs.device
)
batch_next_token_ids, _ = top_k_top_p_sampling_from_probs(
probs, uniform_samples, self.top_ks, self.top_ps
)
except RuntimeError as e:
warnings.warn(f"Ignore errors in sampling: {e}")
sampled_index = torch.ones(
probs_sort.shape[:-1] + (1,), dtype=torch.int64, device=probs.device
)
batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(
-1
)
batch_next_token_probs = torch.gather(
probs_sort, dim=1, index=sampled_index
).view(-1)
batch_next_token_ids = torch.argmax(probs, dim=-1)

if has_regex:
batch_next_token_ids_cpu = batch_next_token_ids.cpu().numpy()
Expand All @@ -682,18 +680,7 @@ def sample(self, logits: torch.Tensor):
req.regex_fsm_state, batch_next_token_ids_cpu[i]
)

return batch_next_token_ids, batch_next_token_probs


def _top_p_top_k(probs: torch.Tensor, top_ps: torch.Tensor, top_ks: torch.Tensor):
probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
probs_sort[(probs_sum - probs_sort) > top_ps] = 0.0
probs_sort[
torch.arange(0, probs.shape[-1], device=probs.device).view(1, -1) >= top_ks
] = 0.0
probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
return probs_sort, probs_idx
return batch_next_token_ids


@dataclass
Expand Down
4 changes: 2 additions & 2 deletions python/sglang/srt/managers/controller/tp_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,7 @@ def forward_prefill_batch(self, batch: Batch):
# Forward and sample the next tokens
if batch.extend_num_tokens != 0:
output = self.model_runner.forward(batch, ForwardMode.EXTEND)
next_token_ids, _ = batch.sample(output.next_token_logits)
next_token_ids = batch.sample(output.next_token_logits)

# Move logprobs to cpu
if output.next_token_logprobs is not None:
Expand Down Expand Up @@ -574,7 +574,7 @@ def forward_decode_batch(self, batch: Batch):

# Forward and sample the next tokens
output = self.model_runner.forward(batch, ForwardMode.DECODE)
next_token_ids, _ = batch.sample(output.next_token_logits)
next_token_ids = batch.sample(output.next_token_logits)

# Move logprobs to cpu
if output.next_token_logprobs is not None:
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
if not server_args.disable_flashinfer:
assert_pkg_version(
"flashinfer",
"0.0.8",
"0.1.0",
"Please uninstall the old version and "
"reinstall the latest version by following the instructions "
"at https://docs.flashinfer.ai/installation.html.",
Expand Down

0 comments on commit 3de2f30

Please sign in to comment.