Skip to content

Commit

Permalink
Temporary fix invalid sample results (#668)
Browse files Browse the repository at this point in the history
  • Loading branch information
hnyls2002 authored Jul 20, 2024
1 parent e3046ea commit 8f4b155
Showing 1 changed file with 10 additions and 0 deletions.
10 changes: 10 additions & 0 deletions python/sglang/srt/managers/controller/infer_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,6 +673,16 @@ def sample(self, logits: torch.Tensor):
batch_next_token_ids, _ = top_k_top_p_sampling_from_probs(
probs, uniform_samples, self.top_ks, self.top_ps
)

# FIXME: This is a temporary fix for the illegal token ids in sampling.
illegal_mask = (
batch_next_token_ids < 0 or batch_next_token_ids >= probs.shape[-1]
)
if torch.any(illegal_mask):
warnings.warn("Illegal token ids in sampling.")
batch_next_token_ids = torch.where(
illegal_mask, torch.argmax(probs, dim=-1), batch_next_token_ids
)
except RuntimeError as e:
warnings.warn(f"Ignore errors in sampling: {e}")
batch_next_token_ids = torch.argmax(probs, dim=-1)
Expand Down

0 comments on commit 8f4b155

Please sign in to comment.