From 2517225cdc6fb1cbf843a33311e649870c9c6b27 Mon Sep 17 00:00:00 2001 From: Liangsheng Yin Date: Sat, 20 Jul 2024 02:05:11 -0700 Subject: [PATCH] Revert "Temporary fix invalid sample results (#668)" This reverts commit 8f4b1559e796bd37cf43d6fa61a8fa7e191eb872. --- python/sglang/srt/managers/controller/infer_batch.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/python/sglang/srt/managers/controller/infer_batch.py b/python/sglang/srt/managers/controller/infer_batch.py index 3a909617073..5eed985d3e1 100644 --- a/python/sglang/srt/managers/controller/infer_batch.py +++ b/python/sglang/srt/managers/controller/infer_batch.py @@ -673,16 +673,6 @@ def sample(self, logits: torch.Tensor): batch_next_token_ids, _ = top_k_top_p_sampling_from_probs( probs, uniform_samples, self.top_ks, self.top_ps ) - - # FIXME: This is a temporary fix for the illegal token ids in sampling. - illegal_mask = ( - batch_next_token_ids < 0 or batch_next_token_ids >= probs.shape[-1] - ) - if torch.any(illegal_mask): - warnings.warn("Illegal token ids in sampling.") - batch_next_token_ids = torch.where( - illegal_mask, torch.argmax(probs, dim=-1), batch_next_token_ids - ) except RuntimeError as e: warnings.warn(f"Ignore errors in sampling: {e}") batch_next_token_ids = torch.argmax(probs, dim=-1)