Skip to content

Commit

Permalink
Support FP8 FA from Quark format
Browse files Browse the repository at this point in the history
  • Loading branch information
BowenBao committed Jan 28, 2025
1 parent 5c3b97a commit f90b0e7
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 4 deletions.
11 changes: 11 additions & 0 deletions vllm/model_executor/layers/quantization/quark/quark.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,17 @@ def _check_scheme_supported(self,
else:
return False

def is_fp8_w8a8(self) -> bool:
# Returns True if all layers in model is fp8 w8a8.
global_quant_config = self.quant_config.get("global_quant_config")
layer_quant_configs = self.quant_config.get("layer_quant_config")
for quant_config in (global_quant_config,
*layer_quant_configs.values()):
if not self._is_fp8_w8a8(quant_config.get("weight"),
quant_config.get("input_tensors")):
return False
return True

def _is_fp8_w8a8(self, weight_quant: Optional[Dict[str, Any]],
input_quant: Optional[Dict[str, Any]]) -> bool:
# Confirm weights and input quantized.
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/grok1.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def __init__(
self.hidden_size = config.hidden_size
self.use_fp8 = isinstance(
quant_config, Fp8Config) or (isinstance(quant_config, QuarkConfig)
and quant_config._is_fp8_w8a8)
and quant_config.is_fp8_w8a8())
# Requires transformers > 4.32.0
rope_theta = getattr(config, "rope_theta", 10000)
self.attn = Grok1Attention(hidden_size=self.hidden_size,
Expand Down
6 changes: 3 additions & 3 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def __init__(
)
self.use_fp8 = (isinstance(quant_config, Fp8Config) or
(isinstance(quant_config, QuarkConfig)
and quant_config._is_fp8_w8a8)
and quant_config.is_fp8_w8a8())
if current_platform.is_rocm() and not is_navi() else
False)
if hidden_act != "silu":
Expand Down Expand Up @@ -201,7 +201,7 @@ def __init__(self,
# For CUDA devices and Navi4x, attn_fp8 will be set to false.
use_fp8 = isinstance(
quant_config, Fp8Config) or (isinstance(quant_config, QuarkConfig)
and quant_config._is_fp8_w8a8)
and quant_config.is_fp8_w8a8())
self.attn_fp8_out = envs.VLLM_USE_ROCM_CUSTOM_PAGED_ATTN_FP8_OUT \
and current_platform.is_rocm() \
and not is_navi() \
Expand Down Expand Up @@ -248,7 +248,7 @@ def __init__(
self.hidden_size = config.hidden_size
self.use_fp8 = (isinstance(quant_config, Fp8Config) or
(isinstance(quant_config, QuarkConfig)
and quant_config._is_fp8_w8a8)
and quant_config.is_fp8_w8a8())
if current_platform.is_rocm() and not is_navi() else
False)
rope_theta = getattr(config, "rope_theta", 10000)
Expand Down

0 comments on commit f90b0e7

Please sign in to comment.