Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Speculative decoding] Initial spec decode docs #5400

Merged
merged 3 commits into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ Documentation
models/engine_args
models/lora
models/vlm
models/spec_decode
models/performance

.. toctree::
Expand Down
75 changes: 75 additions & 0 deletions docs/source/models/spec_decode.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
.. _spec_decode:

Speculative decoding in vLLM
============================

.. warning::
Please note that speculative decoding in vLLM is not yet optimized and does
not usually yield inter-token latency reductions for all prompt datasets or sampling parameters. The work
to optimize it is ongoing and can be followed in `this issue. <https://github.com/vllm-project/vllm/issues/4630>`_

This document shows how to use `Speculative Decoding <https://x.com/karpathy/status/1697318534555336961>`_ with vLLM.
Speculative decoding is a technique which improves inter-token latency in memory-bound LLM inference.

Speculating with a draft model
------------------------------

The following code configures vLLM to use speculative decoding with a draft model, speculating 5 tokens at a time.

.. code-block:: python
from vllm import LLM, SamplingParams
prompts = [
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
llm = LLM(
model="facebook/opt-6.7b",
tensor_parallel_size=1,
speculative_model="facebook/opt-125m",
num_speculative_tokens=5,
use_v2_block_manager=True,
)
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
Speculating by matching n-grams in the prompt
---------------------------------------------

The following code configures vLLM to use speculative decoding where proposals are generated by
matching n-grams in the prompt. For more information read `this thread. <https://x.com/joao_gante/status/1747322413006643259>`_

.. code-block:: python
from vllm import LLM, SamplingParams
prompts = [
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
llm = LLM(
model="facebook/opt-6.7b",
tensor_parallel_size=1,
speculative_model="[ngram]",
num_speculative_tokens=5,
ngram_prompt_lookup_max=4,
use_v2_block_manager=True,
)
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
Resources for vLLM contributors
-------------------------------
* `A Hacker's Guide to Speculative Decoding in vLLM <https://www.youtube.com/watch?v=9wNAgpX6z_4>`_
* `What is Lookahead Scheduling in vLLM? <https://docs.google.com/document/d/1Z9TvqzzBPnh5WHcRwjvK2UEeFeq5zMZb5mFE8jR0HCs/edit#heading=h.1fjfb0donq5a>`_
* `Information on batch expansion. <https://docs.google.com/document/d/1T-JaS2T1NRfdP51qzqpyakoCXxSXTtORppiwaj5asxA/edit#heading=h.kk7dq05lc6q8>`_
* `Dynamic speculative decoding <https://github.com/vllm-project/vllm/issues/4565>`_
Loading