From 386d6c2570c27c89ffea6187686e74322a87f105 Mon Sep 17 00:00:00 2001 From: "Swift.Sun" Date: Fri, 21 Feb 2025 15:04:15 +0800 Subject: [PATCH] Modify int4 kernel scale&zp layout (#1379) 1 Modify the layout of scale and zeropt to row-major. 2 Add more groupsize --- src/ATen/native/xpu/LinearInt4.cpp | 14 +-- src/ATen/native/xpu/sycl/Dequant_int4.cpp | 96 ++++++++++++++--- src/ATen/native/xpu/sycl/LinearInt4.cpp | 119 +++++++++++++++++----- test/xpu/test_linalg_xpu.py | 11 +- 4 files changed, 187 insertions(+), 53 deletions(-) diff --git a/src/ATen/native/xpu/LinearInt4.cpp b/src/ATen/native/xpu/LinearInt4.cpp index d2b52ee96..992ef61aa 100644 --- a/src/ATen/native/xpu/LinearInt4.cpp +++ b/src/ATen/native/xpu/LinearInt4.cpp @@ -32,20 +32,12 @@ Tensor _weight_int4pack_mm_xpu( TORCH_CHECK(B.dim() == 2, __func__, " : expect B to 2d tensor."); TORCH_CHECK( - qGroupSize == 32 || qGroupSize == 64 || qGroupSize == 128 || - qGroupSize == 256, + qGroupSize == 16 || qGroupSize == 32 || qGroupSize == 64 || + qGroupSize == 128 || qGroupSize == 256, __func__, - ": expect qGroupSize to be 32, 64, 128 or 256, got ", + ": expect qGroupSize to be 16, 32, 64, 128 or 256, got ", qGroupSize); - TORCH_CHECK( - qScaleAndZeros.dim() == 3 && qScaleAndZeros.size(0) == N && - qScaleAndZeros.size(2) == 2, - __func__, - ": expect qScaleAndZeros to be 3d tensor with sizes [", - N, - ", :, 2]"); - std::optional common_device = std::nullopt; c10::impl::check_and_update_common_device( common_device, A, "xpu::_weight_int4pack_mm", "A"); diff --git a/src/ATen/native/xpu/sycl/Dequant_int4.cpp b/src/ATen/native/xpu/sycl/Dequant_int4.cpp index 5d51a7cdc..612d0be29 100644 --- a/src/ATen/native/xpu/sycl/Dequant_int4.cpp +++ b/src/ATen/native/xpu/sycl/Dequant_int4.cpp @@ -21,9 +21,7 @@ struct DequantInt4KernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { ScaleAndZeros(ScaleAndZeros), weight_dequant(weight_dequant) {} - void sycl_ker_config_convention(sycl::handler& cgh) { - tmpT = sycl::local_accessor(TileN, cgh); - } + void sycl_ker_config_convention(sycl::handler& cgh) {} [[intel::reqd_sub_group_size(SgSize)]] void operator()( sycl::nd_item<1> it) const { int constexpr GroupN = TileN; @@ -42,10 +40,9 @@ struct DequantInt4KernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { int g_n = g_idx_n * GroupN; int g_k = g_idx_k * GroupK; - int ld_scale_zp = k / blocksize * 2; - - auto sptr = ScaleAndZeros + (g_k / blocksize) * 2 + g_n * ld_scale_zp; - auto zptr = ScaleAndZeros + (g_k / blocksize) * 2 + g_n * ld_scale_zp + 1; + int ld_scale_zp = n * 2; + auto sptr = ScaleAndZeros + g_n * 2 + (g_k / blocksize) * ld_scale_zp; + auto zptr = ScaleAndZeros + g_n * 2 + (g_k / blocksize) * ld_scale_zp + 1; auto bptr = weight_int4 + (g_k + g_n * k) / 2; auto dbptr = weight_dequant + g_k * n + g_n; @@ -53,7 +50,7 @@ struct DequantInt4KernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { float tmp[TileN]; bool high4 = sg_id % 2 != 0; for (int in = 0; in < TileN; in++) { - int scale_offset = sg_id * TileK / blocksize * 2 + in * ld_scale_zp; + int scale_offset = in * 2 + sg_id * TileK / blocksize * ld_scale_zp; int zp_offset = scale_offset; float scale = *(sptr + scale_offset); float zero_point = *(zptr + zp_offset); @@ -63,6 +60,8 @@ struct DequantInt4KernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { : static_cast((srcu8 & 0x0f) - 8) * scale + zero_point; } + float tmpT[TileN]; + for (int in = 0; in < TileN; in++) { for (int is = 0; is < SgSize; is++) { auto shlv = select_from_group(sg, tmp[in], is); @@ -83,7 +82,6 @@ struct DequantInt4KernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { const uint8_t* weight_int4; const scalar_t* ScaleAndZeros; scalar_t* weight_dequant; - sycl::local_accessor tmpT; }; void dequant_int4_kernel( @@ -114,16 +112,88 @@ void dequant_int4_kernel( std::is_same_v, sycl::half, sycl::ext::oneapi::bfloat16>; - - DequantInt4KernelFunctor kfn = - DequantInt4KernelFunctor( + switch (qGroupSize) { + case 16: { + auto kfn = DequantInt4KernelFunctor< + scalar_sycl_t, + 16, + TileK, + TileN, + SgSize>( n, k, reinterpret_cast(weight_int4.data_ptr()), reinterpret_cast( qScaleAndZeros.data_ptr()), reinterpret_cast(weight.data_ptr())); - sycl_kernel_submit(global_range, local_range, sycl_queue, kfn); + sycl_kernel_submit(global_range, local_range, sycl_queue, kfn); + break; + } + case 32: { + auto kfn = DequantInt4KernelFunctor< + scalar_sycl_t, + 32, + TileK, + TileN, + SgSize>( + n, + k, + reinterpret_cast(weight_int4.data_ptr()), + reinterpret_cast( + qScaleAndZeros.data_ptr()), + reinterpret_cast(weight.data_ptr())); + sycl_kernel_submit(global_range, local_range, sycl_queue, kfn); + break; + } + case 64: { + auto kfn = DequantInt4KernelFunctor< + scalar_sycl_t, + 64, + TileK, + TileN, + SgSize>( + n, + k, + reinterpret_cast(weight_int4.data_ptr()), + reinterpret_cast( + qScaleAndZeros.data_ptr()), + reinterpret_cast(weight.data_ptr())); + sycl_kernel_submit(global_range, local_range, sycl_queue, kfn); + break; + } + case 128: { + auto kfn = DequantInt4KernelFunctor< + scalar_sycl_t, + 128, + TileK, + TileN, + SgSize>( + n, + k, + reinterpret_cast(weight_int4.data_ptr()), + reinterpret_cast( + qScaleAndZeros.data_ptr()), + reinterpret_cast(weight.data_ptr())); + sycl_kernel_submit(global_range, local_range, sycl_queue, kfn); + break; + } + case 256: { + auto kfn = DequantInt4KernelFunctor< + scalar_sycl_t, + 256, + TileK, + TileN, + SgSize>( + n, + k, + reinterpret_cast(weight_int4.data_ptr()), + reinterpret_cast( + qScaleAndZeros.data_ptr()), + reinterpret_cast(weight.data_ptr())); + sycl_kernel_submit(global_range, local_range, sycl_queue, kfn); + break; + } + } }); } diff --git a/src/ATen/native/xpu/sycl/LinearInt4.cpp b/src/ATen/native/xpu/sycl/LinearInt4.cpp index 9378187c9..c18d363b0 100644 --- a/src/ATen/native/xpu/sycl/LinearInt4.cpp +++ b/src/ATen/native/xpu/sycl/LinearInt4.cpp @@ -45,7 +45,7 @@ struct LinearInt4KernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { int constexpr SgSize = 16; int constexpr blocksize = block_size; using scalarx2_t = sycl::vec; - + int ld_scale_zp = 2 * n; if (k % (SgSize * 32 * Unroll) == 0) { int constexpr TileK = 32; int constexpr GroupK = SgSize * TileK; @@ -54,8 +54,8 @@ struct LinearInt4KernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { auto sg = it.get_sub_group(); int sg_id = sg.get_local_id()[0]; int g_n = g_idx; - auto sptr = ScaleAndZeros + g_n * ldb * 2; - auto zptr = ScaleAndZeros + g_n * ldb * 2 + 1; + auto sptr = ScaleAndZeros + g_n * 2; + auto zptr = ScaleAndZeros + g_n * 2 + 1; auto bptr = B + g_n * k / 2; auto aptr = A; auto cptr = C + g_n; @@ -67,8 +67,8 @@ struct LinearInt4KernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { uint8_t tmps8[TileK / 2]; *(sycl::vec*)tmps8 = *(sycl::vec*)(bptr + sg_id * TileK / 2); - int scale_offset = sg_id * (TileK / blocksize) * 2; - int zp_offset = sg_id * (TileK / blocksize) * 2; + int scale_offset = sg_id * (TileK / blocksize) * ld_scale_zp; + int zp_offset = sg_id * (TileK / blocksize) * ld_scale_zp; scalar_t scale = *(sptr + scale_offset); scalar_t zero_point = *(zptr + zp_offset); #pragma unroll @@ -80,7 +80,8 @@ struct LinearInt4KernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { auto tmpAmulB = tmpA * (tmpB * scale + zero_point); tmpAcc += {tmpAmulB[0], tmpAmulB[1]}; } - sptr += (GroupK / blocksize) * 2; + sptr += (GroupK / blocksize) * ld_scale_zp; + zptr += (GroupK / blocksize) * ld_scale_zp; aptr += GroupK; bptr += GroupK / 2; } @@ -94,6 +95,7 @@ struct LinearInt4KernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { int constexpr TileK = 32; int constexpr GroupK = SgSize * TileK; int k_body = padto_le(k, GroupK * Unroll); + int constexpr TileK2 = 8; int constexpr GroupK2 = SgSize * TileK2; int k_body2 = padto_le(k, GroupK2 * Unroll); @@ -101,8 +103,8 @@ struct LinearInt4KernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { auto sg = it.get_sub_group(); int sg_id = sg.get_local_id()[0]; int g_n = g_idx; - auto sptr = ScaleAndZeros + g_n * ldb * 2; - auto zptr = ScaleAndZeros + g_n * ldb * 2 + 1; + auto sptr = ScaleAndZeros + g_n * 2; + auto zptr = ScaleAndZeros + g_n * 2 + 1; auto bptr = B + g_n * k / 2; auto aptr = A; auto cptr = C + g_n; @@ -115,8 +117,9 @@ struct LinearInt4KernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { *(sycl::vec*)tmps8 = *(sycl::vec*)(bptr + sg_id * TileK / 2); - int scale_offset = sg_id * (TileK / blocksize) * 2; - int zp_offset = sg_id * (TileK / blocksize) * 2; + int scale_offset = sg_id * TileK / blocksize * ld_scale_zp; + int zp_offset = sg_id * TileK / blocksize * ld_scale_zp; + scalar_t scale = *(sptr + scale_offset); scalar_t zero_point = *(zptr + zp_offset); #pragma unroll @@ -128,7 +131,8 @@ struct LinearInt4KernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { auto tmpAmulB = tmpA * (tmpB * scale + zero_point); tmpAcc += {tmpAmulB[0], tmpAmulB[1]}; } - sptr += (GroupK / blocksize) * 2; + sptr += (GroupK / blocksize) * ld_scale_zp; + zptr += (GroupK / blocksize) * ld_scale_zp; aptr += GroupK; bptr += GroupK / 2; } @@ -141,8 +145,8 @@ struct LinearInt4KernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { *(sycl::vec*)tmps8 = *(sycl::vec*)(bptr + sg_id * TileK2 / 2); - int scale_offset = sg_id * (TileK2 / blocksize) * 2; - int zp_offset = sg_id * (TileK2 / blocksize) * 2; + int scale_offset = sg_id * TileK2 / blocksize * ld_scale_zp; + int zp_offset = sg_id * TileK2 / blocksize * ld_scale_zp; scalar_t scale = *(sptr + scale_offset); scalar_t zero_point = *(zptr + zp_offset); #pragma unroll @@ -154,7 +158,8 @@ struct LinearInt4KernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { auto tmpAmulB = tmpA * (tmpB * scale + zero_point); tmpAcc += {tmpAmulB[0], tmpAmulB[1]}; } - sptr += (GroupK2 / blocksize) * 2; + sptr += (GroupK2 / blocksize) * ld_scale_zp; + zptr += (GroupK2 / blocksize) * ld_scale_zp; aptr += GroupK2; bptr += GroupK2 / 2; } @@ -163,18 +168,21 @@ struct LinearInt4KernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { if (i + SgSize * 2 <= k) { for (; i < k; i += SgSize * 2) { uint8_t tmps8 = *(bptr + sg_id); - scalarx2_t tmpB = { - static_cast((tmps8 & 0x0f) - 8), - static_cast((tmps8 >> 4) - 8)}; - int scale_offset = sg_id * (2 / blocksize) * 2; - int zp_offset = sg_id * (2 / blocksize) * 2; + int scale_offset = sg_id * 2 / blocksize * ld_scale_zp; + int zp_offset = sg_id * 2 / blocksize * ld_scale_zp; scalar_t scale = *(sptr + scale_offset); scalar_t zero_point = *(zptr + zp_offset); + + scalarx2_t tmpB = { + static_cast((tmps8 & 0x0f) - 8), + static_cast((tmps8 >> 4) - 8)}; scalarx2_t tmpA = *(scalarx2_t*)(aptr + sg_id * 2); + auto tmpAmulB = tmpA * (tmpB * scale + zero_point); tmpAcc += {tmpAmulB[0], tmpAmulB[1]}; - sptr += (SgSize * 2 / blocksize) * 2; + sptr += (SgSize * 2 / blocksize) * ld_scale_zp; + zptr += (SgSize * 2 / blocksize) * ld_scale_zp; aptr += SgSize * 2; bptr += SgSize * 2 / 2; } @@ -229,8 +237,40 @@ void linear_int4_kernel( reinterpret_cast(C.data_ptr()); scalar_sycl_t* scale_zeros_data = reinterpret_cast( qScaleAndZeros.data_ptr()); - LinearInt4KernelFunctor kfn = - LinearInt4KernelFunctor( + + switch (qGroupSize) { + case 16: { + auto kfn = LinearInt4KernelFunctor( + input_data, + weight_data, + output_data, + scale_zeros_data, + m, + n, + k, + k, + k / qGroupSize, + n); + sycl_kernel_submit(global_range, local_range, sycl_queue, kfn); + break; + } + case 32: { + auto kfn = LinearInt4KernelFunctor( + input_data, + weight_data, + output_data, + scale_zeros_data, + m, + n, + k, + k, + k / qGroupSize, + n); + sycl_kernel_submit(global_range, local_range, sycl_queue, kfn); + break; + } + case 64: { + auto kfn = LinearInt4KernelFunctor( input_data, weight_data, output_data, @@ -241,7 +281,40 @@ void linear_int4_kernel( k, k / qGroupSize, n); - sycl_kernel_submit(global_range, local_range, sycl_queue, kfn); + sycl_kernel_submit(global_range, local_range, sycl_queue, kfn); + break; + } + case 128: { + auto kfn = LinearInt4KernelFunctor( + input_data, + weight_data, + output_data, + scale_zeros_data, + m, + n, + k, + k, + k / qGroupSize, + n); + sycl_kernel_submit(global_range, local_range, sycl_queue, kfn); + break; + } + case 256: { + auto kfn = LinearInt4KernelFunctor( + input_data, + weight_data, + output_data, + scale_zeros_data, + m, + n, + k, + k, + k / qGroupSize, + n); + sycl_kernel_submit(global_range, local_range, sycl_queue, kfn); + break; + } + } }); } diff --git a/test/xpu/test_linalg_xpu.py b/test/xpu/test_linalg_xpu.py index be94a936b..a986e42f4 100644 --- a/test/xpu/test_linalg_xpu.py +++ b/test/xpu/test_linalg_xpu.py @@ -255,11 +255,10 @@ def _group_quantize_tensor(w, n_bit=4, q_group_size=16): assert torch.isnan(zeros).sum() == 0 out = to_quant.sub(min_val).div(scales).round().clamp_(min_int, max_int) - assert torch.isnan(out).sum() == 0 out = out.to(dtype=torch.uint8).reshape(w.shape) - + # The cpu uses big endian while the xpu uses little endian if out.device.type == "xpu": out = (out[::, 1::2] << 4 | out[::, ::2]).to(torch.uint8) elif out.device != torch.device("cpu"): @@ -275,9 +274,7 @@ def _group_quantize_tensor(w, n_bit=4, q_group_size=16): ], 2, ) - - if out.device.type != "xpu": - scales_and_zeros = scales_and_zeros.transpose(0, 1).contiguous() + scales_and_zeros = scales_and_zeros.transpose(0, 1).contiguous() return out, scales_and_zeros def convert_weight_to_int4pack(b): @@ -317,7 +314,9 @@ def weight_int4pack_mm(a, b_int4pack, b_scales_and_zeros): torch.manual_seed(1) a_bf16 = torch.rand((m, k), dtype=torch.bfloat16, device=device) - b_bf16 = torch.rand((k, n), dtype=torch.bfloat16, device=device) + b_bf16 = torch.rand((k, n), dtype=torch.bfloat16, device=device) * torch.rand( + (k, 1), dtype=torch.bfloat16, device=device + ) b_int4pack, b_scales_and_zeros_bf16 = convert_weight_to_int4pack(b_bf16) for dtype in [torch.bfloat16] + (