Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sampling: fused speculative sampling kernels #259

Merged
merged 3 commits into from
May 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
371 changes: 368 additions & 3 deletions include/flashinfer/sampling.cuh

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions python/csrc/batch_prefill.cu
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ void BatchPrefillWithPagedKVCachePyTorchWrapper::BeginForward(

cudaError_t status =
handler_->BeginForward(static_cast<void*>(workspace_buffer.data_ptr()),
workspace_size_in_bytes, static_cast<int32_t*>(qo_indptr.data_ptr()),
batch_size, num_qo_heads, num_kv_heads, head_dim);
workspace_size_in_bytes, static_cast<int32_t*>(qo_indptr.data_ptr()),
batch_size, num_qo_heads, num_kv_heads, head_dim);
TORCH_CHECK(status == cudaSuccess, "BatchPrefillWithPagedKVCache failed with error ",
cudaGetErrorString(status));
}
Expand Down Expand Up @@ -166,8 +166,8 @@ void BatchPrefillWithRaggedKVCachePyTorchWrapper::BeginForward(

cudaError_t status =
handler_->BeginForward(static_cast<void*>(workspace_buffer.data_ptr()),
workspace_size_in_bytes, static_cast<int32_t*>(qo_indptr.data_ptr()),
batch_size, num_qo_heads, num_kv_heads, head_dim);
workspace_size_in_bytes, static_cast<int32_t*>(qo_indptr.data_ptr()),
batch_size, num_qo_heads, num_kv_heads, head_dim);
TORCH_CHECK(status == cudaSuccess, "BatchPrefillWithPagedKVCache failed with error ",
cudaGetErrorString(status));
}
Expand Down
4 changes: 4 additions & 0 deletions python/csrc/flashinfer_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"Top-k sampling from probabilities");
m.def("top_p_sampling_from_probs", &top_p_sampling_from_probs,
"Top-p sampling from probabilities");
m.def("top_k_renorm_prob", &top_k_renorm_prob, "Renormalize probabilities by top-k mask");
m.def("top_p_renorm_prob", &top_p_renorm_prob, "Renormalize probabilities by top-p mask");
m.def("chain_speculative_sampling", &chain_speculative_sampling,
"Speculative sampling from sequence of probabilities");
m.def("rmsnorm", &rmsnorm, "Root mean square normalization");
py::class_<BatchDecodeWithPagedKVCachePyTorchWrapper>(m,
"BatchDecodeWithPagedKVCachePyTorchWrapper")
Expand Down
10 changes: 8 additions & 2 deletions python/csrc/flashinfer_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,13 @@ std::vector<torch::Tensor> top_k_sampling_from_probs(torch::Tensor probs,
torch::Tensor uniform_samples,
unsigned int top_k);

torch::Tensor top_p_renorm_prob(torch::Tensor probs, double top_p, double eps);

torch::Tensor top_k_renorm_prob(torch::Tensor probs, unsigned int top_k, double eps);

torch::Tensor chain_speculative_sampling(torch::Tensor draft_probs, torch::Tensor draft_token_ids,
torch::Tensor uniform_samples, torch::Tensor target_probs);

torch::Tensor rmsnorm(torch::Tensor x, torch::Tensor w, double eps);

class BatchDecodeWithPagedKVCachePyTorchWrapper {
Expand All @@ -83,8 +90,7 @@ class BatchDecodeWithPagedKVCachePyTorchWrapper {
BatchDecodeWithPagedKVCachePyTorchWrapper(unsigned int layout,
unsigned int max_workspace_size_in_bytes)
: kv_layout_(flashinfer::QKVLayout(layout)),
handler_(
std::make_shared<flashinfer::BatchDecodeHandler>(max_workspace_size_in_bytes)) {}
handler_(std::make_shared<flashinfer::BatchDecodeHandler>(max_workspace_size_in_bytes)) {}

protected:
std::shared_ptr<flashinfer::BatchDecodeHandler> handler_;
Expand Down
82 changes: 82 additions & 0 deletions python/csrc/sampling.cu
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,85 @@ std::vector<torch::Tensor> top_k_sampling_from_probs(torch::Tensor probs,

return {samples, success};
}

torch::Tensor top_p_renorm_prob(torch::Tensor probs, double top_p, double eps) {
CHECK_INPUT(probs);
CHECK_DIM(2, probs); // probs: (batch_size, vocab_size)
unsigned int batch_size = probs.size(0);
unsigned int vocab_size = probs.size(1);
probs = probs.to(torch::kFloat32);

cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();
auto renorm_probs =
torch::empty({batch_size, vocab_size}, torch::dtype(torch::kFloat32).device(probs.device()));

cudaError_t status = sampling::TopPRenormProb<float>(
static_cast<float*>(probs.data_ptr()), static_cast<float*>(renorm_probs.data_ptr()), top_p,
eps, batch_size, vocab_size, torch_current_stream);
TORCH_CHECK(status == cudaSuccess,
"TopPRenormProb failed with error code " + std::string(cudaGetErrorString(status)));
return renorm_probs;
}

torch::Tensor top_k_renorm_prob(torch::Tensor probs, unsigned int top_k, double eps) {
CHECK_INPUT(probs);
CHECK_DIM(2, probs); // probs: (batch_size, vocab_size)
unsigned int batch_size = probs.size(0);
unsigned int vocab_size = probs.size(1);
probs = probs.to(torch::kFloat32);

cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();
auto renorm_probs =
torch::empty({batch_size, vocab_size}, torch::dtype(torch::kFloat32).device(probs.device()));

cudaError_t status = sampling::TopKRenormProb<float>(
static_cast<float*>(probs.data_ptr()), static_cast<float*>(renorm_probs.data_ptr()), top_k,
eps, batch_size, vocab_size, torch_current_stream);

TORCH_CHECK(status == cudaSuccess,
"TopKRenormProb failed with error code " + std::string(cudaGetErrorString(status)));
return renorm_probs;
}

torch::Tensor chain_speculative_sampling(torch::Tensor draft_probs, torch::Tensor draft_token_ids,
torch::Tensor uniform_samples,
torch::Tensor target_probs) {
CHECK_INPUT(draft_probs);
CHECK_INPUT(draft_token_ids);
CHECK_INPUT(uniform_samples);
CHECK_INPUT(target_probs);
CHECK_DIM(3, draft_probs); // draft_probs: (batch_size, num_speculate_tokens, vocab_size)
CHECK_DIM(2, draft_token_ids); // draft_token_ids: (batch_size, num_speculate_tokens)
CHECK_DIM(2, uniform_samples); // uniform_samples: (batch_size, num_speculate_tokens + 1)
CHECK_DIM(3, target_probs); // target_probs: (batch_size, num_speculate_tokens + 1, vocab_size)
unsigned int batch_size = draft_probs.size(0);
unsigned int num_speculate_tokens = draft_probs.size(1);
unsigned int vocab_size = draft_probs.size(2);
CHECK_EQ(batch_size, draft_token_ids.size(0));
CHECK_EQ(batch_size, uniform_samples.size(0));
CHECK_EQ(batch_size, target_probs.size(0));
CHECK_EQ(num_speculate_tokens + 1, uniform_samples.size(1));
CHECK_EQ(num_speculate_tokens + 1, target_probs.size(1));
CHECK_EQ(vocab_size, target_probs.size(2));

draft_probs = draft_probs.to(torch::kFloat32);
draft_token_ids = draft_token_ids.to(torch::kInt32);
uniform_samples = uniform_samples.to(torch::kFloat32);
target_probs = target_probs.to(torch::kFloat32);

cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();
auto output_token_ids =
torch::empty({batch_size, num_speculate_tokens + 1},
torch::dtype(torch::kInt32).device(draft_token_ids.device()));

cudaError_t status = sampling::ChainSpeculativeSampling<float, int>(
static_cast<float*>(draft_probs.data_ptr()), static_cast<int*>(draft_token_ids.data_ptr()),
static_cast<float*>(uniform_samples.data_ptr()), static_cast<float*>(target_probs.data_ptr()),
static_cast<int*>(output_token_ids.data_ptr()), batch_size, num_speculate_tokens, vocab_size,
torch_current_stream);

TORCH_CHECK(status == cudaSuccess, "ChainSpeculativeSampling failed with error code " +
std::string(cudaGetErrorString(status)));

return output_token_ids;
}
3 changes: 3 additions & 0 deletions python/flashinfer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@
sampling_from_probs,
top_p_sampling_from_probs,
top_k_sampling_from_probs,
top_p_renorm_prob,
top_k_renorm_prob,
chain_speculative_sampling,
)
from .norm import rmsnorm

Expand Down
8 changes: 4 additions & 4 deletions python/flashinfer/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,9 +632,9 @@ def forward_return_lse(


class CUDAGraphBatchDecodeWithPagedKVCacheWrapper:
r"""CUDAGraph-compatible Wrapper class for decode attention with paged kv-cache (first
proposed in `vLLM <https://arxiv.org/abs/2309.06180>`_) for batch of requests.
r"""CUDAGraph-compatible Wrapper class for decode attention with paged kv-cache (first
proposed in `vLLM <https://arxiv.org/abs/2309.06180>`_) for batch of requests.

Note that this wrapper may not be as efficient as :class:`BatchDecodeWithPagedKVCacheWrapper`
because we won't dispatch to different kernels for different batch sizes/sequence lengths/etc
to accomodate the CUDAGraph requirement.
Expand Down Expand Up @@ -673,7 +673,7 @@ def __init__(
during the lifecycle of this wrapper.
indices_buffer : torch.Tensor
The user reserved buffer on GPU to store the page indices of the paged kv cache,
should be large enough to store the maximum number of page indices
should be large enough to store the maximum number of page indices
(``max_num_pages``) during the lifecycle of this wrapper.
last_page_len_buffer : torch.Tensor
The user reserved buffer on GPU to store the number of entries in the last page,
Expand Down
109 changes: 104 additions & 5 deletions python/flashinfer/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@


def sampling_from_probs(probs: torch.Tensor, uniform_samples: torch.Tensor):
r"""Category sampling from probabilities.
r"""Fused GPU kernel for category sampling from probabilities.

Parameters
----------
Expand Down Expand Up @@ -75,8 +75,11 @@ def sampling_from_probs(probs: torch.Tensor, uniform_samples: torch.Tensor):
def top_p_sampling_from_probs(
probs: torch.Tensor, uniform_samples: torch.Tensor, top_p: float
):
r"""Top-p sampling (nucleus sampling) from probabilities, this operator implements
GPU-based rejection sampling without explicit sorting.
r"""Fused GPU kernel for top-p sampling (nucleus sampling) from probabilities,
this operator implements GPU-based rejection sampling without explicit sorting.

The multiple rounds of rejection sampling are implemented in a single CUDA kernel,
which is more efficient than the naive implementation that launches a series of kernels.

Parameters
----------
Expand Down Expand Up @@ -134,8 +137,11 @@ def top_p_sampling_from_probs(
def top_k_sampling_from_probs(
probs: torch.Tensor, uniform_samples: torch.Tensor, top_k: int
):
r"""Top-k sampling from probabilities, this operator implements GPU-based rejection sampling
without explicit sorting.
r"""Fused GPU kernel for top-k sampling from probabilities,
this operator implements GPU-based rejection sampling without explicit sorting.

The multiple rounds of rejection sampling are implemented in a single CUDA kernel,
which is more efficient than the naive implementation that launches a series of kernels.

Parameters
----------
Expand Down Expand Up @@ -188,3 +194,96 @@ def top_k_sampling_from_probs(
implementation usually use much fewer rounds for rejection sampling because of early stopping.
"""
return _kernels.top_k_sampling_from_probs(probs, uniform_samples, top_k)


def top_p_renorm_prob(probs: torch.Tensor, top_p: float, eps: float = 1e-5):
r"""Fused GPU kernel for renormalizing probabilities by top-p thresholding.

Parameters
----------
probs: torch.Tensor
Probabilities, shape ``(batch_size, num_classes)``.
top_p: float
The threshold for re-normalizing probabilities, should be in ``(0, 1)``.
We mask out the probabilities less than `threshold` where the cumulative sum
of ``probs[probs >= threshold]`` is `top_p`, and renormalize the probabilities.
eps: float
The epsilon value for numerical stability.

Returns
-------
renorm_probs: torch.Tensor
Renormalized probabilities, shape ``(batch_size, num_classes)``.

This combination of ``top_p_renorm_prob`` and ``sampling_from_probs`` should be equivalent to
``top_p_sampling_from_probs``.
"""
return _kernels.top_p_renorm_prob(probs, top_p, eps)


def top_k_renorm_prob(probs: torch.Tensor, top_k: int, eps: float = 1e-5):
r"""Fused GPU kernel for renormalizing probabilities by top-k thresholding.

Parameters
----------
probs: torch.Tensor
Probabilities, shape ``(batch_size, num_classes)``.
top_k: int
The threshold for re-normalizing probabilities, should be in ``(0, num_classes)``.
We keep the top-k probabilities, set the rest to zero, and renormalize the probabilities.
eps: float
The epsilon value for numerical stability.

Returns
-------
renorm_probs: torch.Tensor
Renormalized probabilities, shape ``(batch_size, num_classes)``.

Note
----
This combination of ``top_k_renorm_prob`` and ``sampling_from_probs`` should be equivalent to
``top_k_sampling_from_probs``.
"""
return _kernels.top_k_renorm_prob(probs, top_k, eps)


def chain_speculative_sampling(
draft_probs,
draft_token_ids,
uniform_samples,
target_probs,
):
r"""Fused-GPU kernel for speculative sampling for sequence generation (proposed in
paper `Accelerating Large Language Model Decoding with Speculative Sampling <https://arxiv.org/pdf/2302.01318>`_),
where the draft model generates a sequence(chain) of tokens for each request.

Parameters
----------
draft_probs: torch.Tensor
The probability over vocabulary generated by draft model.
Shape: ``(batch_size, num_speculate_tokens, vocab_size)``
draft_token_ids: torch.Tensor
The draft model's generated token indices.
Shape: ``(batch_size, num_specutate_tokens)``
uniform_samples: torch.Tensor
The uniform samples used as needle for sampling, shape ``(batch_size, num_speculate_tokens + 1)``.
Expected to be uniformly distributed in ``[0, 1)``.
target_probs: torch.Tensor
The probability over vocabulary generated by target model.
Compared to input :attr:`draft_probs`, the target model's probability has an additional
slot at the end because the target model will generate one more token than the draft model.
Shape: ``(batch_size, num_speculate_tokens + 1, vocab_size)``

Returns
-------
output_token_ids: torch.Tensor
The output token indices verified by the target model, rejected samples are
padded with ``-1``.
Compared to input :attr:`draft_token_ids`, the output tensor has an additional
token index at the end for the final token, if all previous tokens are accepted,
another "bonus" token will be sampled from the target model's probability.
Shape: (batch_size, num_specutate_tokens + 1)
"""
return _kernels.chain_speculative_sampling(
draft_probs, draft_token_ids, uniform_samples, target_probs
)
6 changes: 3 additions & 3 deletions python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,14 @@ def get_instantiation_cu() -> List[str]:
(root / prefix).mkdir(parents=True, exist_ok=True)

group_sizes = os.environ.get("FLASHINFER_GROUP_SIZES", "1,4,6,8").split(",")
page_sizes = os.environ.get("FLASHINFER_PAGE_SIZES", "1,16,32").split(",")
head_dims = os.environ.get("FLASHINFER_HEAD_DIMS", "64,128,256").split(",")
page_sizes = os.environ.get("FLASHINFER_PAGE_SIZES", "1").split(",")
head_dims = os.environ.get("FLASHINFER_HEAD_DIMS", "128,256").split(",")
kv_layouts = os.environ.get("FLASHINFER_KV_LAYOUTS", "0,1").split(",")
pos_encoding_modes = os.environ.get("FLASHINFER_POS_ENCODING_MODES", "0,1,2").split(
","
)
allow_fp16_qk_reduction_options = os.environ.get(
"FLASHINFER_ALLOW_FP16_QK_REDUCTION_OPTIONS", "0,1"
"FLASHINFER_ALLOW_FP16_QK_REDUCTION_OPTIONS", "0"
).split(",")
causal_options = os.environ.get("FLASHINFER_CAUSAL_OPTIONS", "0,1").split(",")
# dispatch.inc
Expand Down
Loading