From daddc14f7dcc801859ece7b9ef3933168755d1c2 Mon Sep 17 00:00:00 2001 From: Will Lin Date: Fri, 13 Sep 2024 11:44:14 -0700 Subject: [PATCH 1/2] [bugfix] add multi-step advance_step to ROCmFlashAttentionMetadata --- vllm/attention/backends/rocm_flash_attn.py | 59 +++++++++++++++++++++- 1 file changed, 58 insertions(+), 1 deletion(-) diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index b0f4d0530b7f0..24d93e77d187b 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -1,10 +1,11 @@ """Attention layer ROCm GPUs.""" from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple, Type +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type import torch import vllm.envs as envs +from vllm import _custom_ops as ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) from vllm.attention.backends.utils import (CommonAttentionState, @@ -13,6 +14,9 @@ PagedAttentionMetadata) from vllm.logger import init_logger +if TYPE_CHECKING: + from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata + logger = init_logger(__name__) @@ -175,6 +179,59 @@ def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: ) return self._cached_decode_metadata + def advance_step(self, model_input: "ModelInputForGPUWithSamplingMetadata", + sampled_token_ids: Optional[torch.Tensor], + block_size: int, num_seqs: int, num_queries: int): + """ + Update metadata in-place to advance one decode step. + """ + # When using cudagraph, the num_seqs is padded to the next captured + # batch sized, but num_queries tracks the actual number of requests in + # the batch. For --enforce-eager mode, num_seqs == num_queries + if num_seqs != num_queries: + assert num_seqs > num_queries + assert self.use_cuda_graph + + assert self.num_prefills == 0 + assert self.num_prefill_tokens == 0 + assert self.num_decode_tokens == num_seqs + assert self.slot_mapping.shape == (num_seqs, ) + + assert self.seq_lens is not None + assert len(self.seq_lens) == num_seqs + assert self.seq_lens_tensor is not None + assert self.seq_lens_tensor.shape == (num_seqs, ) + assert self.max_query_len == 1 + assert self.max_prefill_seq_len == 0 + assert self.max_decode_seq_len == max(self.seq_lens) + + assert self.query_start_loc is not None + assert self.query_start_loc.shape == (num_queries + 1, ) + assert self.seq_start_loc is not None + assert self.seq_start_loc.shape == (num_seqs + 1, ) + + assert self.context_lens_tensor is not None + assert self.context_lens_tensor.shape == (num_queries, ) + + assert self.block_tables is not None + assert self.block_tables.shape[0] == num_seqs + + # Update query lengths. Note that we update only queries and not seqs, + # since tensors may be padded due to captured cuda graph batch size + for i in range(num_queries): + self.seq_lens[i] += 1 + self.max_decode_seq_len = max(self.seq_lens) + + ops.advance_step_flashattn(num_seqs=num_seqs, + num_queries=num_queries, + block_size=block_size, + input_tokens=model_input.input_tokens, + sampled_token_ids=sampled_token_ids, + input_positions=model_input.input_positions, + seq_lens=self.seq_lens_tensor, + slot_mapping=self.slot_mapping, + block_tables=self.block_tables) + class ROCmFlashAttentionMetadataBuilder( CommonMetadataBuilder[ROCmFlashAttentionMetadata]): From 306f21f975b09a68687a6cde3009a76f2d91da52 Mon Sep 17 00:00:00 2001 From: Will Lin Date: Fri, 13 Sep 2024 14:27:18 -0700 Subject: [PATCH 2/2] add rocm to MULTI_STEP_ATTENTION_BACKENDS --- vllm/worker/multi_step_model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/worker/multi_step_model_runner.py b/vllm/worker/multi_step_model_runner.py index b900eb5a610ff..2c76775cd3231 100644 --- a/vllm/worker/multi_step_model_runner.py +++ b/vllm/worker/multi_step_model_runner.py @@ -29,7 +29,7 @@ logger = init_logger(__name__) -MULTI_STEP_ATTENTION_BACKENDS = ["flash-attn", "flashinfer"] +MULTI_STEP_ATTENTION_BACKENDS = ["flash-attn", "rocm-flash-attn", "flashinfer"] def seq_output_builder():