diff --git a/csrc/cpu/quant.cpp b/csrc/cpu/quant.cpp index 0cfc19097fded..2d7abe6145fee 100644 --- a/csrc/cpu/quant.cpp +++ b/csrc/cpu/quant.cpp @@ -257,11 +257,13 @@ void int8_scaled_mm(torch::Tensor& c, // [M, OC], row-major // static-per-tensor quantization. void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] const torch::Tensor& input, // [..., hidden_size] - const torch::Tensor& scale) { + const torch::Tensor& scale, + c10::optional const& azp) { CPU_KERNEL_GUARD_IN(static_scaled_int8_quant) TORCH_CHECK(input.is_contiguous()); TORCH_CHECK(out.is_contiguous()); TORCH_CHECK(scale.numel() == 1); + TORCH_CHECK(!azp.has_value(), "Zero point is not supported on CPU."); const int hidden_size = input.size(-1); const int num_tokens = input.numel() / hidden_size; @@ -277,11 +279,12 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] void dynamic_scaled_int8_quant( torch::Tensor& out, // [..., hidden_size] const torch::Tensor& input, // [..., hidden_size] - torch::Tensor& scale // [..., 1] -) { + torch::Tensor& scale, // [..., 1] + c10::optional const& azp) { CPU_KERNEL_GUARD_IN(dynamic_scaled_int8_quant) TORCH_CHECK(input.is_contiguous()); TORCH_CHECK(out.is_contiguous()); + TORCH_CHECK(!azp.has_value(), "Zero point is not supported on CPU."); int const hidden_size = input.size(-1); int const num_tokens = input.numel() / hidden_size; diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp index b45da1b386b5b..ab697e3e6aef7 100644 --- a/csrc/cpu/torch_bindings.cpp +++ b/csrc/cpu/torch_bindings.cpp @@ -94,13 +94,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { #ifdef __AVX512F__ // Compute int8 quantized tensor for given scaling factor. ops.def( - "static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale) -> " - "()"); + "static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale," + "Tensor? azp) -> ()"); ops.impl("static_scaled_int8_quant", torch::kCPU, &static_scaled_int8_quant); + // Compute int8 quantized tensor and scaling factor ops.def( - "dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale) -> " - "()"); + "dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale, " + "Tensor!? azp) -> ()"); ops.impl("dynamic_scaled_int8_quant", torch::kCPU, &dynamic_scaled_int8_quant); // W8A8 GEMM, supporting symmetric per-tensor or per-row/column diff --git a/csrc/ops.h b/csrc/ops.h index 5333b22c536d6..681ab4b898ca3 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -184,10 +184,12 @@ torch::Tensor marlin_qqq_gemm(torch::Tensor const& a, #endif void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, - torch::Tensor const& scale); + torch::Tensor const& scale, + c10::optional const& azp); void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, - torch::Tensor& scales); + torch::Tensor& scales, + c10::optional const& azp); torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight, torch::Tensor b_gptq_qzeros, diff --git a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu index 616fc149760e5..aec9fa002f96e 100644 --- a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu +++ b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu @@ -14,12 +14,17 @@ static inline __device__ int8_t float_to_int8_rn(float x) { #ifdef USE_ROCM - static const float i8_min = + static constexpr auto i8_min = static_cast(std::numeric_limits::min()); - static const float i8_max = + static constexpr auto i8_max = static_cast(std::numeric_limits::max()); - // round + + // To match the rounding mode of CUDA, we use nearbyint. + // It uses the current rounding mode, which is always FE_TONEAREST on HIP. + // If that changes in the future, we may need to set the rounding mode + // explicitly, either at runtime or compile time. float dst = std::nearbyint(x); + // saturate dst = std::clamp(dst, i8_min, i8_max); return static_cast(dst); @@ -31,6 +36,59 @@ static inline __device__ int8_t float_to_int8_rn(float x) { #endif } +static inline __device__ int32_t float_to_int32_rn(float x) { +#ifdef USE_ROCM + // int32_max is not exactly representable as float. + // Therefore, we need to be careful and manually return int32_max on overflow. + // For symmetry, we also do the same for int32_min, even though it is exactly + // representable as float and the conversion should be exact. + static constexpr auto i32_min = std::numeric_limits::min(); + static constexpr auto i32_min_f = static_cast(i32_min); + static constexpr auto i32_max = std::numeric_limits::max(); + static constexpr auto i32_max_f = static_cast(i32_max); + + // To match the rounding mode of CUDA, we use nearbyint. + // It uses the current rounding mode, which is always FE_TONEAREST on HIP. + // If that changes in the future, we may need to set the rounding mode + // explicitly, either at runtime or compile time. + float dst = std::nearbyint(x); + + // saturate on the higher end. + if (dst >= i32_max_f) { + return i32_max; + } + // saturate on the lower end. + if (dst <= i32_min_f) { + return i32_min; + } + + return static_cast(dst); +#else + // CUDA path + uint32_t dst; + asm volatile("cvt.rni.sat.s32.f32 %0, %1;" : "=r"(dst) : "f"(x)); + return reinterpret_cast(dst); +#endif +} + +static inline __device__ int8_t int32_to_int8(int32_t x) { +#ifdef USE_ROCM + static constexpr auto i8_min = + static_cast(std::numeric_limits::min()); + static constexpr auto i8_max = + static_cast(std::numeric_limits::max()); + + // saturate + int32_t dst = std::clamp(x, i8_min, i8_max); + return static_cast(dst); +#else + // CUDA path + uint32_t dst; + asm volatile("cvt.sat.s8.s32 %0, %1;" : "=r"(dst) : "r"(x)); + return reinterpret_cast(dst); +#endif +} + namespace vllm { template @@ -47,6 +105,23 @@ __global__ void static_scaled_int8_quant_kernel( } } +template +__global__ void static_scaled_int8_azp_quant_kernel( + scalar_t const* __restrict__ input, int8_t* __restrict__ out, + scale_type const* scale_ptr, azp_type const* azp_ptr, + const int hidden_size) { + int const tid = threadIdx.x; + int const token_idx = blockIdx.x; + scale_type const scale = *scale_ptr; + azp_type const azp = *azp_ptr; + + for (int i = tid; i < hidden_size; i += blockDim.x) { + auto const val = static_cast(input[token_idx * hidden_size + i]); + auto const quant_val = int32_to_int8(float_to_int32_rn(val / scale) + azp); + out[token_idx * hidden_size + i] = quant_val; + } +} + template __global__ void dynamic_scaled_int8_quant_kernel( scalar_t const* __restrict__ input, int8_t* __restrict__ out, @@ -80,14 +155,68 @@ __global__ void dynamic_scaled_int8_quant_kernel( } } +template +__global__ void dynamic_scaled_int8_azp_quant_kernel( + scalar_t const* __restrict__ input, int8_t* __restrict__ out, + scale_type* scale, azp_type* azp, const int hidden_size) { + int const token_idx = blockIdx.x; + + // Scan for the min and max value for this token + float max_val = std::numeric_limits::min(); + float min_val = std::numeric_limits::max(); + for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) { + auto val = static_cast(input[token_idx * hidden_size + i]); + max_val = std::max(max_val, val); + min_val = std::min(min_val, val); + } + + // Reduce the max and min values across the block + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStorage; + max_val = BlockReduce(reduceStorage).Reduce(max_val, cub::Max{}, blockDim.x); + __syncthreads(); // Make sure min doesn't mess with max shared memory + min_val = BlockReduce(reduceStorage).Reduce(min_val, cub::Min{}, blockDim.x); + + __shared__ scale_type scale_sh; + __shared__ azp_type azp_sh; + + // Compute the scale and zero point and store them, only on the first thread + if (threadIdx.x == 0) { + float const scale_val = (max_val - min_val) / 255.0f; + // Use rounding to even (same as torch.round) + auto const azp_float = std::nearbyint(-128.0f - min_val / scale_val); + auto const azp_val = static_cast(azp_float); + + // Store the scale and azp into shared and global + scale[token_idx] = scale_sh = scale_val; + azp[token_idx] = azp_sh = azp_val; + } + + // Wait for the scale and azp to be computed + __syncthreads(); + + float const scale_val = scale_sh; + azp_type const azp_val = azp_sh; + + // Quantize the values + for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) { + auto const val = static_cast(input[token_idx * hidden_size + i]); + auto const quant_val = + int32_to_int8(float_to_int32_rn(val / scale_val) + azp_val); + out[token_idx * hidden_size + i] = quant_val; + } +} + } // namespace vllm void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] torch::Tensor const& input, // [..., hidden_size] - torch::Tensor const& scale) { + torch::Tensor const& scale, + c10::optional const& azp) { TORCH_CHECK(input.is_contiguous()); TORCH_CHECK(out.is_contiguous()); TORCH_CHECK(scale.numel() == 1); + TORCH_CHECK(!azp || azp->numel() == 1); int const hidden_size = input.size(-1); int const num_tokens = input.numel() / hidden_size; @@ -96,19 +225,29 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES( input.scalar_type(), "static_scaled_int8_quant_kernel", [&] { - vllm::static_scaled_int8_quant_kernel - <<>>(input.data_ptr(), - out.data_ptr(), - scale.data_ptr(), hidden_size); + if (!azp) { + vllm::static_scaled_int8_quant_kernel + <<>>( + input.data_ptr(), out.data_ptr(), + scale.data_ptr(), hidden_size); + } else { + vllm::static_scaled_int8_azp_quant_kernel + <<>>( + input.data_ptr(), out.data_ptr(), + scale.data_ptr(), azp->data_ptr(), + hidden_size); + } }); } void dynamic_scaled_int8_quant( torch::Tensor& out, // [..., hidden_size] torch::Tensor const& input, // [..., hidden_size] - torch::Tensor& scales) { + torch::Tensor& scales, c10::optional const& azp) { TORCH_CHECK(input.is_contiguous()); TORCH_CHECK(out.is_contiguous()); + TORCH_CHECK(scales.is_contiguous()); + TORCH_CHECK(!azp || azp->is_contiguous()); int const hidden_size = input.size(-1); int const num_tokens = input.numel() / hidden_size; @@ -117,9 +256,17 @@ void dynamic_scaled_int8_quant( const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES( input.scalar_type(), "dynamic_scaled_int8_quant_kernel", [&] { - vllm::dynamic_scaled_int8_quant_kernel - <<>>(input.data_ptr(), - out.data_ptr(), - scales.data_ptr(), hidden_size); + if (!azp) { + vllm::dynamic_scaled_int8_quant_kernel + <<>>( + input.data_ptr(), out.data_ptr(), + scales.data_ptr(), hidden_size); + } else { + vllm::dynamic_scaled_int8_azp_quant_kernel + <<>>( + input.data_ptr(), out.data_ptr(), + scales.data_ptr(), azp->data_ptr(), + hidden_size); + } }); } diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 51afeacfdc0ad..d7f7547fbef55 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -336,14 +336,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Compute int8 quantized tensor for given scaling factor. ops.def( - "static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale) -> " - "()"); + "static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale," + "Tensor? azp) -> ()"); ops.impl("static_scaled_int8_quant", torch::kCUDA, &static_scaled_int8_quant); // Compute int8 quantized tensor and scaling factor ops.def( - "dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale) -> " - "()"); + "dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale, " + "Tensor!? azp) -> ()"); ops.impl("dynamic_scaled_int8_quant", torch::kCUDA, &dynamic_scaled_int8_quant); } diff --git a/tests/kernels/test_int8_quant.py b/tests/kernels/test_int8_quant.py index a82ecb026482e..e93cb535d715a 100644 --- a/tests/kernels/test_int8_quant.py +++ b/tests/kernels/test_int8_quant.py @@ -13,14 +13,28 @@ SCALE = [0.1, 0.5, 0.8, 1.2, 2.1] -def opcheck_int8_quant(output, input, scale=None): - if scale is not None: - opcheck(torch.ops._C.static_scaled_int8_quant, (output, input, scale)) +def opcheck_int8_quant_static(output, input, scale, azp=None): + if azp is None: + opcheck(torch.ops._C.static_scaled_int8_quant, + (output, input, scale, None)) else: - scale = torch.empty((input.numel() // input.shape[-1], 1), - device=input.device, - dtype=torch.float32) - opcheck(torch.ops._C.dynamic_scaled_int8_quant, (output, input, scale)) + opcheck(torch.ops._C.static_scaled_int8_quant, + (output, input, scale, azp)) + + +def opcheck_int8_quant_dynamic(output, input, symmetric=True): + scale = torch.empty((input.numel() // input.shape[-1], 1), + device=input.device, + dtype=torch.float32) + if symmetric: + opcheck(torch.ops._C.dynamic_scaled_int8_quant, + (output, input, scale, None)) + else: + azp = torch.empty((input.numel() // input.shape[-1], 1), + device=input.device, + dtype=torch.int32) + opcheck(torch.ops._C.dynamic_scaled_int8_quant, + (output, input, scale, azp)) @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @@ -38,14 +52,56 @@ def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int, # reference ref_out, ref_scales = ref_dynamic_per_token_quant(x, torch.int8) # kernel - ops_out, ops_scales = scaled_int8_quant(x) + ops_out, ops_scales, _ = scaled_int8_quant(x) torch.testing.assert_close(ops_scales, ref_scales) - torch.testing.assert_close( - ops_out, ref_out, atol=1, - rtol=0.0) # big atol to account for rounding errors + # big atol to account for rounding errors + torch.testing.assert_close(ops_out, ref_out, atol=1, rtol=0.0) - opcheck_int8_quant(ops_out, x) + opcheck_int8_quant_dynamic(ops_out, x) + + +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@torch.inference_mode() +def test_dynamic_scaled_int8_azp_quant(num_tokens: int, hidden_size: int, + dtype: torch.dtype, seed: int) -> None: + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + int8_traits = torch.iinfo(torch.int8) + + x = torch.rand(num_tokens, hidden_size, dtype=dtype, + device="cuda") * 1000 - 300 + + x_token_max, _ = x.to(dtype=torch.float32).max(dim=1, keepdim=True) + x_token_min, _ = x.to(dtype=torch.float32).min(dim=1, keepdim=True) + + # calculate scale and azp, and adjust the range + scales = (x_token_max - x_token_min) / torch.tensor(255.0) + azps = torch.round(torch.tensor(-128.0) - x_token_min / scales).to( + torch.int32) + + torch_out = ((x / scales).round() + azps).clamp( + int8_traits.min, int8_traits.max).to(torch.int8) + assert torch_out.min() >= int8_traits.min and torch_out.max( + ) <= int8_traits.max + + ops_out = torch.empty_like(x, dtype=torch.int8) + scales_out = torch.empty_like(scales, dtype=torch.float32) + azp_out = torch.empty_like(azps, dtype=torch.int32) + torch.ops._C.dynamic_scaled_int8_quant(ops_out, x, scales_out, azp_out) + + if (not torch.allclose(scales_out, scales)): + print(torch.argmax(torch.abs(scales_out - scales))) + torch.testing.assert_close(scales_out, scales) + # big atol to account for rounding errors + torch.testing.assert_close(azp_out, azps, atol=1, rtol=0.0) + # if AZP is off by 1, after rounding-to-even, the output may be off by 2 + torch.testing.assert_close(ops_out, torch_out, atol=2, rtol=0.0) + + opcheck_int8_quant_dynamic(ops_out, x, False) @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @@ -62,14 +118,76 @@ def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int, int8_traits = torch.iinfo(torch.int8) x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000 - scale = torch.tensor([scale], dtype=torch.float32, device="cuda") + scale_arg = torch.tensor([scale], dtype=torch.float32, device="cuda") + + out1 = (x / scale_arg).round().clamp(int8_traits.min, + int8_traits.max).to(torch.int8) + out2, _, _ = scaled_int8_quant(x, scale_arg) + + # big atol to account for rounding errors + torch.testing.assert_close(out1, out2, atol=1, rtol=0.0) + + opcheck_int8_quant_static(out2, x, scale_arg) - out1 = (x / scale).round().clamp(int8_traits.min, - int8_traits.max).to(torch.int8) - out2, _ = scaled_int8_quant(x, scale) - torch.testing.assert_close( - out1, out2, atol=1, - rtol=0.0) # big atol to account for rounding errors +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("scale", SCALE[2:]) # Reduce test time +@pytest.mark.parametrize("azp", [-255, 54]) +@torch.inference_mode() +def test_static_scaled_int8_azp_quant(num_tokens: int, hidden_size: int, + dtype: torch.dtype, seed: int, + scale: float, azp: int) -> None: + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + int8_traits = torch.iinfo(torch.int8) + + x = torch.rand(num_tokens, hidden_size, dtype=dtype, + device="cuda") * 1000 - 300 + + out1 = ((x / scale).round() + azp).clamp(int8_traits.min, + int8_traits.max).to(torch.int8) + out2 = torch.empty_like(x, dtype=torch.int8) + scale_arg = torch.tensor([scale], dtype=torch.float32, device="cuda") + azp_arg = torch.tensor([azp], dtype=torch.int32, device="cuda") + + torch.ops._C.static_scaled_int8_quant(out2, x, scale_arg, azp_arg) + + # big atol to account for rounding errors + torch.testing.assert_close(out1, out2, atol=1, rtol=0.0) + + opcheck_int8_quant_static(out2, x, scale_arg, azp_arg) + + +@pytest.mark.parametrize("is_max", [True, False]) +@torch.inference_mode() +def test_static_scaled_int8_azp_quant_saturating_cast(is_max: bool) -> None: + # Test that the saturating cast works correctly for values near i32 max/min + + from numpy import inf, nextafter + + int32_traits = torch.iinfo(torch.int32) + val = float(int32_traits.max if is_max else int32_traits.min) + + x_vals = [[ + nextafter(val, inf), val + 1, val, val - 1, + nextafter(val, -inf) + ]] + x = torch.tensor(x_vals, dtype=torch.float32, device="cuda") + + # The calculation in the kernel is: cast(cast(x / scale) + azp) + # where cast is a saturating cast to type T. + # Scale is set to 1.0 so that the input values are the ones that are cast. + # AZP is set to 0 to make sure the int8 saturating cast is tested as well. + scale = torch.scalar_tensor(1.0, dtype=torch.float32, device="cuda") + azp = torch.scalar_tensor(0, dtype=torch.int32, device="cuda") + + int8_traits = torch.iinfo(torch.int8) + val_i8 = int8_traits.max if is_max else int8_traits.min + expected = torch.full((1, 5), val_i8, dtype=torch.int8, device="cuda") - opcheck_int8_quant(out2, x, scale) + out = torch.empty_like(expected) + torch.ops._C.static_scaled_int8_quant(out, x, scale, azp) + torch.testing.assert_close(expected, out, atol=0, rtol=0) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 74b3b69606c67..d5b3d7bc6dd5a 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -684,32 +684,43 @@ def scaled_fp8_quant( # int8 def scaled_int8_quant( - input: torch.Tensor, - scale: Optional[torch.Tensor] = None -) -> Tuple[torch.Tensor, torch.Tensor]: + input: torch.Tensor, + scale: Optional[torch.Tensor] = None, + azp: Optional[torch.Tensor] = None, + symmetric: bool = True +) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: """ - Quantize the input tensor to int8 and return the quantized tensor and scale. + Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp. Args: input: The input tensor to be quantized to int8. scale: Optional scaling factor for the int8 quantization. When not provided, we invoke dynamic-per-token quantization. + azp: Optional zero-point for the int8 quantization. + Must be provided for asymmetric quantization if `scale` is provided. + symmetric: Whether to use symmetric quantization (scale only, azp ignored). Returns: - Tuple[Torch.Tensor, Torch.Tensor] : Output int8 tensor and scales. + Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp. """ output = torch.empty_like(input, dtype=torch.int8) if scale is not None: # static-per-tensor quantization. - torch.ops._C.static_scaled_int8_quant(output, input, scale) - return output, scale + assert symmetric == ( + azp is + None), "azp must only be provided for asymmetric quantization." + torch.ops._C.static_scaled_int8_quant(output, input, scale, azp) + return output, scale, None # dynamic-per-token quantization. input_scales = torch.empty((input.numel() // input.shape[-1], 1), device=input.device, dtype=torch.float32) - torch.ops._C.dynamic_scaled_int8_quant(output, input, input_scales) - return output, input_scales + input_azp = None if symmetric else torch.empty_like(input_scales, + dtype=torch.int32) + torch.ops._C.dynamic_scaled_int8_quant(output, input, input_scales, + input_azp) + return output, input_scales, input_azp # qqq ops diff --git a/vllm/model_executor/layers/quantization/qqq.py b/vllm/model_executor/layers/quantization/qqq.py index c3434214a1cde..5bc3737520865 100644 --- a/vllm/model_executor/layers/quantization/qqq.py +++ b/vllm/model_executor/layers/quantization/qqq.py @@ -260,7 +260,7 @@ def apply( size_k = x_2d.shape[1] size_n = s_ch.shape[1] - x_int8, s_tok = ops.scaled_int8_quant(x_2d) + x_int8, s_tok, _ = ops.scaled_int8_quant(x_2d) output_2d = ops.marlin_qqq_gemm(x_int8, qweight, s_tok, s_ch, s_group, workspace, size_m, size_n, size_k) diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index a54e3cae73b14..887ee6605560c 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -188,7 +188,7 @@ def apply_int8_linear( # ops.scaled_int8_quant supports both dynamic and static quant. # * dynamic, layer.input_scale is None and x_scale computed from x. # * static, layer.input_scale is scalar and x_scale is input_scale. - x_q, x_scale = ops.scaled_int8_quant(input, input_scale) + x_q, x_scale, _ = ops.scaled_int8_quant(input, input_scale) return ops.cutlass_scaled_mm(x_q, weight,