From d3257b21c240e63a7075e8f9abe39a77cba4c3cc Mon Sep 17 00:00:00 2001 From: Karol Damaszke Date: Wed, 30 Oct 2024 16:58:19 +0100 Subject: [PATCH] Fix performance of top_p and top_k calculations (#449) This change is fixing the performance issue I have introduced in the PR #414 -- due to the usage of `torch.where` both functions have been called. Now we will run only the selected one. --- vllm/model_executor/layers/sampler.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 74c0416e4b379..1b6bc2b1848c1 100755 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -267,12 +267,13 @@ def forward( if do_top_p_top_k and flashinfer_top_k_top_p_sampling is None: # If we have a scalar p and k, we can use the optimized version. - logits = torch.where( - self._scalar_p_and_k, - self._apply_top_k_top_p_opt(logits, self._top_p_scalar, - self._top_k_scalar), - _apply_top_k_top_p(logits, sampling_tensors.top_ps, - sampling_tensors.top_ks)) + if self._scalar_p_and_k.any(): + logits = self._apply_top_k_top_p_opt(logits, + self._top_p_scalar.item(), + self._top_k_scalar.item()) + else: + logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps, + sampling_tensors.top_ks) if do_min_p: logits = _apply_min_p(logits, sampling_tensors.min_ps)