-
Notifications
You must be signed in to change notification settings - Fork 39
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Output wrong result when using w4a8 matmul #284
Comments
We should use |
Thank you for your reply! I tried to rewrite the code using weight_transform, but the answer still seems to be wrong. I tested the w4a4 GEMM using the same pack function and the result is correct. So I think it is not the pack function that causes the error (though not efficient enough). new code: import torch
import bitblas
torch.random.manual_seed(42)
with torch.no_grad():
N = 32
M = 32
K = 64
A_dtype = "int8"
W_dtype = "int4"
accum_dtype = "int32"
out_dtype = "int32"
layout = "nt"
matmul_config = bitblas.MatmulConfig(
N=N, # N dimension
K=K, # K dimension
A_dtype=A_dtype, # activation A dtype
W_dtype=W_dtype, # weight W dtype
accum_dtype=accum_dtype, # accumulation dtype
out_dtype=out_dtype, # output dtype
layout=layout, # matrix layout, "nt" indicates the layout of A is non-transpose and the layout of W is transpose
)
matmul = bitblas.Matmul(config=matmul_config, enable_tuning=False)
qx = torch.randint(-128, 127, (N, K), dtype=torch.int32)
qweight = torch.randint(-8, 7, (M, K), dtype=torch.int32)
out_ref = qx @ qweight.T
print(out_ref)
qx = qx.cuda().type(torch.int8)
qweight_pack = matmul.weight_transform(qweight.type(torch.int8))
qweight_pack = qweight_pack.cuda()
out1 = matmul(qx, qweight_pack)
print(out1) new results:
|
Sorry for getting you confused, when utilize import torch
import bitblas
torch.random.manual_seed(42)
with torch.no_grad():
N = 32
M = 32
K = 64
A_dtype = "int8"
W_dtype = "uint4"
accum_dtype = "int32"
out_dtype = "int32"
layout = "nt"
matmul_config = bitblas.MatmulConfig(
M=M, # M dimension
N=N, # N dimension
K=K, # K dimension
A_dtype=A_dtype, # activation A dtype
W_dtype=W_dtype, # weight W dtype
accum_dtype=accum_dtype, # accumulation dtype
out_dtype=out_dtype, # output dtype
layout=layout, # matrix layout, "nt" indicates the layout of A is non-transpose and the layout of W is transpose
)
matmul = bitblas.Matmul(config=matmul_config)
qx = torch.randint(-127, 127, (M, K), dtype=torch.int32)
qweight = torch.randint(0, 15, (N, K), dtype=torch.int32)
zeros = 2 ** (4 - 1) - 1
qweight_with_offset = qweight - zeros # lowerd into int range -7 ~ 8
out_ref = qx @ qweight_with_offset.T
print(out_ref)
qx = qx.cuda().type(torch.int8)
qweight_pack = matmul.weight_transform(qweight.type(torch.int8))
qweight_pack = qweight_pack.cuda()
out1 = matmul(qx, qweight_pack)
print(out1) |
Thank you for your reply! I realized that I was using the wrong API (weight_transform) and after correcting the API (weight_transform -> transform_weight) I got the correct result. import torch
import bitblas
torch.random.manual_seed(42)
with torch.no_grad():
N = 32
M = 32
K = 64
A_dtype = "int8"
W_dtype = "int4"
accum_dtype = "int32"
out_dtype = "int32"
layout = "nt"
matmul_config = bitblas.MatmulConfig(
M=M, # M dimension
N=N, # N dimension
K=K, # K dimension
A_dtype=A_dtype, # activation A dtype
W_dtype=W_dtype, # weight W dtype
accum_dtype=accum_dtype, # accumulation dtype
out_dtype=out_dtype, # output dtype
layout=layout, # matrix layout, "nt" indicates the layout of A is non-transpose and the layout of W is transpose
propagate_a=False,
propagate_b=False,
fast_decoding=False,
)
matmul = bitblas.Matmul(config=matmul_config)
# num_bits = 4
# zeros = 2 ** (num_bits - 1)
A = torch.randint(-127, 127 + 1, (M, K), dtype=torch.int8).cuda()
B = torch.randint(-8, 7 + 1, (N, K), dtype=torch.int8).cuda()
out_ref = A.type(torch.float32) @ B.type(torch.float32).T
print(out_ref.type(torch.int32))
B = matmul.transform_weight(B)
out1 = matmul(A, B)
print(out1) |
I am trying to generate w4a8 matmul kernel based on BitBLAS, but it produces wrong results.
result:
environment:
The text was updated successfully, but these errors were encountered: