Skip to content

Commit

Permalink
Support FP8 FA from Quark format
Browse files Browse the repository at this point in the history
  • Loading branch information
BowenBao committed Jan 27, 2025
1 parent 8e87b08 commit eda4d30
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 24 deletions.
34 changes: 14 additions & 20 deletions vllm/model_executor/layers/quantization/quark/quark.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import fnmatch
import re
from typing import Any, Dict, List, Optional, cast

import torch
Expand Down Expand Up @@ -122,6 +121,12 @@ def from_config(cls, config: Dict[str, Any]) -> "QuarkConfig":
for q_config in q_configs:
q_config["output_tensors"] = None

# In case q_proj output is also quantized, remove the configuration
# to keep qkv consistency.
q_proj_q_config = cast(Dict[str, Any],
layer_quant_config.get("*q_proj"))
q_proj_q_config["output_tensors"] = None

return cls(quant_config=config,
kv_cache_group=kv_cache_group,
kv_cache_config=kv_cache_config,
Expand Down Expand Up @@ -286,25 +291,14 @@ def get_cache_scale(self, name: str) -> Optional[str]:
:param name: param name
:return: matching param name for KV cache scale in vLLM
"""
if self.kv_cache_group is None or len(self.kv_cache_group) == 0:
return None

kv_proj_names = [
re.split(r"[*.]", kv_cache)[-1] for kv_cache in self.kv_cache_group
]
if name.endswith(".output_scale"):
if len(kv_proj_names) == 1 and kv_proj_names[0] in name:
kv_output_scale_name = "." + kv_proj_names[0] + ".output_scale"
return name.replace(kv_output_scale_name, ".attn.k_scale")

elif len(kv_proj_names) == 2:
for kv_proj_name in kv_proj_names:
if kv_proj_name in name and kv_proj_name == "k_proj":
return name.replace(".k_proj.output_scale",
".attn.k_scale")
elif kv_proj_name in name and kv_proj_name == "v_proj":
return name.replace(".v_proj.output_scale",
".attn.v_scale")
if name.endswith(".output_scale") and ".k_proj" in name:
return name.replace(".k_proj.output_scale", ".attn.k_scale")
if name.endswith(".output_scale") and ".v_proj" in name:
return name.replace(".v_proj.output_scale", ".attn.v_scale")
if name.endswith(".output_scale") and ".q_proj" in name:
return name.replace(".q_proj.output_scale", ".attn.q_scale")
if name.endswith("self_attn.prob_output_scale"):
return name.replace(".prob_output_scale", ".attn.prob_scale")

# If no matches, return None
return None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def __init__(self, qscheme: str, is_static_input_scheme: Optional[bool]):
self.qscheme = qscheme
self.is_static_input_scheme = is_static_input_scheme
self.cutlass_fp8_supported = cutlass_fp8_supported()
self.out_dtype = torch.get_default_dtype()

@classmethod
def get_min_capability(cls) -> int:
Expand Down Expand Up @@ -134,6 +135,7 @@ def apply_weights(self,
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
out_dtype=self.out_dtype,
input_scale=layer.input_scale,
bias=bias,
cutlass_fp8_supported=self.cutlass_fp8_supported,
Expand Down
5 changes: 4 additions & 1 deletion vllm/model_executor/models/grok1.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
from vllm.model_executor.layers.quantization.quark.quark import QuarkConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import (
Expand Down Expand Up @@ -197,7 +198,9 @@ def __init__(
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
self.use_fp8 = isinstance(quant_config, Fp8Config)
self.use_fp8 = isinstance(
quant_config, Fp8Config) or (isinstance(quant_config, QuarkConfig)
and quant_config._is_fp8_w8a8)
# Requires transformers > 4.32.0
rope_theta = getattr(config, "rope_theta", 10000)
self.attn = Grok1Attention(hidden_size=self.hidden_size,
Expand Down
14 changes: 11 additions & 3 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
from vllm.model_executor.layers.quantization.quark.quark import QuarkConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
Expand Down Expand Up @@ -84,7 +85,9 @@ def __init__(
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
)
self.use_fp8 = (isinstance(quant_config, Fp8Config)
self.use_fp8 = (isinstance(quant_config, Fp8Config) or
(isinstance(quant_config, QuarkConfig)
and quant_config._is_fp8_w8a8)
if current_platform.is_rocm() and not is_navi() else
False)
if hidden_act != "silu":
Expand Down Expand Up @@ -196,10 +199,13 @@ def __init__(self,
sliding_window = None

# For CUDA devices and Navi4x, attn_fp8 will be set to false.
use_fp8 = isinstance(
quant_config, Fp8Config) or (isinstance(quant_config, QuarkConfig)
and quant_config._is_fp8_w8a8)
self.attn_fp8_out = envs.VLLM_USE_ROCM_CUSTOM_PAGED_ATTN_FP8_OUT \
and current_platform.is_rocm() \
and not is_navi() \
and isinstance(quant_config, Fp8Config)
and use_fp8

self.attn = Attention(
self.num_heads,
Expand Down Expand Up @@ -240,7 +246,9 @@ def __init__(
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
self.use_fp8 = (isinstance(quant_config, Fp8Config)
self.use_fp8 = (isinstance(quant_config, Fp8Config) or
(isinstance(quant_config, QuarkConfig)
and quant_config._is_fp8_w8a8)
if current_platform.is_rocm() and not is_navi() else
False)
rope_theta = getattr(config, "rope_theta", 10000)
Expand Down

0 comments on commit eda4d30

Please sign in to comment.