Skip to content

Commit

Permalink
[SpecDec][Misc] Cleanup, remove bonus token logic. (vllm-project#8701)
Browse files Browse the repository at this point in the history
Signed-off-by: Alvant <[email protected]>
  • Loading branch information
LiuXiaoxuanPKU authored and Alvant committed Oct 26, 2024
1 parent 2d0e736 commit d6e878e
Show file tree
Hide file tree
Showing 7 changed files with 33 additions and 115 deletions.
30 changes: 7 additions & 23 deletions tests/samplers/test_rejection_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,18 +42,13 @@ def mock_causal_accepted_tensor(
@pytest.mark.parametrize(
"which_tokens_accepted",
["all_tokens_accepted", "no_tokens_accepted", "some_tokens_accepted"])
@pytest.mark.parametrize("disable_bonus_tokens", [True, False])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("use_flashinfer", [True, False])
@torch.inference_mode()
def test_correct_output_format(which_tokens_accepted: str, seed: int,
disable_bonus_tokens: bool, device: str,
use_flashinfer: bool):
device: str, use_flashinfer: bool):
"""Verify the output has correct format given predetermined accepted matrix.
"""
if use_flashinfer and disable_bonus_tokens:
pytest.skip("Flashinfer rejection sampler must enable bonus token.")

set_random_seed(seed)
torch.set_default_device(device)

Expand Down Expand Up @@ -88,9 +83,7 @@ def test_correct_output_format(which_tokens_accepted: str, seed: int,
size=(batch_size, 1),
dtype=torch.int64)

rejection_sampler = RejectionSampler(
disable_bonus_tokens=disable_bonus_tokens,
use_flashinfer=use_flashinfer)
rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer)
rejection_sampler.init_gpu_tensors(device=device)
output_token_ids = rejection_sampler._create_output( # pylint: disable=protected-access
accepted,
Expand All @@ -100,10 +93,6 @@ def test_correct_output_format(which_tokens_accepted: str, seed: int,
)

expected_bonus_token_ids = bonus_token_ids.clone()
# If bonus tokens disabled. Verify they are set to -1.
# See https://github.com/vllm-project/vllm/issues/4212
if disable_bonus_tokens:
expected_bonus_token_ids = expected_bonus_token_ids * 0 - 1

if which_tokens_accepted == "all_tokens_accepted":
# Expect all tokens to be equal to draft tokens.
Expand Down Expand Up @@ -143,8 +132,7 @@ def test_correct_output_format(which_tokens_accepted: str, seed: int,
def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
device: str, use_flashinfer: bool):
torch.set_default_device(device)
rejection_sampler = RejectionSampler(disable_bonus_tokens=False,
use_flashinfer=use_flashinfer)
rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer)
rejection_sampler.init_gpu_tensors(device=device)

draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
Expand Down Expand Up @@ -177,8 +165,7 @@ def test_deterministic_when_seeded(k: int, vocab_size: int, batch_size: int,
frac_seeded: float, n_rep: int, device: str,
use_flashinfer: bool):
torch.set_default_device(device)
rejection_sampler = RejectionSampler(disable_bonus_tokens=False,
use_flashinfer=use_flashinfer)
rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer)
rejection_sampler.init_gpu_tensors(device=device)

draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
Expand Down Expand Up @@ -251,8 +238,7 @@ def get_seeded_seqs():
}

for use_flashinfer in [True, False]:
rejection_sampler = RejectionSampler(disable_bonus_tokens=False,
use_flashinfer=use_flashinfer)
rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer)
rejection_sampler.init_gpu_tensors(device=device)
# We use seeded sequences to ensure the same tokens are accepted
# for both flashinfer and nonflashinfer backends.
Expand Down Expand Up @@ -282,8 +268,7 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
vocab_size = 30_000
torch.set_default_device(device)

rejection_sampler = RejectionSampler(disable_bonus_tokens=False,
use_flashinfer=use_flashinfer,
rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer,
strict_mode=True)
rejection_sampler.init_gpu_tensors(device=device)

Expand Down Expand Up @@ -359,8 +344,7 @@ def test_rejection_sampling_approximates_target_distribution(
set_random_seed(seed)
helper = _CorrectnessTestHelper(
vocab_size=10,
rejection_sampler=RejectionSampler(disable_bonus_tokens=False,
use_flashinfer=use_flashinfer),
rejection_sampler=RejectionSampler(use_flashinfer=use_flashinfer),
)

draft_probs, target_probs, reference_probs = helper.generate_probs_for_test(
Expand Down
79 changes: 20 additions & 59 deletions tests/samplers/test_typical_acceptance_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,13 @@ def get_draft_token_ids(batch_size: int, k: int, vocab_size: int,
def get_acceptance_sampler(
posterior_threshold: float = 0.03,
posterior_alpha: float = 0.9,
disable_bonus_tokens: bool = False,
strict_mode: bool = False,
) -> TypicalAcceptanceSampler:
"""
Initializes and returns a TypicalAcceptanceSampler.
"""
return TypicalAcceptanceSampler(posterior_threshold, posterior_alpha,
disable_bonus_tokens, strict_mode)
strict_mode)


@pytest.mark.parametrize("k", list(range(1, 6)))
Expand Down Expand Up @@ -154,29 +153,25 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,


@pytest.mark.parametrize("seed", list(range(10)))
@pytest.mark.parametrize("disable_bonus_tokens", [True, False])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_uniform_target_distribution_accepts_all_tokens(
seed: int, disable_bonus_tokens: bool, device: str):
seed: int, device: str):
"""
Test the TypicalAcceptanceSampler with a uniform target probability
distribution.
This test verifies that when provided with a uniform target probability
distribution, the TypicalAcceptanceSampler accepts all draft tokens. The
entropy of the uniform target distribution being high should lead to all
draft tokens being accepted. The test also ensures that the behavior
regarding bonus tokens is consistent with the `disable_bonus_tokens`
flag.
draft tokens being accepted.
"""
set_random_seed(seed)
k = 3
batch_size = 5
vocab_size = 30_000
torch.set_default_device(device)
typical_acceptance_sampler = get_acceptance_sampler(
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
typical_acceptance_sampler.init_gpu_tensors(device=device)
target_with_bonus_probs = torch.rand(batch_size,
k + 1,
Expand All @@ -200,21 +195,15 @@ def test_uniform_target_distribution_accepts_all_tokens(
# should lead to all draft tokens being accepted. Verify that.
assert output_token_ids.shape[0] == batch_size
assert output_token_ids.shape[1] == (k + 1)
if disable_bonus_tokens:
assert torch.all(output_token_ids[:, -1] == -1)
else:
assert torch.all(output_token_ids[:, -1] == bonus_token_ids.squeeze())
assert torch.all(output_token_ids[:, -1] == bonus_token_ids.squeeze())

assert torch.all(output_token_ids[:, :k] == draft_token_ids)


@pytest.mark.parametrize("seed", list(range(10)))
@pytest.mark.parametrize("disable_bonus_tokens", [True, False])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_temperature_zero_target_distribution(seed: int,
disable_bonus_tokens: bool,
device: str):
def test_temperature_zero_target_distribution(seed: int, device: str):
"""
Test the TypicalAcceptanceSampler with a zero-temperature target
probability distribution.
Expand All @@ -232,8 +221,7 @@ def test_temperature_zero_target_distribution(seed: int,
vocab_size = 30_000
torch.set_default_device(device)

typical_acceptance_sampler = get_acceptance_sampler(
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
typical_acceptance_sampler.init_gpu_tensors(device=device)
# Simulate temperature 0 probability distribution for target probabilities
# and create target probabilities such that only 1 token id has
Expand Down Expand Up @@ -267,11 +255,9 @@ def test_temperature_zero_target_distribution(seed: int,


@pytest.mark.parametrize("seed", list(range(10)))
@pytest.mark.parametrize("disable_bonus_tokens", [True, False])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool,
device: str):
def test_mixed_target_distribution(seed: int, device: str):
"""
Test the TypicalAcceptanceSampler with a mixed target probability
distribution.
Expand All @@ -285,16 +271,13 @@ def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool,
with a probability of 1.0 is accepted, and all other tokens are rejected.
- For sequences with a uniform distribution, all draft tokens are
accepted.
- When `disable_bonus_tokens` is False, the bonus tokens are also accepted
for sequences with a uniform distribution.
"""
set_random_seed(seed)
k = 3
batch_size = 4
vocab_size = 30_000
torch.set_default_device(device)
typical_acceptance_sampler = get_acceptance_sampler(
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
typical_acceptance_sampler.init_gpu_tensors(device=device)
# For sequences 0 and 2 set the distribution to a temperature
# zero distribution. For sequences 1 and 3 set it to a uniform
Expand Down Expand Up @@ -328,21 +311,16 @@ def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool,
0]))
# For sequences 1 and 3 verify that all tokens are accepted since the
# target probability distribution is uniform. In addition verify that
# if disable_bonus_tokens is false then we also accept the bonus tokens.
# we also accept the bonus tokens.
assert torch.all(
output_token_ids[[1, 3], :-1] == draft_token_ids[[1, 3], :])
if disable_bonus_tokens:
assert torch.all(output_token_ids[[1, 3], -1] == -1)
else:
assert torch.all(output_token_ids[[1, 3], -1] != -1)
assert torch.all(output_token_ids[[1, 3], -1] != -1)


@pytest.mark.parametrize("seed", list(range(10)))
@pytest.mark.parametrize("disable_bonus_tokens", [True, False])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
device: str):
def test_accept_tokens_partially(seed: int, device: str):
"""
Test the TypicalAcceptanceSampler's behavior when only a subset of draft
tokens should be accepted.
Expand All @@ -362,8 +340,7 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
batch_size = 1
vocab_size = 30_000
torch.set_default_device(device)
typical_acceptance_sampler = get_acceptance_sampler(
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
typical_acceptance_sampler.init_gpu_tensors(device=device)
# Create a temperature zero target probability distribution and ensure
# all draft token ids correspond to the tokens with 1.0 probability.
Expand All @@ -384,10 +361,7 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
assert output_token_ids.shape[0] == batch_size
assert output_token_ids.shape[1] == (k + 1)
assert torch.all(output_token_ids[:, 0:-1] == draft_token_ids)
if disable_bonus_tokens:
assert torch.all(output_token_ids[:, -1] == -1)
else:
assert torch.all(output_token_ids[:, -1] == bonus_token_ids)
assert torch.all(output_token_ids[:, -1] == bonus_token_ids)
# Next only keep the first 2 draft tokens same as the zero temperature
# tokens. For the remaining 3 choose some other tokens. In the
# response we will expect the first 2 tokens to be the same as the
Expand All @@ -408,12 +382,9 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,


@pytest.mark.parametrize("seed", list(range(1)))
@pytest.mark.parametrize("disable_bonus_tokens", [True, False])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_accept_tokens_set_non_default_posteriors(seed: int,
disable_bonus_tokens: bool,
device: str):
def test_accept_tokens_set_non_default_posteriors(seed: int, device: str):
"""
Test the TypicalAcceptanceSampler with custom posterior thresholds and
alpha values. This test verifies that by modifying the posterior
Expand All @@ -425,8 +396,7 @@ def test_accept_tokens_set_non_default_posteriors(seed: int,
batch_size = 1
vocab_size = 30_000
torch.set_default_device(device)
typical_acceptance_sampler = get_acceptance_sampler(
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
typical_acceptance_sampler.init_gpu_tensors(device=device)
# Simulate temperature 0 probability distribution for target
# probabilities and create target probabilities such that only 1 token
Expand Down Expand Up @@ -457,10 +427,7 @@ def test_accept_tokens_set_non_default_posteriors(seed: int,
# now accept even draft tokens with very low probability in the
# target distribution. Simulate and verify the same.
typical_acceptance_sampler = TypicalAcceptanceSampler(
strict_mode=True,
disable_bonus_tokens=disable_bonus_tokens,
posterior_threshold=0.0,
posterior_alpha=0.0)
strict_mode=True, posterior_threshold=0.0, posterior_alpha=0.0)
typical_acceptance_sampler.init_gpu_tensors(device=device)
output_token_ids = typical_acceptance_sampler(
target_probs,
Expand All @@ -470,18 +437,13 @@ def test_accept_tokens_set_non_default_posteriors(seed: int,
assert output_token_ids.shape[0] == batch_size
assert output_token_ids.shape[1] == (k + 1)
assert torch.all(output_token_ids[:, 0:-1] == draft_token_ids)
if disable_bonus_tokens:
assert torch.all(output_token_ids[:, -1] == -1)
else:
assert torch.all(output_token_ids[:, -1] == bonus_token_ids)
assert torch.all(output_token_ids[:, -1] == bonus_token_ids)


@pytest.mark.parametrize("seed", list(range(10)))
@pytest.mark.parametrize("disable_bonus_tokens", [True, False])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_replacement_token_ids(seed: int, disable_bonus_tokens: bool,
device: str):
def test_replacement_token_ids(seed: int, device: str):
"""
Test the TypicalAcceptanceSampler's method for generating
replacement token IDs.
Expand All @@ -497,8 +459,7 @@ def test_replacement_token_ids(seed: int, disable_bonus_tokens: bool,
batch_size = 5
vocab_size = 30_000
torch.set_default_device(device)
typical_acceptance_sampler = get_acceptance_sampler(
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
typical_acceptance_sampler.init_gpu_tensors(device=device)
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
expected_replacement_tokens = -torch.ones(
Expand Down
2 changes: 1 addition & 1 deletion tests/spec_decode/e2e/test_medusa_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
# speculative model
SPEC_MODEL = "abhigoyal/vllm-medusa-llama-68m-random"

# max. number of speculative tokens: this corresponds to
# max number of speculative tokens: this corresponds to
# num_heads in the config.json of the speculator model.
MAX_SPEC_TOKENS = 5

Expand Down
9 changes: 1 addition & 8 deletions vllm/model_executor/layers/rejection_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,11 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
"""

def __init__(self,
disable_bonus_tokens: bool = True,
strict_mode: bool = False,
use_flashinfer: Optional[bool] = None):
"""Create a rejection sampler.
Args:
disable_bonus_tokens: Whether or not to disable the bonus token.
Require when bonus tokens will cause corrupt KV cache for
proposal methods that require KV cache.
strict_mode: Whether or not to perform shape/device/dtype checks
during sampling. This catches correctness issues but adds
nontrivial latency.
Expand All @@ -48,17 +44,14 @@ def __init__(self,
None, we will use the default value from the environment variable.
This parameter is only used for testing purposes.
"""
super().__init__(disable_bonus_tokens=disable_bonus_tokens,
strict_mode=strict_mode)
super().__init__(strict_mode=strict_mode)
if use_flashinfer is None:
self.use_flashinfer = envs.VLLM_USE_FLASHINFER_SAMPLER and (
chain_speculative_sampling is not None)
else:
self.use_flashinfer = use_flashinfer

if self.use_flashinfer:
assert not disable_bonus_tokens, \
"flashinfer will enable bonus token by default"
logger.info("Use flashinfer for rejection sampling.")
else:
logger.info("Use pytorch for rejection sampling.")
Expand Down
15 changes: 1 addition & 14 deletions vllm/model_executor/layers/spec_decode_base_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,14 @@ class SpecDecodeBaseSampler(nn.Module):
step.
"""

def __init__(self,
disable_bonus_tokens: bool = True,
strict_mode: bool = False):
def __init__(self, strict_mode: bool = False):
"""Base class constructor.
Args:
disable_bonus_tokens: Whether or not to disable the bonus token.
Require when bonus tokens will cause corrupt KV cache for
proposal methods that require KV cache.
strict_mode: Whether or not to perform shape/device/dtype checks
during sampling. This catches correctness issues but adds
nontrivial latency.
"""
super().__init__()
self._disable_bonus_tokens = disable_bonus_tokens
self._strict_mode = strict_mode

# NOTE: A "bonus token" is accepted iff all proposal tokens are
Expand Down Expand Up @@ -111,13 +105,6 @@ def _create_output(
output_with_bonus_tokens[:, -1] = torch.where(output[:, -1] != -1,
bonus_token_ids, -1)

# We disable bonus tokens because it causes corrupt KV cache for
# proposal methods that require KV cache. We can fix it by "prefilling"
# the bonus token in the proposer. The following issue tracks the fix.
# https://github.com/vllm-project/vllm/issues/4212
if self._disable_bonus_tokens:
output_with_bonus_tokens[:, -1] = -1

# Fill the recovered token ids.
output.mul_(~after_false_mask).add_(
substitute_token_ids.mul(after_false_mask))
Expand Down
Loading

0 comments on commit d6e878e

Please sign in to comment.