Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speculative decoding with lookahead #2790

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

jjjjohnson
Copy link
Contributor

@jjjjohnson jjjjohnson commented Jan 8, 2025

Motivation

n-gram based speculative is very effective in retrieval augmented generation(RAG). The cost of generating draft tokens is relatively low compared to eagle and has a great potential for accelerating token generation in RAG. Ant group has proposed the Trie-based retrieval and verification mechanism. They claimed to use lookahead based on vLLM for the single-query situation and obtain 1.6 times acceleration on a real-life scenario. I want to adopt lookahead to SGLang.

Related resources

Lookahead: An Inference Acceleration Framework for Large Language Model with Lossless Generation Accuracy

Overall workflow

image

Features

  • No need to train draft model.
  • Trie tree will be updated with both prompt tokens and output tokens.
  • The draft tokens generation is a frequency based sort mechanism from the specific prompt tokens and ALL history output tokens(with evict).
  • Both Single-branch and Multi-branch are supported.

Checklist

  • Format your code according to the Contributor Guide.
  • Add unit tests as outlined in the Contributor Guide.
  • Update documentation as needed, including docstrings or example tutorials.

@jjjjohnson
Copy link
Contributor Author

import sglang as sgl
import time
import json
import numpy as np

def main():
    # Sample prompts.
    prompts = [
        '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n你是谁?<|im_end|>\n<|im_start|>assistant\n'
    ]

    sampling_params = {"temperature": 0.7, "repetition_penalty":1,
                       "max_new_tokens": 256,"top_k": 1,
                       "stop_token_ids": [151645, 151644, 151643]}


    model_path = "Qwen/Qwen2-7B-Instruct"

    # Create an LLM.
    llm = sgl.Engine(model_path=model_path, speculative_one_branch=True, disable_cuda_graph=False, 
                     speculative_num_draft_tokens=4, speculative_algorithm='LOOKAHEAD', mem_fraction_static=0.60, 
                     watchdog_timeout=1e8, log_level='info')


    for idx in range(5):
        start = time.time()
        outputs = llm.generate(prompts, sampling_params)
        cos = time.time()-start
        completion_tokens = 0
        # Print the outputs.
        for prompt, output in zip(prompts, outputs):
            completion_tokens += output["meta_info"]["completion_tokens"]
            print(f"{output['text']}")
            print('======================')
        print(f"{idx=}!!!!!!!!! tps =: {completion_tokens/cos}\n\n")

if __name__ == "__main__":
    main()
image

@zhyncs
Copy link
Member

zhyncs commented Jan 11, 2025

Hi @jjjjohnson Could you help resolve the conflicts? Thanks.

@jjjjohnson
Copy link
Contributor Author

Hi @jjjjohnson Could you help resolve the conflicts? Thanks.

Done

@merrymercy
Copy link
Contributor

Could you share any performance results?

@merrymercy merrymercy mentioned this pull request Jan 15, 2025
@jjjjohnson
Copy link
Contributor Author

jjjjohnson commented Jan 16, 2025

Could you share any performance results?

Sure!
Since the Lookahead speculative decode will cache input and output tokens, I run sglang.bench_serving 2 turns and disable the random.shuffle(dataset) to make the request same for 2 turns to compare the performance difference with normal decode.
Note: Lookahead speculative decode is turned off when batch size > 4 and I limit the max-concurrency and request-rate.

image

Start Server:

Normal decode:

python -m sglang.launch_server --model-path /mnt/workspace/model_hub/Qwen2-7B-Instruct --trust-remote-code --tp 1

Lookahead speculative decode:

python -m sglang.launch_server --model-path /mnt/workspace/model_hub/Qwen2-7B-Instruct \
      --trust-remote-code --tp 1 --speculative-num-draft-tokens 4 --speculative-algorithm LOOKAHEAD --speculative-one-branch

Benchmark:

python3 -m sglang.bench_serving --backend sglang --dataset-name sharegpt --dataset-path /oss/ShareGPT_V3_unfiltered_cleaned_split.json --num-prompts 500 --max-concurrency 3 --request-rate 2

Result:

Normal decode first run turn:

Backend:

Normal decode second run turn:

image

Lookahead speculative decode first run turn:

Backend:

Lookahead speculative decode second run turn:

Backend:

parser.add_argument(
"--speculative-lookahead-path",
type=str,
help="The path of the lookahead ",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A more detailed description is needed here. The current description is somewhat confusing as to what this parameter does.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@@ -220,6 +220,17 @@ def __init__(
target_worker=self.tp_worker,
dp_rank=dp_rank,
)
elif self.spec_algorithm.is_lookahead():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it more appropriate to use a factory pattern to create different speculative workers? cc @merrymercy

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just follow what eagel2 did in the class SpeculativeAlgorithm...

@@ -666,6 +667,10 @@ def init_cuda_graphs(self):
tic = time.time()
logger.info("Capture cuda graph begin. This can take up to several minutes.")
self.cuda_graph_runner = CudaGraphRunner(self)
if self.spec_algorithm.is_lookahead():
# in case look_ahead failed to match any draft token, fallback to normal cuda graph decode
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why can’t the same cuda graph runner be reused here?

Copy link
Contributor Author

@jjjjohnson jjjjohnson Jan 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because there is two cases using cuda graph:

  • Normal decode, where one batch corresponding to 1 token when decode;
  • Lookahead spec decode, where one batch corresponding to more than 1 token when decode

These two cases cannot be unified, so I need a tag to differentiate these cases.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But it will always be the one used when using the lookahead algorithm. Why do we need to create two cuda graph runners? https://github.com/sgl-project/sglang/pull/2790/files#diff-65c6ac2c41977f68e460f18e35053b97089631f88a9958b0796343fccee78a67R719

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you profiled the time-consuming proportion of the lookahead_cache part? I'm curious about the performance of these functions implemented by python. (Of course this is not a problem that needs to be solved for merging this PR. I am just curious.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The time to generate draft tokens using lookahead_cache is quite short, only cost 0.001s for 8 tokens.
image
But to update lookahead_cache using context tokens after prefill sometimes takes a long time especially when context very long, probably due to lookahead_cache.put induce python dict resizing, which takes a long time when the dict is very large.
image

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe you could change the lookahead_cache.put to an async function to make it overlap with model computation on GPU. Which may help the module perform better.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request high priority
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants