Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support FP8 FA from Quark format #388

Merged
merged 4 commits into from
Jan 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 27 additions & 20 deletions vllm/model_executor/layers/quantization/quark/quark.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import fnmatch
import re
from typing import Any, Dict, List, Optional, cast

import torch
Expand Down Expand Up @@ -122,6 +121,12 @@ def from_config(cls, config: Dict[str, Any]) -> "QuarkConfig":
for q_config in q_configs:
q_config["output_tensors"] = None

# In case q_proj output is also quantized, remove the configuration
# to keep qkv consistency.
q_proj_q_config = cast(Dict[str, Any],
layer_quant_config.get("*q_proj"))
q_proj_q_config["output_tensors"] = None

return cls(quant_config=config,
kv_cache_group=kv_cache_group,
kv_cache_config=kv_cache_config,
Expand All @@ -148,6 +153,19 @@ def _check_scheme_supported(self,
else:
return False

def is_fp8_w8a8(self) -> bool:
# Returns True if all quantized layers in model are fp8 w8a8
global_quant_config = cast(
Dict[str, Any], self.quant_config.get("global_quant_config"))
layer_quant_configs = cast(Dict[str, Any],
self.quant_config.get("layer_quant_config"))
for config in (global_quant_config, *layer_quant_configs.values()):
weight_config = cast(Dict[str, Any], config.get("weight"))
input_config = cast(Dict[str, Any], config.get("input_tensors"))
if not self._is_fp8_w8a8(weight_config, input_config):
return False
return True

def _is_fp8_w8a8(self, weight_quant: Optional[Dict[str, Any]],
input_quant: Optional[Dict[str, Any]]) -> bool:
# Confirm weights and input quantized.
Expand Down Expand Up @@ -286,25 +304,14 @@ def get_cache_scale(self, name: str) -> Optional[str]:
:param name: param name
:return: matching param name for KV cache scale in vLLM
"""
if self.kv_cache_group is None or len(self.kv_cache_group) == 0:
return None

kv_proj_names = [
re.split(r"[*.]", kv_cache)[-1] for kv_cache in self.kv_cache_group
]
if name.endswith(".output_scale"):
if len(kv_proj_names) == 1 and kv_proj_names[0] in name:
kv_output_scale_name = "." + kv_proj_names[0] + ".output_scale"
return name.replace(kv_output_scale_name, ".attn.k_scale")

elif len(kv_proj_names) == 2:
for kv_proj_name in kv_proj_names:
if kv_proj_name in name and kv_proj_name == "k_proj":
return name.replace(".k_proj.output_scale",
".attn.k_scale")
elif kv_proj_name in name and kv_proj_name == "v_proj":
return name.replace(".v_proj.output_scale",
".attn.v_scale")
if name.endswith(".output_scale") and ".k_proj" in name:
return name.replace(".k_proj.output_scale", ".attn.k_scale")
if name.endswith(".output_scale") and ".v_proj" in name:
return name.replace(".v_proj.output_scale", ".attn.v_scale")
if name.endswith(".output_scale") and ".q_proj" in name:
return name.replace(".q_proj.output_scale", ".attn.q_scale")
if name.endswith("self_attn.prob_output_scale"):
return name.replace(".prob_output_scale", ".attn.prob_scale")

# If no matches, return None
return None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def __init__(self, qscheme: str, is_static_input_scheme: Optional[bool]):
self.qscheme = qscheme
self.is_static_input_scheme = is_static_input_scheme
self.cutlass_fp8_supported = cutlass_fp8_supported()
self.out_dtype = torch.get_default_dtype()

@classmethod
def get_min_capability(cls) -> int:
Expand Down Expand Up @@ -134,6 +135,7 @@ def apply_weights(self,
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
out_dtype=self.out_dtype,
input_scale=layer.input_scale,
bias=bias,
cutlass_fp8_supported=self.cutlass_fp8_supported,
Expand Down
5 changes: 4 additions & 1 deletion vllm/model_executor/models/grok1.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
from vllm.model_executor.layers.quantization.quark.quark import QuarkConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import (
Expand Down Expand Up @@ -197,7 +198,9 @@ def __init__(
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
self.use_fp8 = isinstance(quant_config, Fp8Config)
self.use_fp8 = isinstance(
quant_config, Fp8Config) or (isinstance(quant_config, QuarkConfig)
and quant_config.is_fp8_w8a8())
# Requires transformers > 4.32.0
rope_theta = getattr(config, "rope_theta", 10000)
self.attn = Grok1Attention(hidden_size=self.hidden_size,
Expand Down
14 changes: 11 additions & 3 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
from vllm.model_executor.layers.quantization.quark.quark import QuarkConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
Expand Down Expand Up @@ -84,7 +85,9 @@ def __init__(
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
)
self.use_fp8 = (isinstance(quant_config, Fp8Config)
self.use_fp8 = (isinstance(quant_config, Fp8Config) or
(isinstance(quant_config, QuarkConfig)
and quant_config.is_fp8_w8a8())
if current_platform.is_rocm() and not is_navi() else
False)
if hidden_act != "silu":
Expand Down Expand Up @@ -196,10 +199,13 @@ def __init__(self,
sliding_window = None

# For CUDA devices and Navi4x, attn_fp8 will be set to false.
use_fp8 = isinstance(
quant_config, Fp8Config) or (isinstance(quant_config, QuarkConfig)
and quant_config.is_fp8_w8a8())
self.attn_fp8_out = envs.VLLM_USE_ROCM_CUSTOM_PAGED_ATTN_FP8_OUT \
and current_platform.is_rocm() \
and not is_navi() \
and isinstance(quant_config, Fp8Config)
and use_fp8

self.attn = Attention(
self.num_heads,
Expand Down Expand Up @@ -240,7 +246,9 @@ def __init__(
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
self.use_fp8 = (isinstance(quant_config, Fp8Config)
self.use_fp8 = (isinstance(quant_config, Fp8Config) or
(isinstance(quant_config, QuarkConfig)
and quant_config.is_fp8_w8a8())
if current_platform.is_rocm() and not is_navi() else
False)
rope_theta = getattr(config, "rope_theta", 10000)
Expand Down