Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Commit

Permalink
[Bug Fix] Fix the support check for FP8 CUTLASS (vllm-project#5352)
Browse files Browse the repository at this point in the history
Bug description:
With torch 2.4.0.dev20240603+cu121,
cutlass_fp8_supported outputs False, and the (capability, version) before the comparison is (90, 11111111112)

This PR fixes the support check for FP8 CUTLASS ( cutlass_fp8_supported) which was introduced in vllm-project#5183.
  • Loading branch information
cli99 authored and robertgshaw2-redhat committed Jun 11, 2024
1 parent d65c3ab commit e349c2d
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,16 @@
def cutlass_fp8_supported() -> bool:
capability = torch.cuda.get_device_capability()
capability = capability[0] * 10 + capability[1]
version = torch.version.cuda
version = version[0] * 10 + version[1]
major, minor = torch.version.cuda.split(".")
version = int(major) * 10 + int(minor)

# CUTLASS FP8 kernels need at least
# CUDA 12.0 on SM90 systems (Hopper)
# CUDA 12.4 on SM89 systems (Lovelace)
gpu_is_supported = False
if capability >= 900:
if capability >= 90:
gpu_is_supported = version > 120
elif capability >= 890:
elif capability >= 89:
gpu_is_supported = version > 124

return gpu_is_supported
Expand Down Expand Up @@ -103,7 +103,7 @@ class Fp8LinearMethod(LinearMethodBase):
1. Only support per-tensor quantization due to torch._scaled_mm support.
2. Only support float8_e4m3fn data type due to the limitation of
torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
Args:
quant_config: The quantization config.
"""
Expand Down Expand Up @@ -298,8 +298,8 @@ def __init__(self, quant_config: Fp8Config):
self.quant_config = quant_config

def create_weights(self, layer: torch.nn.Module):
"""Create "weight" (aka kv_scale) for an attention layer.
"""Create "weight" (aka kv_scale) for an attention layer.
Args:
layer: The layer that is using the QuantizeMethodBase factory.
"""
Expand Down

0 comments on commit e349c2d

Please sign in to comment.