Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Output wrong result when using w4a8 matmul #284

Closed
hyx1999 opened this issue Jan 20, 2025 · 4 comments
Closed

Output wrong result when using w4a8 matmul #284

hyx1999 opened this issue Jan 20, 2025 · 4 comments

Comments

@hyx1999
Copy link

hyx1999 commented Jan 20, 2025

I am trying to generate w4a8 matmul kernel based on BitBLAS, but it produces wrong results.

import torch
import bitblas

torch.random.manual_seed(42)

def two_compl(x, bits: int):
    return torch.where(x < 0, 2**bits + x, x)

def get_minq_maxq(bits: int, sym: bool):
    if sym:
        maxq = torch.tensor(2 ** (bits - 1) - 1)
        minq = torch.tensor(-maxq - 1)
    else:
        maxq = torch.tensor(2**bits - 1)
        minq = torch.tensor(0)

    return minq, maxq

# Pack the int tensor. Each uint8 stores two int4 value.
def pack_i4(q):
    assert torch.is_signed(q), "The tensor to be packed should be signed int"
    minq, maxq = get_minq_maxq(4, True)
    assert torch.all(torch.logical_and(q >= minq, q <= maxq))

    q_i8 = two_compl(q.to(dtype=torch.int8), 4).to(torch.uint8)
    q_i4 = q_i8[:, 0::2] | (q_i8[:, 1::2] << 4)
    return q_i4


with torch.no_grad():
    N = 32
    M = 32
    K = 64
    A_dtype = "int8"
    W_dtype = "int4"
    accum_dtype = "int32"
    out_dtype = "int32"
    layout = "nt"
    matmul_config = bitblas.MatmulConfig(
        N=N,  # N dimension
        K=K,  # K dimension
        A_dtype=A_dtype,  # activation A dtype
        W_dtype=W_dtype,  # weight W dtype
        accum_dtype=accum_dtype,  # accumulation dtype
        out_dtype=out_dtype,  # output dtype
        layout=layout,  # matrix layout, "nt" indicates the layout of A is non-transpose and the layout of W is transpose
    )
    matmul = bitblas.Matmul(config=matmul_config, enable_tuning=False)
    
    qx = torch.randint(-128, 127, (N, K), dtype=torch.int32)
    qweight = torch.randint(-8, 7, (M, K), dtype=torch.int32)
    
    qweight_pack = pack_i4(qweight.type(torch.int8))
    
    out_ref = qx @ qweight.T    
    print(out_ref)
    
    qx = qx.cuda().type(torch.int8)
    qweight_pack = qweight_pack.cuda()
    out1 = matmul(qx, qweight_pack)
    print(out1)

result:

torch result:
tensor([[-2650, -3869,  1951,  ...,   903, -3034, -2176],
        [  954,   975,  2805,  ...,  2581,   868,  2934],
        [ 1316,  4411,  4804,  ..., -1021,   299,  2011],
        ...,
        [-2285, -3522,  2887,  ...,  1284,  -629,  2081],
        [ 1525,  1015,  -243,  ...,   187,  2356, -1012],
        [ 1251,   907,  3809,  ..., -1848,  1797, -2703]], dtype=torch.int32)

biblas result:
tensor([[  518, -3261, -2625,  ..., -2201,  1782,  3008],
        [ 3090,   359,   717,  ...,  4893,  4684, -5378],
        [-1084,   443, -1868,  ...,   835, -2245,  4411],
        ...,
        [ 1651,  4718, -1337,  ...,     4, -2357, -1263],
        [-2395,  1479, -1939,  ..., -2501,   404,   924],
        [ -557, -1109, -2847,  ...,  3736,  2549, -1071]], device='cuda:0',
       dtype=torch.int32)

environment:

torch: 2.5.1+cu121
nvidia-smi: NVIDIA-SMI 535.113.01  Driver Version: 535.113.01  CUDA Version: 12.2
GPU: NVIDIA RTX A6000
@LeiWang1999
Copy link
Contributor

We should use matmul.transform_weight instead of using our own pack functions, as the weight transformation is complex if we want to have good performance, (rather than weight pack, we also have fast dequantize layout transform and ladder layout transform)

@hyx1999
Copy link
Author

hyx1999 commented Jan 20, 2025

Thank you for your reply! I tried to rewrite the code using weight_transform, but the answer still seems to be wrong. I tested the w4a4 GEMM using the same pack function and the result is correct. So I think it is not the pack function that causes the error (though not efficient enough).

new code:

import torch
import bitblas

torch.random.manual_seed(42)

with torch.no_grad():
    N = 32
    M = 32
    K = 64
    A_dtype = "int8"
    W_dtype = "int4"
    accum_dtype = "int32"
    out_dtype = "int32"
    layout = "nt"
    matmul_config = bitblas.MatmulConfig(
        N=N,  # N dimension
        K=K,  # K dimension
        A_dtype=A_dtype,  # activation A dtype
        W_dtype=W_dtype,  # weight W dtype
        accum_dtype=accum_dtype,  # accumulation dtype
        out_dtype=out_dtype,  # output dtype
        layout=layout,  # matrix layout, "nt" indicates the layout of A is non-transpose and the layout of W is transpose
    )
    matmul = bitblas.Matmul(config=matmul_config, enable_tuning=False)
    
    qx = torch.randint(-128, 127, (N, K), dtype=torch.int32)
    qweight = torch.randint(-8, 7, (M, K), dtype=torch.int32)
        
    out_ref = qx @ qweight.T    
    print(out_ref)
    
    qx = qx.cuda().type(torch.int8)
    qweight_pack = matmul.weight_transform(qweight.type(torch.int8))
    qweight_pack = qweight_pack.cuda()
    out1 = matmul(qx, qweight_pack)
    print(out1)

new results:

tensor([[-2650, -3869,  1951,  ...,   903, -3034, -2176],
        [  954,   975,  2805,  ...,  2581,   868,  2934],
        [ 1316,  4411,  4804,  ..., -1021,   299,  2011],
        ...,
        [-2285, -3522,  2887,  ...,  1284,  -629,  2081],
        [ 1525,  1015,  -243,  ...,   187,  2356, -1012],
        [ 1251,   907,  3809,  ..., -1848,  1797, -2703]], dtype=torch.int32)
tensor([[  519, -3181, -4741,  ...,  -912,  1250,  1353],
        [-1820, -2132,  -166,  ..., -2177,  1476, -6651],
        [-2149, -1128, -1339,  ...,   -18,  -972,  2873],
        ...,
        [-1890, -2242, -2827,  ..., -1354, -4520, -2746],
        [-2844,  1734, -3194,  ..., -3781,  2334,  1508],
        [-2017, -2495,  -347,  ...,   143,  1521, -2381]], device='cuda:0',
       dtype=torch.int32)

@LeiWang1999
Copy link
Contributor

LeiWang1999 commented Jan 20, 2025

Sorry for getting you confused, when utilize int format, it'll be better to:

import torch
import bitblas

torch.random.manual_seed(42)

with torch.no_grad():
    N = 32
    M = 32
    K = 64
    A_dtype = "int8"
    W_dtype = "uint4"
    accum_dtype = "int32"
    out_dtype = "int32"
    layout = "nt"
    matmul_config = bitblas.MatmulConfig(
        M=M,  # M dimension
        N=N,  # N dimension
        K=K,  # K dimension
        A_dtype=A_dtype,  # activation A dtype
        W_dtype=W_dtype,  # weight W dtype
        accum_dtype=accum_dtype,  # accumulation dtype
        out_dtype=out_dtype,  # output dtype
        layout=layout,  # matrix layout, "nt" indicates the layout of A is non-transpose and the layout of W is transpose
    )
    matmul = bitblas.Matmul(config=matmul_config)

    qx = torch.randint(-127, 127, (M, K), dtype=torch.int32)
    qweight = torch.randint(0, 15, (N, K), dtype=torch.int32)

    zeros = 2 ** (4 - 1) - 1
    qweight_with_offset = qweight - zeros    # lowerd into int range -7 ~ 8

    out_ref = qx @ qweight_with_offset.T
    print(out_ref)

    qx = qx.cuda().type(torch.int8)
    qweight_pack = matmul.weight_transform(qweight.type(torch.int8))
    qweight_pack = qweight_pack.cuda()
    out1 = matmul(qx, qweight_pack)
    print(out1)

@hyx1999
Copy link
Author

hyx1999 commented Jan 21, 2025

Thank you for your reply! I realized that I was using the wrong API (weight_transform) and after correcting the API (weight_transform -> transform_weight) I got the correct result.

import torch
import bitblas

torch.random.manual_seed(42)

with torch.no_grad():
    N = 32
    M = 32
    K = 64
    A_dtype = "int8"
    W_dtype = "int4"
    accum_dtype = "int32"
    out_dtype = "int32"
    layout = "nt"
    matmul_config = bitblas.MatmulConfig(
        M=M,  # M dimension
        N=N,  # N dimension
        K=K,  # K dimension
        A_dtype=A_dtype,  # activation A dtype
        W_dtype=W_dtype,  # weight W dtype
        accum_dtype=accum_dtype,  # accumulation dtype
        out_dtype=out_dtype,  # output dtype
        layout=layout,  # matrix layout, "nt" indicates the layout of A is non-transpose and the layout of W is transpose
        propagate_a=False,
        propagate_b=False,
        fast_decoding=False,
    )
    matmul = bitblas.Matmul(config=matmul_config)

    # num_bits = 4
    # zeros = 2 ** (num_bits - 1)

    A = torch.randint(-127, 127 + 1, (M, K), dtype=torch.int8).cuda()
    B = torch.randint(-8, 7 + 1, (N, K), dtype=torch.int8).cuda()
    
    out_ref = A.type(torch.float32) @ B.type(torch.float32).T
    print(out_ref.type(torch.int32))

    B = matmul.transform_weight(B)    
    out1 = matmul(A, B)
    print(out1)

@hyx1999 hyx1999 closed this as completed Jan 21, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants