diff --git a/vllm/attention/backends/habana_attn.py b/vllm/attention/backends/habana_attn.py index fc7f74c1f96fb..f60370146b7f2 100644 --- a/vllm/attention/backends/habana_attn.py +++ b/vllm/attention/backends/habana_attn.py @@ -142,6 +142,13 @@ def forward( query = query.view(-1, self.num_heads, self.head_size) key = key.view(-1, self.num_kv_heads, self.head_size) value = value.view(-1, self.num_kv_heads, self.head_size) + + if prefill_meta := attn_metadata.prefill_metadata: + block_indices = prefill_meta.block_indices + block_offsets = prefill_meta.block_offsets + if decode_meta := attn_metadata.decode_metadata: + block_indices = decode_meta.block_indices + block_offsets = decode_meta.block_offsets if kv_cache is not None: key_cache, value_cache = HabanaPagedAttention.split_kv_cache( kv_cache, self.num_kv_heads, self.head_size) @@ -149,9 +156,8 @@ def forward( # Reshape the input keys and values and store them in the cache. # If kv_cache is not provided, the new key and value tensors are # not cached. This happens during the initial memory profiling run. - block_indices, block_offset = cache_ops.prepare_to_cache(key_cache, attn_metadata.slot_mapping) - key_cache = self.key_cache(key, key_cache, block_indices, block_offset) - value_cache = self.value_cache(value, value_cache, block_indices, block_offset) + key_cache = self.key_cache(key, key_cache, block_indices, block_offsets) + value_cache = self.value_cache(value, value_cache, block_indices, block_offsets) if prefill_meta := attn_metadata.prefill_metadata: # Prompt run. diff --git a/vllm/attention/ops/habana_paged_attn.py b/vllm/attention/ops/habana_paged_attn.py index fa4dba165f3b3..051fcd8cca272 100644 --- a/vllm/attention/ops/habana_paged_attn.py +++ b/vllm/attention/ops/habana_paged_attn.py @@ -19,6 +19,8 @@ class HabanaPagedAttentionMetadata: block_list: Optional[torch.Tensor] block_mapping: Optional[torch.Tensor] block_usage: Optional[torch.Tensor] + block_indices: Optional[torch.Tensor] + block_offsets: Optional[torch.Tensor] class HabanaPagedAttention: diff --git a/vllm/hpu/cache_ops.py b/vllm/hpu/cache_ops.py index 4734333c99e49..40dcd8eae17bd 100644 --- a/vllm/hpu/cache_ops.py +++ b/vllm/hpu/cache_ops.py @@ -10,24 +10,6 @@ import habana_frameworks.torch as htorch -def reshape_and_cache(key, value, key_cache, value_cache, slot_mapping, dtype, is_prompt=False): - block_size = key_cache.size(1) - slot_mapping = slot_mapping.flatten() - indices = torch.div(slot_mapping, block_size, rounding_mode="floor") - offsets = torch.fmod(slot_mapping, block_size) - key_cache.index_put_((indices, offsets), key) - value_cache.index_put_((indices, offsets), value) - - -def prepare_to_cache(cache, slot_mapping): - block_size = cache.size(1) - slot_mapping = slot_mapping.flatten() - indices = torch.div(slot_mapping, block_size, rounding_mode="floor") - offsets = torch.fmod(slot_mapping, block_size) - - return indices, offsets - - def insert_or_update_cache(input, cache, block_indices, block_offsets): cache.index_put_((block_indices, block_offsets), input) diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index 696736e96511b..8316fed55e7d3 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -200,6 +200,13 @@ def pad_list(l, k, v): return l + [v] * padding +def precompute_indices_and_offsets(block_size, slot_mapping): + slot_mapping = slot_mapping.flatten() + indices = torch.div(slot_mapping, block_size, rounding_mode="floor") + offsets = torch.fmod(slot_mapping, block_size) + return indices, offsets + + class HpuModelAdapter(): def __init__(self, model, block_size): self.model = model @@ -595,10 +602,14 @@ def _prepare_prompt( dtype=torch.long, device=self.device) + block_indices, block_offsets = precompute_indices_and_offsets(self.block_size, slot_mapping) + attn_metadata = self.attn_backend.make_metadata( block_list=None, block_mapping=None, block_usage=None, + block_indices=block_indices, + block_offsets=block_offsets, attn_bias=None, seq_lens_tensor=seq_lens_tensor, ) @@ -698,10 +709,14 @@ def _prepare_decode( dtype=torch.long, device=self.device) + block_indices, block_offsets = precompute_indices_and_offsets(self.block_size, slot_mapping) + attn_metadata = self.attn_backend.make_metadata( block_list=block_list, block_mapping=block_mapping, block_usage=block_usage, + block_indices=block_indices, + block_offsets=block_offsets, attn_bias=None, seq_lens_tensor=None, ) @@ -886,10 +901,10 @@ def _seq_len(self, attn_metadata): def trim_attn_metadata(self, metadata: AttentionMetadata) -> object: prefill_metadata = subtuple(metadata.prefill_metadata, "TrimmedPrefillMetadata", - ['attn_bias', 'seq_lens_tensor']) + ['attn_bias', 'seq_lens_tensor', 'block_indices', 'block_offsets']) decode_metadata = subtuple(metadata.decode_metadata, "TrimmedDecodeMetadata", - ['attn_bias', 'block_list', 'block_mapping', 'block_usage']) + ['attn_bias', 'block_list', 'block_mapping', 'block_usage', 'block_indices', 'block_offsets']) return subtuple(metadata, 'TrimmedMetadata', ['slot_mapping',