Skip to content

Commit

Permalink
Address review
Browse files Browse the repository at this point in the history
  • Loading branch information
WoosukKwon committed Nov 6, 2024
1 parent b593aa6 commit c7d0b7a
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 8 deletions.
12 changes: 6 additions & 6 deletions vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def forward(


def unified_flash_attention(
out: torch.Tensor,
output: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
Expand Down Expand Up @@ -202,7 +202,7 @@ def unified_flash_attention(
v_scale,
)

output = flash_attn_varlen_func(
attn_output = flash_attn_varlen_func(
q=query[:num_actual_tokens],
k=key_cache,
v=value_cache,
Expand All @@ -217,13 +217,13 @@ def unified_flash_attention(
block_table=attn_metadata.block_table,
softcap=logits_soft_cap,
)
output = output.view(num_actual_tokens, -1)
attn_output = attn_output.view(num_actual_tokens, -1)
# TODO(woosuk): Optimize this.
out[:num_actual_tokens].copy_(output, non_blocking=True)
output[:num_actual_tokens].copy_(attn_output)


def unified_flash_attention_fake(
out: torch.Tensor,
output: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
Expand All @@ -245,6 +245,6 @@ def unified_flash_attention_fake(
direct_register_custom_op(
op_name="unified_flash_attention",
op_func=unified_flash_attention,
mutates_args=["kv_cache", "out"],
mutates_args=["kv_cache", "output"],
fake_impl=unified_flash_attention_fake,
)
5 changes: 3 additions & 2 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import time
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Set
Expand Down Expand Up @@ -407,7 +408,7 @@ def load_model(self) -> None:
# FIXME(woosuk): Currently, the custom ops are not supported
# in the piecewise compilation mode. We rely on TorchInductor
# to optimize the model.
envs.VLLM_CUSTOM_OPS = "none"
os.environ["VLLM_CUSTOM_OPS"] = "none"
set_compilation_config(
CompilationConfig(
use_cudagraph=True,
Expand Down Expand Up @@ -451,7 +452,7 @@ def profile_run(self) -> None:

@torch.inference_mode()
def capture_model(self) -> None:
if self.use_cuda_graph:
if not self.use_cuda_graph:
logger.warning(
"Skipping CUDA graph capture. Please set "
"VLLM_TORCH_COMPILE_LEVEL=%d to use CUDA graphs.",
Expand Down

0 comments on commit c7d0b7a

Please sign in to comment.