From 8f4b1559e796bd37cf43d6fa61a8fa7e191eb872 Mon Sep 17 00:00:00 2001 From: Liangsheng Yin Date: Sat, 20 Jul 2024 00:51:05 -0700 Subject: [PATCH] Temporary fix invalid sample results (#668) --- python/sglang/srt/managers/controller/infer_batch.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/python/sglang/srt/managers/controller/infer_batch.py b/python/sglang/srt/managers/controller/infer_batch.py index 5eed985d3e1..3a909617073 100644 --- a/python/sglang/srt/managers/controller/infer_batch.py +++ b/python/sglang/srt/managers/controller/infer_batch.py @@ -673,6 +673,16 @@ def sample(self, logits: torch.Tensor): batch_next_token_ids, _ = top_k_top_p_sampling_from_probs( probs, uniform_samples, self.top_ks, self.top_ps ) + + # FIXME: This is a temporary fix for the illegal token ids in sampling. + illegal_mask = ( + batch_next_token_ids < 0 or batch_next_token_ids >= probs.shape[-1] + ) + if torch.any(illegal_mask): + warnings.warn("Illegal token ids in sampling.") + batch_next_token_ids = torch.where( + illegal_mask, torch.argmax(probs, dim=-1), batch_next_token_ids + ) except RuntimeError as e: warnings.warn(f"Ignore errors in sampling: {e}") batch_next_token_ids = torch.argmax(probs, dim=-1)