Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1][WIP] Hybrid allocator for full attention & sliding window attention interleaved models (Reference PR, do not merge) #11938

Draft
wants to merge 15 commits into
base: main
Choose a base branch
from
Draft
64 changes: 64 additions & 0 deletions examples/offline_sliding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import random
from typing import List
from vllm import LLM, SamplingParams


def prep_prompts(batch_size: int):
"""
Generate prompts which a bunch of assignments,
then asking for the value of one of them.
The prompt is just under 10k tokens; sliding window is 4k
so the answer is outside sliding window, but should still be correct.
"""
prompts: List[str] = []
answer: List[int] = []
indices: List[int] = []
random.seed(1)
for _ in range(batch_size):
idx = random.randint(30, 90)
indices.append(idx)
prompt = "```python\n# We set a number of variables, " + \
f"x{idx} will be important later\n"
ln = random.randint(600, 800)
for k in range(30, ln):
v = random.randint(10, 99)
if k == idx:
answer.append(v)
prompt += f"x{k} = {v}\n"
prompt += f"# Now, we check the value of x{idx}:\n"
prompt += f"assert x{idx} == "
prompts.append(prompt)
return prompts, answer, indices


def check_answers(indices: List[int], answer: List[int], outputs: List[str]):
answer2 = [int(text[0:2].strip()) for text in outputs]
print(list(zip(indices, zip(answer, answer2))))
numok = 0
for a1, a2 in zip(answer, answer2):
if a1 == a2:
numok += 1
frac_ok = numok / len(answer)
print(f"Num OK: {numok}/{len(answer)} {frac_ok}")
assert frac_ok > 0.7


# Sample prompts.
prompts, answer, indices = prep_prompts(1)

# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.0, max_tokens=100)

# Create an LLM.
llm = LLM(model="google/gemma-2-9b-it", enforce_eager=True)
# llm = LLM(model="meta-llama/Llama-3.1-8B-Instruct")
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Generated text: {generated_text!r}")
check_answers(indices, answer,
[response.outputs[0].text for response in outputs])
28 changes: 27 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from vllm.utils import (FlexibleArgumentParser, StoreBoolean, deprecate_kwargs,
get_open_port, memory_profiling, merge_async_iterators,
supports_kw)
register_kv_cache, supports_kw)

from .utils import error_on_warning, fork_new_process_for_each_test

Expand Down Expand Up @@ -306,3 +306,29 @@ def test_memory_profiling():
del weights
lib.cudaFree(handle1)
lib.cudaFree(handle2)


def test_register_gpu_kv_cache():
from vllm.attention import Attention
from vllm.config import LayerForwardContext

# example from Jamba PP=2
ctx = {
'model.layers.20.attn':
LayerForwardContext(
attn_module=Attention(32, 128, 0.1),
kv_cache=None,
),
'model.layers.28.attn':
LayerForwardContext(
attn_module=Attention(32, 128, 0.1),
kv_cache=None,
)
}
kv_cache = [
torch.zeros((1, )),
torch.zeros((1, )),
]
register_kv_cache(ctx, kv_cache)
assert ctx['model.layers.20.attn'].kv_cache is kv_cache[0]
assert ctx['model.layers.28.attn'].kv_cache is kv_cache[1]
79 changes: 79 additions & 0 deletions tests/v1/e2e/test_correctness_sliding_window.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import random
from typing import List
from vllm import LLM, SamplingParams
import pytest


@pytest.mark.parametrize("model", ["bigcode/starcoder2-3b"])
@pytest.mark.parametrize("batch_size", [5])
@pytest.mark.parametrize("seed", [1])
def test_sliding_window_retrival(monkeypatch, model, batch_size, seed):
"""
The test does a bunch of assignments "x1 = 10\nx2 = 33\n..." and then
asks for value of one of them (which is outside the sliding window).
If we tell it upfront which we are going to be looking for, then
it answers correctly (mostly).
"""
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")

llm = LLM(model=model)
sampling_params = SamplingParams(temperature=0.0, max_tokens=100)

prompts, answer, indices = prep_prompts(batch_size)

responses = llm.generate(prompts, sampling_params)
check_answers(indices, answer,
[response.outputs[0].text for response in responses])


def prep_prompts(batch_size: int):
"""
Generate prompts which a bunch of assignments,
then asking for the value of one of them.
The prompt is just under 10k tokens; sliding window is 4k
so the answer is outside sliding window, but should still be correct.
"""
prompts: List[str] = []
answer: List[int] = []
indices: List[int] = []
random.seed(1)
for _ in range(batch_size):
idx = random.randint(30, 90)
indices.append(idx)
prompt = "```python\n# We set a number of variables, " + \
f"x{idx} will be important later\n"
ln = random.randint(800, 1100)
for k in range(30, ln):
v = random.randint(10, 99)
if k == idx:
answer.append(v)
prompt += f"x{k} = {v}\n"
prompt += f"# Now, we check the value of x{idx}:\n"
prompt += f"assert x{idx} == "
prompts.append(prompt)
return prompts, answer, indices


def check_answers(indices: List[int], answer: List[int], outputs: List[str]):
answer2 = [int(text[0:2].strip()) for text in outputs]
print(list(zip(indices, zip(answer, answer2))))
numok = 0
for a1, a2 in zip(answer, answer2):
if a1 == a2:
numok += 1
frac_ok = numok / len(answer)
print(f"Num OK: {numok}/{len(answer)} {frac_ok}")
assert frac_ok > 0.7


def check_window(prompts: List[str]):

def inner(llm: LLM):
sliding_window = llm.llm_engine.model_config.get_sliding_window()
assert sliding_window and sliding_window > 0
assert any(
len(llm.get_tokenizer().tokenize(prompt)) > sliding_window
for prompt in prompts)

return inner
1 change: 1 addition & 0 deletions vllm/attention/backends/blocksparse_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,7 @@ def __init__(
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
layer_name: str = "",
) -> None:
assert blocksparse_params is not None
assert alibi_slopes is None, ValueError(
Expand Down
42 changes: 24 additions & 18 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@

from vllm.attention import AttentionMetadata, AttentionType
from vllm.attention.selector import backend_name_to_enum, get_attn_backend
from vllm.config import CacheConfig, get_current_vllm_config
from vllm.config import (CacheConfig, LayerForwardContext,
get_current_vllm_config)
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
Expand Down Expand Up @@ -100,7 +101,9 @@ def __init__(
self.num_heads = num_heads
self.head_size = head_size
self.num_kv_heads = num_kv_heads
self.sliding_window = sliding_window
self.backend = backend_name_to_enum(attn_backend.get_name())
self.dtype = dtype

# For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
# torch.compile works by registering the attention as one giant
Expand All @@ -117,7 +120,10 @@ def __init__(
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
# use a placeholder kv cache tensor during init, which will be replaced
# after kv cache initialization
compilation_config.static_forward_context[
prefix] = LayerForwardContext(self, torch.tensor([]))
self.layer_name = prefix

def forward(
Expand Down Expand Up @@ -152,13 +158,11 @@ def forward(
if value is not None:
value = value.view(-1, self.num_kv_heads, self.head_size)
torch.ops.vllm.unified_attention_with_output(
query, key, value, output, kv_cache, attn_type,
self.layer_name)
query, key, value, output, attn_type, self.layer_name)
return output.view(-1, hidden_size)
else:
return torch.ops.vllm.unified_attention(query, key, value,
kv_cache, attn_type,
self.layer_name)
attn_type, self.layer_name)

def extra_repr(self) -> str:
s = f"head_size={self.impl.head_size}" # type: ignore
Expand Down Expand Up @@ -235,17 +239,19 @@ def unified_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_type: str,
layer_name: str,
) -> torch.Tensor:
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.dynamic_forward_context
self = forward_context.static_forward_context[layer_name]
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[layer_name]
ctx = forward_context.layers[layer_name]
self = ctx.attn_module
return self.impl.forward(query,
key,
value,
kv_cache,
ctx.kv_cache,
attn_metadata,
self._k_scale,
self._v_scale,
Expand All @@ -256,7 +262,6 @@ def unified_attention_fake(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_type: str,
layer_name: str,
) -> torch.Tensor:
Expand All @@ -266,7 +271,7 @@ def unified_attention_fake(
direct_register_custom_op(
op_name="unified_attention",
op_func=unified_attention,
mutates_args=["kv_cache"],
mutates_args=[],
fake_impl=unified_attention_fake,
dispatch_key=current_platform.dispatch_key,
)
Expand All @@ -277,17 +282,19 @@ def unified_attention_with_output(
key: torch.Tensor,
value: torch.Tensor,
output: torch.Tensor,
kv_cache: torch.Tensor,
attn_type: str,
layer_name: str,
) -> None:
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.dynamic_forward_context
self = forward_context.static_forward_context[layer_name]
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[layer_name]
ctx = forward_context.layers[layer_name]
self = ctx.attn_module
self.impl.forward(query,
key,
value,
kv_cache,
ctx.kv_cache,
attn_metadata,
self._k_scale,
self._v_scale,
Expand All @@ -300,7 +307,6 @@ def unified_attention_with_output_fake(
key: torch.Tensor,
value: torch.Tensor,
output: torch.Tensor,
kv_cache: torch.Tensor,
attn_type: str,
layer_name: str,
) -> None:
Expand All @@ -310,7 +316,7 @@ def unified_attention_with_output_fake(
direct_register_custom_op(
op_name="unified_attention_with_output",
op_func=unified_attention_with_output,
mutates_args=["kv_cache", "output"],
mutates_args=["output"],
fake_impl=unified_attention_with_output_fake,
dispatch_key=current_platform.dispatch_key,
)
14 changes: 10 additions & 4 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -992,7 +992,7 @@ def _verify_prefix_caching(self) -> None:
if not self.enable_prefix_caching:
return

if self.sliding_window is not None:
if not envs.VLLM_USE_V1 and self.sliding_window is not None:
raise NotImplementedError(
"Prefix caching is not supported with sliding window. "
"Run with --disable-sliding-window to use prefix caching.")
Expand Down Expand Up @@ -2564,6 +2564,12 @@ class CompilationLevel:
PIECEWISE = 3


@dataclass
class LayerForwardContext:
attn_module: Any # vllm.attention.layer.Attention
kv_cache: Any # torch.Tensor


class CompilationConfig(BaseModel):
"""
Configuration for compilation.
Expand Down Expand Up @@ -2717,9 +2723,9 @@ def model_post_init(self, __context: Any) -> None:
inductor_hash_cache: Any = PrivateAttr

# Per-model forward context
# Mainly used to store attention cls
# Map from layer name to the attention cls
static_forward_context: Dict[str, Any] = PrivateAttr
# Map from layer name to the layer's forward context, which stores
# attention cls and kv_cache
static_forward_context: Dict[str, LayerForwardContext] = PrivateAttr

def compute_hash(self) -> str:
"""
Expand Down
Loading
Loading