Skip to content

Commit

Permalink
Merge branch 'main' into minjean/welford_layernorm
Browse files Browse the repository at this point in the history
  • Loading branch information
min-jean-cho authored Feb 21, 2025
2 parents f1c78eb + 386d6c2 commit fb35f96
Show file tree
Hide file tree
Showing 4 changed files with 187 additions and 53 deletions.
14 changes: 3 additions & 11 deletions src/ATen/native/xpu/LinearInt4.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,20 +32,12 @@ Tensor _weight_int4pack_mm_xpu(
TORCH_CHECK(B.dim() == 2, __func__, " : expect B to 2d tensor.");

TORCH_CHECK(
qGroupSize == 32 || qGroupSize == 64 || qGroupSize == 128 ||
qGroupSize == 256,
qGroupSize == 16 || qGroupSize == 32 || qGroupSize == 64 ||
qGroupSize == 128 || qGroupSize == 256,
__func__,
": expect qGroupSize to be 32, 64, 128 or 256, got ",
": expect qGroupSize to be 16, 32, 64, 128 or 256, got ",
qGroupSize);

TORCH_CHECK(
qScaleAndZeros.dim() == 3 && qScaleAndZeros.size(0) == N &&
qScaleAndZeros.size(2) == 2,
__func__,
": expect qScaleAndZeros to be 3d tensor with sizes [",
N,
", :, 2]");

std::optional<Device> common_device = std::nullopt;
c10::impl::check_and_update_common_device(
common_device, A, "xpu::_weight_int4pack_mm", "A");
Expand Down
96 changes: 83 additions & 13 deletions src/ATen/native/xpu/sycl/Dequant_int4.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@ struct DequantInt4KernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
ScaleAndZeros(ScaleAndZeros),
weight_dequant(weight_dequant) {}

void sycl_ker_config_convention(sycl::handler& cgh) {
tmpT = sycl::local_accessor<float>(TileN, cgh);
}
void sycl_ker_config_convention(sycl::handler& cgh) {}
[[intel::reqd_sub_group_size(SgSize)]] void operator()(
sycl::nd_item<1> it) const {
int constexpr GroupN = TileN;
Expand All @@ -42,18 +40,17 @@ struct DequantInt4KernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
int g_n = g_idx_n * GroupN;
int g_k = g_idx_k * GroupK;

int ld_scale_zp = k / blocksize * 2;

auto sptr = ScaleAndZeros + (g_k / blocksize) * 2 + g_n * ld_scale_zp;
auto zptr = ScaleAndZeros + (g_k / blocksize) * 2 + g_n * ld_scale_zp + 1;
int ld_scale_zp = n * 2;
auto sptr = ScaleAndZeros + g_n * 2 + (g_k / blocksize) * ld_scale_zp;
auto zptr = ScaleAndZeros + g_n * 2 + (g_k / blocksize) * ld_scale_zp + 1;

auto bptr = weight_int4 + (g_k + g_n * k) / 2;
auto dbptr = weight_dequant + g_k * n + g_n;

float tmp[TileN];
bool high4 = sg_id % 2 != 0;
for (int in = 0; in < TileN; in++) {
int scale_offset = sg_id * TileK / blocksize * 2 + in * ld_scale_zp;
int scale_offset = in * 2 + sg_id * TileK / blocksize * ld_scale_zp;
int zp_offset = scale_offset;
float scale = *(sptr + scale_offset);
float zero_point = *(zptr + zp_offset);
Expand All @@ -63,6 +60,8 @@ struct DequantInt4KernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
: static_cast<int8_t>((srcu8 & 0x0f) - 8) * scale + zero_point;
}

float tmpT[TileN];

for (int in = 0; in < TileN; in++) {
for (int is = 0; is < SgSize; is++) {
auto shlv = select_from_group(sg, tmp[in], is);
Expand All @@ -83,7 +82,6 @@ struct DequantInt4KernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
const uint8_t* weight_int4;
const scalar_t* ScaleAndZeros;
scalar_t* weight_dequant;
sycl::local_accessor<float> tmpT;
};

void dequant_int4_kernel(
Expand Down Expand Up @@ -114,16 +112,88 @@ void dequant_int4_kernel(
std::is_same_v<scalar_t, at::Half>,
sycl::half,
sycl::ext::oneapi::bfloat16>;

DequantInt4KernelFunctor<scalar_sycl_t, 32, TileK, TileN, SgSize> kfn =
DequantInt4KernelFunctor<scalar_sycl_t, 32, TileK, TileN, SgSize>(
switch (qGroupSize) {
case 16: {
auto kfn = DequantInt4KernelFunctor<
scalar_sycl_t,
16,
TileK,
TileN,
SgSize>(
n,
k,
reinterpret_cast<const uint8_t*>(weight_int4.data_ptr()),
reinterpret_cast<const scalar_sycl_t*>(
qScaleAndZeros.data_ptr<scalar_t>()),
reinterpret_cast<scalar_sycl_t*>(weight.data_ptr<scalar_t>()));
sycl_kernel_submit(global_range, local_range, sycl_queue, kfn);
sycl_kernel_submit(global_range, local_range, sycl_queue, kfn);
break;
}
case 32: {
auto kfn = DequantInt4KernelFunctor<
scalar_sycl_t,
32,
TileK,
TileN,
SgSize>(
n,
k,
reinterpret_cast<const uint8_t*>(weight_int4.data_ptr()),
reinterpret_cast<const scalar_sycl_t*>(
qScaleAndZeros.data_ptr<scalar_t>()),
reinterpret_cast<scalar_sycl_t*>(weight.data_ptr<scalar_t>()));
sycl_kernel_submit(global_range, local_range, sycl_queue, kfn);
break;
}
case 64: {
auto kfn = DequantInt4KernelFunctor<
scalar_sycl_t,
64,
TileK,
TileN,
SgSize>(
n,
k,
reinterpret_cast<const uint8_t*>(weight_int4.data_ptr()),
reinterpret_cast<const scalar_sycl_t*>(
qScaleAndZeros.data_ptr<scalar_t>()),
reinterpret_cast<scalar_sycl_t*>(weight.data_ptr<scalar_t>()));
sycl_kernel_submit(global_range, local_range, sycl_queue, kfn);
break;
}
case 128: {
auto kfn = DequantInt4KernelFunctor<
scalar_sycl_t,
128,
TileK,
TileN,
SgSize>(
n,
k,
reinterpret_cast<const uint8_t*>(weight_int4.data_ptr()),
reinterpret_cast<const scalar_sycl_t*>(
qScaleAndZeros.data_ptr<scalar_t>()),
reinterpret_cast<scalar_sycl_t*>(weight.data_ptr<scalar_t>()));
sycl_kernel_submit(global_range, local_range, sycl_queue, kfn);
break;
}
case 256: {
auto kfn = DequantInt4KernelFunctor<
scalar_sycl_t,
256,
TileK,
TileN,
SgSize>(
n,
k,
reinterpret_cast<const uint8_t*>(weight_int4.data_ptr()),
reinterpret_cast<const scalar_sycl_t*>(
qScaleAndZeros.data_ptr<scalar_t>()),
reinterpret_cast<scalar_sycl_t*>(weight.data_ptr<scalar_t>()));
sycl_kernel_submit(global_range, local_range, sycl_queue, kfn);
break;
}
}
});
}

Expand Down
119 changes: 96 additions & 23 deletions src/ATen/native/xpu/sycl/LinearInt4.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ struct LinearInt4KernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
int constexpr SgSize = 16;
int constexpr blocksize = block_size;
using scalarx2_t = sycl::vec<scalar_t, 2>;

int ld_scale_zp = 2 * n;
if (k % (SgSize * 32 * Unroll) == 0) {
int constexpr TileK = 32;
int constexpr GroupK = SgSize * TileK;
Expand All @@ -54,8 +54,8 @@ struct LinearInt4KernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
auto sg = it.get_sub_group();
int sg_id = sg.get_local_id()[0];
int g_n = g_idx;
auto sptr = ScaleAndZeros + g_n * ldb * 2;
auto zptr = ScaleAndZeros + g_n * ldb * 2 + 1;
auto sptr = ScaleAndZeros + g_n * 2;
auto zptr = ScaleAndZeros + g_n * 2 + 1;
auto bptr = B + g_n * k / 2;
auto aptr = A;
auto cptr = C + g_n;
Expand All @@ -67,8 +67,8 @@ struct LinearInt4KernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
uint8_t tmps8[TileK / 2];
*(sycl::vec<uint8_t, TileK / 2>*)tmps8 =
*(sycl::vec<uint8_t, TileK / 2>*)(bptr + sg_id * TileK / 2);
int scale_offset = sg_id * (TileK / blocksize) * 2;
int zp_offset = sg_id * (TileK / blocksize) * 2;
int scale_offset = sg_id * (TileK / blocksize) * ld_scale_zp;
int zp_offset = sg_id * (TileK / blocksize) * ld_scale_zp;
scalar_t scale = *(sptr + scale_offset);
scalar_t zero_point = *(zptr + zp_offset);
#pragma unroll
Expand All @@ -80,7 +80,8 @@ struct LinearInt4KernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
auto tmpAmulB = tmpA * (tmpB * scale + zero_point);
tmpAcc += {tmpAmulB[0], tmpAmulB[1]};
}
sptr += (GroupK / blocksize) * 2;
sptr += (GroupK / blocksize) * ld_scale_zp;
zptr += (GroupK / blocksize) * ld_scale_zp;
aptr += GroupK;
bptr += GroupK / 2;
}
Expand All @@ -94,15 +95,16 @@ struct LinearInt4KernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
int constexpr TileK = 32;
int constexpr GroupK = SgSize * TileK;
int k_body = padto_le(k, GroupK * Unroll);

int constexpr TileK2 = 8;
int constexpr GroupK2 = SgSize * TileK2;
int k_body2 = padto_le(k, GroupK2 * Unroll);
int g_idx = it.get_group(0);
auto sg = it.get_sub_group();
int sg_id = sg.get_local_id()[0];
int g_n = g_idx;
auto sptr = ScaleAndZeros + g_n * ldb * 2;
auto zptr = ScaleAndZeros + g_n * ldb * 2 + 1;
auto sptr = ScaleAndZeros + g_n * 2;
auto zptr = ScaleAndZeros + g_n * 2 + 1;
auto bptr = B + g_n * k / 2;
auto aptr = A;
auto cptr = C + g_n;
Expand All @@ -115,8 +117,9 @@ struct LinearInt4KernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
*(sycl::vec<uint8_t, TileK / 2>*)tmps8 =
*(sycl::vec<uint8_t, TileK / 2>*)(bptr + sg_id * TileK / 2);

int scale_offset = sg_id * (TileK / blocksize) * 2;
int zp_offset = sg_id * (TileK / blocksize) * 2;
int scale_offset = sg_id * TileK / blocksize * ld_scale_zp;
int zp_offset = sg_id * TileK / blocksize * ld_scale_zp;

scalar_t scale = *(sptr + scale_offset);
scalar_t zero_point = *(zptr + zp_offset);
#pragma unroll
Expand All @@ -128,7 +131,8 @@ struct LinearInt4KernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
auto tmpAmulB = tmpA * (tmpB * scale + zero_point);
tmpAcc += {tmpAmulB[0], tmpAmulB[1]};
}
sptr += (GroupK / blocksize) * 2;
sptr += (GroupK / blocksize) * ld_scale_zp;
zptr += (GroupK / blocksize) * ld_scale_zp;
aptr += GroupK;
bptr += GroupK / 2;
}
Expand All @@ -141,8 +145,8 @@ struct LinearInt4KernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
*(sycl::vec<uint8_t, TileK2 / 2>*)tmps8 =
*(sycl::vec<uint8_t, TileK2 / 2>*)(bptr + sg_id * TileK2 / 2);

int scale_offset = sg_id * (TileK2 / blocksize) * 2;
int zp_offset = sg_id * (TileK2 / blocksize) * 2;
int scale_offset = sg_id * TileK2 / blocksize * ld_scale_zp;
int zp_offset = sg_id * TileK2 / blocksize * ld_scale_zp;
scalar_t scale = *(sptr + scale_offset);
scalar_t zero_point = *(zptr + zp_offset);
#pragma unroll
Expand All @@ -154,7 +158,8 @@ struct LinearInt4KernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
auto tmpAmulB = tmpA * (tmpB * scale + zero_point);
tmpAcc += {tmpAmulB[0], tmpAmulB[1]};
}
sptr += (GroupK2 / blocksize) * 2;
sptr += (GroupK2 / blocksize) * ld_scale_zp;
zptr += (GroupK2 / blocksize) * ld_scale_zp;
aptr += GroupK2;
bptr += GroupK2 / 2;
}
Expand All @@ -163,18 +168,21 @@ struct LinearInt4KernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
if (i + SgSize * 2 <= k) {
for (; i < k; i += SgSize * 2) {
uint8_t tmps8 = *(bptr + sg_id);
scalarx2_t tmpB = {
static_cast<int8_t>((tmps8 & 0x0f) - 8),
static_cast<int8_t>((tmps8 >> 4) - 8)};

int scale_offset = sg_id * (2 / blocksize) * 2;
int zp_offset = sg_id * (2 / blocksize) * 2;
int scale_offset = sg_id * 2 / blocksize * ld_scale_zp;
int zp_offset = sg_id * 2 / blocksize * ld_scale_zp;
scalar_t scale = *(sptr + scale_offset);
scalar_t zero_point = *(zptr + zp_offset);

scalarx2_t tmpB = {
static_cast<int8_t>((tmps8 & 0x0f) - 8),
static_cast<int8_t>((tmps8 >> 4) - 8)};
scalarx2_t tmpA = *(scalarx2_t*)(aptr + sg_id * 2);

auto tmpAmulB = tmpA * (tmpB * scale + zero_point);
tmpAcc += {tmpAmulB[0], tmpAmulB[1]};
sptr += (SgSize * 2 / blocksize) * 2;
sptr += (SgSize * 2 / blocksize) * ld_scale_zp;
zptr += (SgSize * 2 / blocksize) * ld_scale_zp;
aptr += SgSize * 2;
bptr += SgSize * 2 / 2;
}
Expand Down Expand Up @@ -229,8 +237,40 @@ void linear_int4_kernel(
reinterpret_cast<scalar_sycl_t*>(C.data_ptr<scalar_t>());
scalar_sycl_t* scale_zeros_data = reinterpret_cast<scalar_sycl_t*>(
qScaleAndZeros.data_ptr<scalar_t>());
LinearInt4KernelFunctor<scalar_sycl_t, 32> kfn =
LinearInt4KernelFunctor<scalar_sycl_t, 32>(

switch (qGroupSize) {
case 16: {
auto kfn = LinearInt4KernelFunctor<scalar_sycl_t, 16>(
input_data,
weight_data,
output_data,
scale_zeros_data,
m,
n,
k,
k,
k / qGroupSize,
n);
sycl_kernel_submit(global_range, local_range, sycl_queue, kfn);
break;
}
case 32: {
auto kfn = LinearInt4KernelFunctor<scalar_sycl_t, 32>(
input_data,
weight_data,
output_data,
scale_zeros_data,
m,
n,
k,
k,
k / qGroupSize,
n);
sycl_kernel_submit(global_range, local_range, sycl_queue, kfn);
break;
}
case 64: {
auto kfn = LinearInt4KernelFunctor<scalar_sycl_t, 64>(
input_data,
weight_data,
output_data,
Expand All @@ -241,7 +281,40 @@ void linear_int4_kernel(
k,
k / qGroupSize,
n);
sycl_kernel_submit(global_range, local_range, sycl_queue, kfn);
sycl_kernel_submit(global_range, local_range, sycl_queue, kfn);
break;
}
case 128: {
auto kfn = LinearInt4KernelFunctor<scalar_sycl_t, 128>(
input_data,
weight_data,
output_data,
scale_zeros_data,
m,
n,
k,
k,
k / qGroupSize,
n);
sycl_kernel_submit(global_range, local_range, sycl_queue, kfn);
break;
}
case 256: {
auto kfn = LinearInt4KernelFunctor<scalar_sycl_t, 256>(
input_data,
weight_data,
output_data,
scale_zeros_data,
m,
n,
k,
k,
k / qGroupSize,
n);
sycl_kernel_submit(global_range, local_range, sycl_queue, kfn);
break;
}
}
});
}

Expand Down
Loading

0 comments on commit fb35f96

Please sign in to comment.