Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support HQT on VLLM #59

Closed
wants to merge 13 commits into from
1 change: 1 addition & 0 deletions vllm/attention/backends/habana_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ def forward(
self.kv_matmul,
self.key_cache.fetch_from_cache,
self.value_cache.fetch_from_cache,
self.key_cache.permute_cache,
)

# Reshape the output tensor.
Expand Down
2 changes: 2 additions & 0 deletions vllm/attention/ops/habana_paged_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def forward_decode(
kv_op=torch.matmul,
keys_fetch=ops.fetch_from_cache,
values_fetch=ops.fetch_from_cache,
keys_permute=ops.permute_cache,
) -> torch.Tensor:
block_size = value_cache.shape[1]
return ops.paged_attention_v1(
Expand All @@ -108,6 +109,7 @@ def forward_decode(
kv_op,
keys_fetch,
values_fetch,
keys_permute,
)

@staticmethod
Expand Down
12 changes: 10 additions & 2 deletions vllm/hpu/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,13 @@ def fetch_from_cache(cache, blocks, permutations):
return [cache.index_select(0, blocks[:, i]).permute(permutations) for i in range(blocks.size(1))]


def permute_cache(cache, permutations):
return [v.permute(permutations) for v in cache]


def paged_attention_v1(query, key_cache, value_cache, head_mapping, scale, block_tables, context_lens, block_size, alibi_slopes, kv_cache_dtype=None,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is @hpu_utils.with_mark_steps removed here?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we want that all convert to\from hf8 will be in the same graph
so we remove this mark step and add one outside the transformer block align to ohf version.

qk_matmul_op=torch.matmul, softmax_op=torch.softmax, kv_matmul_op=torch.matmul, keys_fetch_func=fetch_from_cache, values_fetch_func=fetch_from_cache) -> None:
qk_matmul_op=torch.matmul, softmax_op=torch.softmax, kv_matmul_op=torch.matmul, keys_fetch_func=fetch_from_cache, values_fetch_func=fetch_from_cache,
keys_permute=permute_cache) -> None:
seq_len = block_tables.size(1)
batch_size, query_heads, _ = query.shape
_, _, kv_heads, _ = key_cache.shape
Expand All @@ -50,11 +55,14 @@ def paged_attention_v1(query, key_cache, value_cache, head_mapping, scale, block
.view(batch_size, 1, 1, -1))
query.mul_(scale)
query = query.unsqueeze(-2)
keys = keys_fetch_func(key_cache, block_tables, (0, 2, 3, 1))
keys = keys_fetch_func(key_cache, block_tables, (0, 2, 1, 3))
if query_heads != kv_heads:
query = query.unflatten(1, (kv_heads, -1))
keys = [k.unflatten(1, (kv_heads, 1)) for k in keys]
keys = keys_permute(keys, (0, 1, 2, 4, 3))
mask = mask.unsqueeze(2)
else:
keys = keys_permute(keys, (0, 1, 3, 2))
attn_weights = [qk_matmul_op(query, k) for k in keys]
attn_weights = softmax_op(torch.cat(attn_weights, dim=-1).masked_fill(mask, min_inf),
dim=-1)
Expand Down
3 changes: 3 additions & 0 deletions vllm/hpu/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,6 @@ def forward(self, input, cache, block_indices, block_offset):

def fetch_from_cache(self, cache, blocks, permutations):
return [cache.index_select(0, blocks[:, i]).permute(permutations) for i in range(blocks.size(1))]

def permute_cache(self, cache, permutations):
return [v.permute(permutations) for v in cache]
8 changes: 4 additions & 4 deletions vllm/model_executor/layers/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ def forward(
orig_shape = x.shape
residual += x.view(residual.shape)
# Note: FusedRMSNorm requires 3D tensors as inputs
x = FusedRMSNorm.apply(residual.float(), self.weight.float(), self.variance_epsilon)
return x.to(orig_dtype).view(orig_shape), residual
x = FusedRMSNorm.apply(residual, self.weight, self.variance_epsilon)
return x.view(orig_shape), residual
ops.fused_add_rms_norm(
x,
residual,
Expand All @@ -72,8 +72,8 @@ def forward(
return x, residual
if x.device.type == "hpu" and FusedRMSNorm:
orig_dtype = x.dtype

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

orig_dtype not used, here and also in line 60

x = FusedRMSNorm.apply(x.float(), self.weight.float(), self.variance_epsilon)
return x.to(orig_dtype)
x = FusedRMSNorm.apply(x, self.weight, self.variance_epsilon)
return x
out = torch.empty_like(x)
ops.rms_norm(
out,
Expand Down