Skip to content

Commit

Permalink
Cache indices and offsets
Browse files Browse the repository at this point in the history
  • Loading branch information
DamianSzwichtenberg committed Jul 15, 2024
1 parent f0e4a83 commit c3e775a
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 23 deletions.
12 changes: 9 additions & 3 deletions vllm/attention/backends/habana_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,16 +142,22 @@ def forward(
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)

if prefill_meta := attn_metadata.prefill_metadata:
block_indices = prefill_meta.block_indices
block_offsets = prefill_meta.block_offsets
if decode_meta := attn_metadata.decode_metadata:
block_indices = decode_meta.block_indices
block_offsets = decode_meta.block_offsets
if kv_cache is not None:
key_cache, value_cache = HabanaPagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size)

# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run.
block_indices, block_offset = cache_ops.prepare_to_cache(key_cache, attn_metadata.slot_mapping)
key_cache = self.key_cache(key, key_cache, block_indices, block_offset)
value_cache = self.value_cache(value, value_cache, block_indices, block_offset)
key_cache = self.key_cache(key, key_cache, block_indices, block_offsets)
value_cache = self.value_cache(value, value_cache, block_indices, block_offsets)

if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run.
Expand Down
2 changes: 2 additions & 0 deletions vllm/attention/ops/habana_paged_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ class HabanaPagedAttentionMetadata:
block_list: Optional[torch.Tensor]
block_mapping: Optional[torch.Tensor]
block_usage: Optional[torch.Tensor]
block_indices: Optional[torch.Tensor]
block_offsets: Optional[torch.Tensor]


class HabanaPagedAttention:
Expand Down
18 changes: 0 additions & 18 deletions vllm/hpu/cache_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,6 @@
import habana_frameworks.torch as htorch


def reshape_and_cache(key, value, key_cache, value_cache, slot_mapping, dtype, is_prompt=False):
block_size = key_cache.size(1)
slot_mapping = slot_mapping.flatten()
indices = torch.div(slot_mapping, block_size, rounding_mode="floor")
offsets = torch.fmod(slot_mapping, block_size)
key_cache.index_put_((indices, offsets), key)
value_cache.index_put_((indices, offsets), value)


def prepare_to_cache(cache, slot_mapping):
block_size = cache.size(1)
slot_mapping = slot_mapping.flatten()
indices = torch.div(slot_mapping, block_size, rounding_mode="floor")
offsets = torch.fmod(slot_mapping, block_size)

return indices, offsets


def insert_or_update_cache(input, cache, block_indices, block_offsets):
cache.index_put_((block_indices, block_offsets), input)

Expand Down
19 changes: 17 additions & 2 deletions vllm/worker/habana_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,13 @@ def pad_list(l, k, v):
return l + [v] * padding


def precompute_indices_and_offsets(block_size, slot_mapping):
slot_mapping = slot_mapping.flatten()
indices = torch.div(slot_mapping, block_size, rounding_mode="floor")
offsets = torch.fmod(slot_mapping, block_size)
return indices, offsets


class HpuModelAdapter():
def __init__(self, model, block_size):
self.model = model
Expand Down Expand Up @@ -595,10 +602,14 @@ def _prepare_prompt(
dtype=torch.long,
device=self.device)

block_indices, block_offsets = precompute_indices_and_offsets(self.block_size, slot_mapping)

attn_metadata = self.attn_backend.make_metadata(
block_list=None,
block_mapping=None,
block_usage=None,
block_indices=block_indices,
block_offsets=block_offsets,
attn_bias=None,
seq_lens_tensor=seq_lens_tensor,
)
Expand Down Expand Up @@ -698,10 +709,14 @@ def _prepare_decode(
dtype=torch.long,
device=self.device)

block_indices, block_offsets = precompute_indices_and_offsets(self.block_size, slot_mapping)

attn_metadata = self.attn_backend.make_metadata(
block_list=block_list,
block_mapping=block_mapping,
block_usage=block_usage,
block_indices=block_indices,
block_offsets=block_offsets,
attn_bias=None,
seq_lens_tensor=None,
)
Expand Down Expand Up @@ -886,10 +901,10 @@ def _seq_len(self, attn_metadata):
def trim_attn_metadata(self, metadata: AttentionMetadata) -> object:
prefill_metadata = subtuple(metadata.prefill_metadata,
"TrimmedPrefillMetadata",
['attn_bias', 'seq_lens_tensor'])
['attn_bias', 'seq_lens_tensor', 'block_indices', 'block_offsets'])
decode_metadata = subtuple(metadata.decode_metadata,
"TrimmedDecodeMetadata",
['attn_bias', 'block_list', 'block_mapping', 'block_usage'])
['attn_bias', 'block_list', 'block_mapping', 'block_usage', 'block_indices', 'block_offsets'])
return subtuple(metadata,
'TrimmedMetadata',
['slot_mapping',
Expand Down

0 comments on commit c3e775a

Please sign in to comment.