Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Speculative decoding] [Performance]: Re-enable bonus tokens #4212

Closed
cadedaniel opened this issue Apr 19, 2024 · 12 comments · Fixed by #5765
Closed

[Speculative decoding] [Performance]: Re-enable bonus tokens #4212

cadedaniel opened this issue Apr 19, 2024 · 12 comments · Fixed by #5765
Labels
performance Performance-related issues

Comments

@cadedaniel
Copy link
Collaborator

cadedaniel commented Apr 19, 2024

Proposal to improve performance

In #3951 we disable bonus tokens (token sampled from verifier model assuming all proposal tokens are accepted) because its KV is not generated for the draft model. We can fix this by "prefilling" the KV of bonus tokens in the draft model. Note that for proposal methods not requiring KV (e.g. prompt lookup), we can re-enable bonus tokens and get a speedup there.

The impact of this performance improvement depends on the speculation length. For low K, e.g. 1, where the probability of accepting the single spec token is high (~= how aligned the draft model and target model are on the sequence), it has high impact because accepting 1 token allows us to emit 2 tokens (1 speculative and 1 bonus). Since we disable bonus tokens, we can now only emit 1 token (the accepted speculative one).

For higher K the impact is less as the likelihood of accepting all speculative tokens is exponentially lower.

# We disable bonus tokens because it causes corrupt KV cache for
# proposal methods that require KV cache. We can fix it by "prefilling"
# the bonus token in the proposer. The following issue tracks the fix.
# https://github.com/vllm-project/vllm/issues/4212
output_with_bonus_tokens[:, -1] = -1

@sroy745
Copy link
Collaborator

sroy745 commented Jun 14, 2024

Hi @cadedaniel wondering if anyone working on this currently? If not I would like to look into it. Please let me know.

@cadedaniel
Copy link
Collaborator Author

cadedaniel commented Jun 14, 2024

that would be awesome. let's chat more in vllm slack.

@cadedaniel
Copy link
Collaborator Author

cadedaniel commented Jun 14, 2024

Discussed with @sroy745 and @LiuXiaoxuanPKU on best approach to fix this

Goal: decide whether to use batch expansion or chunked prefill
Decision: which takes more time when there are bonus tokens

How fast can we do batch expansion? vs
- We can hit <0.5ms or less, because the batch expansion is small (only 1 per sequence, in the worst case)

How fast can we do chunked prefill fwd pass
- Expected to be slower because no CUDA graphs
- Expected to be slower because attention is slower in triton (Lily says it's flash, so not a concern)

How to measure "prefill" time in chunked prefill?
- Measure fwd pass time for small batch sizes, with varying bonus tokens (1..BS). using chunked prefil kernel
	this is the best measurement because it's exactly what we will run
	but it takes time to set up properly

	this measures prefill computation + lack of cuda graphs

- Measure fwd pass time of decode batch, no cuda graph (no chunked prefill). Small batch sizes.
	this ignores any overhead in the fwd pass of chunked prefill
	BUT it captures 80+% of the overhead, which we intuit is due to lack of cuda graphs
	
	this measures lack of cuda graphs

Worker.execute_model
	- model_runner.execute_model
		# - prepare_inputs
		- fwd pass
		# - sampling
	- return

--- 68m model with TP1. Measure fwd pass time with and without cuda graph.
JackFrame/68m

https://pytorch.org/docs/stable/generated/torch.cuda.Event.html

torch.cuda.Event(enable_timing=False, blocking=False, interprocess=False)


def model_runner.execute_model(...):
	...
	start_event = torch.cuda.Event(enable_timing=True)
	end_event = torch.cuda.Event(enable_timing=True)

	start_event.record() # need to make sure it's on the same stream as model
	outputs = model.fwd_pass(inputs)
	end_event.record() # need to make sure it's on the same stream as model

	end_event.synchronize()
	elapsed_ms = start_event.elapsed_time(end_event)

	sampled = sampler.sample(outputs)
	return sampled

@sroy745
Copy link
Collaborator

sroy745 commented Jun 16, 2024

Hi @cadedaniel and @LiuXiaoxuanPKU

Here is a pr that I used for doing some measurements.

I ran the tests with JackFram/llama-68m with TP 1 on A-100. Without cuda graphs the decode time is ~0.89ms to 0.87ms at batch size 5 and 10 respectively. This is greater than the 0.5ms expected for batch expansion. Given these numbers should we go with batch expansion then?

For batch size 5
Without Cuda Graph
prefill time - 1.16 ms
decode time - 0.89 ms
With Cuda Graph
prefill time - 1.04 ms
decode time - 0.23 ms

For batch size 10
Without Cuda Graph
prefill time - 1.1 ms
decode time - 0.87 ms
With Cuda Graph
prefill time - 1.0 ms
decode time - 0.22 ms

@cadedaniel
Copy link
Collaborator Author

Sounds good to me!

@sroy745
Copy link
Collaborator

sroy745 commented Jun 27, 2024

This PR implements the logic for enabling bonus tokens. For this feature, the SpecDecodeWorker maintains state across multiple forward passes of a sequence to determine if it was assigned a bonus token. If so, it then backfills the KV cache for the penultimate token in the next forward pass. This logic for maintaining state is implemented in the SpecDecodeWorker

In the current implementation, the SpecDecodeWorker maintains a list of the sequence_ids that were assigned bonus tokens in their last forward pass. If the sequence is not assigned a bonus token in its current pass, it is removed from the list if it was there. However, if the generation is terminated for a sequence that was part of this list, it is never removed. Hence, over time, we will accumulate sequence_ids in this list which are no longer active. Therefore, we need a way to remove such sequence_ids from this list.

One way to implement this would be the following:

  1. Add a new method to the ExecutorBase and WorkerBase that can be invoked to communicate to the Executor and through the Executor to the Worker about the termination of a sequence.
  2. Pass a reference to the ExecutorBase to the SequenceGroupOutputProcessor. In the SequenceGroupOutputProcessor, whenever the sequence terminates, it will invoke the method in ExecutorBase to inform about sequence generation termination.

class ExecutorBase():
    ....
    ....
    def process_terminated_sequences(sequence_ids: List[int]):
      """
      Pass a list of sequence_ids for which generation has been stopped
      for processing by the Executor.
      """
      return self.driver_worker.process_terminated_sequences(sequence_ids)
      ...
      ....

class WorkerBase():
   ....
   ....
   def process_terminated_sequences(seq_ids: List[int]):
   """
   Pass a list of sequence_ids for which generation has been stopped
   for processing by the Executor.
   """
   .....
   .....
   .....

 class SequenceGroupOutputProcessor(ABC):
   def  create_output_processor(
        scheduler_config: SchedulerConfig,
        detokenizer: Detokenizer,
        scheduler: Scheduler,
        seq_counter: Counter,
        get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer],
        stop_checker: "StopChecker",
        executor: ExecutorBase
    ):
         def process_outputs(self, sequence_group: SequenceGroup,
                        outputs: List[SequenceGroupOutput]) -> None:
            ....
            ....
            # Invoke executor.process_terminated_sequences for seq_ids whose generation has been stopped.
            ....
            ....

@cadedaniel can you please take a look at this proposal and let me know if this would work?

@cadedaniel
Copy link
Collaborator Author

cadedaniel commented Jun 27, 2024

The proposal looks good. To simplify the interaction between scheduler and worker, we should embed the finished seq ids in the ExecuteModelRequest. This is better than adding a new method because in the future the worker procs could run forever in a loop; it is also better than coupling the OutputProcessor with the worker as the OutputProcessors will live in their own process in the near future.

by the way, the folks implementing Jamba support ran into the exact same issue. See the changes to ExecuteModelRequest in this PR https://github.com/vllm-project/vllm/pull/4115/files.

@sroy745
Copy link
Collaborator

sroy745 commented Jun 28, 2024

Thanks for the pointer. Since this pr is addressing the same problem I will wait for this pr to be merged.

@llsj14
Copy link
Contributor

llsj14 commented Oct 25, 2024

I hope that either the _disable_bonus_tokens or use_flashinfer option will be set to True by default.
Since they are not enabled by default, we experienced a sharp drop in acceptance rates when K(num_speculative_tokens) is not 1, and it was difficult to identify the reason.


I found that the recent version does not have issues with these options. However, the PRs (#5765, #7244, #8701) still haven't fully resolved the problem with bonus tokens. We continued to experience a sharp drop in acceptance rates when K is not 1, regardless of whether flashinfer was used. This drop does not occur if bonus tokens are disabled. We tested versions v0.5.4 and v0.6.2.

(cc. @jeongin601)

@sroy745
Copy link
Collaborator

sroy745 commented Oct 25, 2024

My understanding is that the bonus_token logic is enabled by default. Can you point me to where you see its not being set?

@llsj14
Copy link
Contributor

llsj14 commented Oct 26, 2024

@sroy745
Yes, in our second check, we found that the default options are not an issue. (It took some time to review the history related to this issue.)

The problem is that, since when we don't use _disable_bonus_tokens, the acceptance rate drops when K(num_speculative_tokens) is not 1.
For example, when K=1, the acceptance rate is 80%, but when K=4, it drops to 65%.
This didn’t happen when we forced the _disable_bonus_tokens option, so we're checking if there might be a bug in handling bonus tokens still.

For reference, we tested versions v0.5.4, and v0.6.2 with and without the use_flashinfer option enabled.

@llsj14
Copy link
Contributor

llsj14 commented Oct 27, 2024

I made a PR related to this issue
#9730

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
performance Performance-related issues
Projects
None yet
3 participants