Skip to content

Commit

Permalink
[Dev] Improve General Matmul With Splitk (apache#50)
Browse files Browse the repository at this point in the history
* improve e4m3 decoding.

* append fp16xint1

* Update submodule commit reference

* chore: Update shared memory scope for float32 output dtype

* BUGFIX: UINT8/INT8 Decoding

* feat: Add rasterization options for roller module

* Refactor tensorcore_legalization method to optimize tensor core usage

* feat: Add function to collect variables from expression, improve for splitk

* chore: Update typing import in __init__.py

* chore: Refactor CPU execution of operators

* Refactor matmul implementation for splitk layout

* Refactor matmul implementation for splitk layout

* Refactor matmul implementation for splitk layout

* chore: Update version to 0.0.1.dev8

* chore: Enable debug output in bitblas.set_debug_level()

* Refactor Linear module matmul implementation for splitk layout

* Refactor matmul implementation for splitk layout

---------

Co-authored-by: LeiWang199 <leiwang199>
  • Loading branch information
LeiWang1999 authored Jun 5, 2024
1 parent 4cac65a commit 0b2e48f
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 26 deletions.
9 changes: 9 additions & 0 deletions docs/QuickStart.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ Here is an example for a $W_{INT4}A_{FP16}$ mixed-precision matrix multiplicatio
import bitblas
import torch

# enabling debug output

bitblas.set_debug_level("Debug")
matmul_config = bitblas.MatmulConfig(
M=1, # M dimension
N=1024, # N dimension
Expand Down Expand Up @@ -125,6 +128,9 @@ Here is an example to define a ```bitblas.Linear``` of $W_{INT4}A_{FP16}$:
import bitblas
import torch

# enabling debug output
bitblas.set_debug_level("Debug")

model = bitblas.Linear(
in_features=1024,
out_features=1024,
Expand Down Expand Up @@ -178,6 +184,9 @@ from auto_gptq.nn_modules.qlinear.qlinear_cuda_old import (
QuantLinear as CudaOldQuantLinear,
)

# enabling debug output
bitblas.set_debug_level("Debug")

in_features = 1024
out_features = 1024
group_size = 128
Expand Down
7 changes: 5 additions & 2 deletions python/bitblas/module/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,15 +232,18 @@ def forward(self, A, output=None):
A = A.half()
# can be lifted to post init.
self.init_params()

if output is None:
output = torch.empty(
A.shape[:-1] + (self.out_features,), dtype=A.dtype, device=A.device)
m = ctypes.c_int32(reduce(operator.mul, A.shape[:-1], 1))
A = self.bitblas_matmul.transform_input(A)
stream = torch.cuda.current_stream()

A_void = ctypes.c_void_p(A.data_ptr())
stream_handle = ctypes.c_void_p(stream.cuda_stream)
# m is the product of the last n - 1 dimensions of A
self.bitblas_matmul.lib.call(A_void, *self.q_params, ctypes.c_void_p(output.data_ptr()), m)
self.bitblas_matmul.lib.call(A_void, *self.q_params, ctypes.c_void_p(output.data_ptr()), m, stream_handle)

return output

Expand Down
11 changes: 8 additions & 3 deletions python/bitblas/ops/general_matmul_splitk.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def forward(self, A, W, scale=None, zeros=None, bias=None, output=None) -> Any:

if output is None:
output = torch.empty(
(self.k_split,) + A.shape[:-1] + (self.N,),
A.shape[:-1] + (self.N,),
dtype=self.torch_output_dtype,
device=A.device)
if scale is not None:
Expand All @@ -169,7 +169,12 @@ def forward(self, A, W, scale=None, zeros=None, bias=None, output=None) -> Any:
args.append(zeros)
if bias is not None:
args.append(bias)
args.append(output)

sk_output = torch.empty((self.k_split,) +
A.shape[:-1] + (self.N,),
dtype=self.torch_output_dtype,
device=A.device)
args.append(sk_output)

if self.dynamic_range is not None:
m = reduce(operator.mul, A.shape[:-1], 1)
Expand All @@ -180,7 +185,7 @@ def forward(self, A, W, scale=None, zeros=None, bias=None, output=None) -> Any:
if self.lib is None:
self._forward_from_torch_func(*args)
self._forward_from_prebuild_lib(*args, stream=stream.cuda_stream)
output = torch.sum(output, dim=0)
torch.sum(sk_output, dim=0, out=output)
return output

def __call__(self, *args: Any, **kwds: Any) -> Any:
Expand Down
4 changes: 3 additions & 1 deletion testing/python/operators/test_general_matmul_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,4 +171,6 @@ def map_torch_type(intype):

# fmt: on
if __name__ == "__main__":
bitblas.testing.main()
# bitblas.testing.main()
test_matmul_torch_forward_weight_dequantize(1024, 1024, 1024, "float16", "e4m3_float8", "float16", "float16", "nt", None, None, None,
None, None)
70 changes: 50 additions & 20 deletions testing/python/operators/test_general_matmul_splitk_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,20 +41,22 @@ def test_matmul_codegen_default(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtyp
matmul = MatmulWithSplitK(config=matmul_config, enable_tuning=False)
assert get_codegen_result(matmul)


@pytest.mark.parametrize(
"M,N,K,A_dtype,W_dtype,accum_dtype,out_dtype,layout,with_bias,group_size,with_scaling,with_zeros,zeros_mode",
"SPlitK,M,N,K,A_dtype,W_dtype,accum_dtype,out_dtype,layout,with_bias,group_size,with_scaling,with_zeros,zeros_mode",
[
(1, 4096, 12800, "float16", "float16", "float16", "float16", "nt", False, -1, False, False,
None),
(16, 4096, 12800, "float16", "float16", "float16", "float16", "nt", False, -1, False, False,
None),
(1, 1, 4096, 12800, "float16", "float16", "float16", "float16", "nt", False, -1, False,
False, None),
(4, 1, 4096, 12800, "float16", "float16", "float16", "float16", "nt", False, -1, False,
False, None),
],
)
def test_matmul_finetune(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, layout, with_bias,
group_size, with_scaling, with_zeros, zeros_mode):

def test_matmul_torch_forward_consistent(SplitK, M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype,
layout, with_bias, group_size, with_scaling, with_zeros,
zeros_mode):
import torch
torch.random.manual_seed(0)
matmul_config = MatmulConfigWithSplitK(
k_split=SplitK,
M=M,
N=N,
K=K,
Expand All @@ -70,20 +72,27 @@ def test_matmul_finetune(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, layo
zeros_mode=zeros_mode,
)
matmul = MatmulWithSplitK(config=matmul_config, enable_tuning=False)
matmul.hardware_aware_finetune(topk=10)
assert get_codegen_result(matmul)

input_shape = (M, K)
weight_shape = (N, K) if layout == "nt" else (K, N)
inputs = []
inputs.append(torch.rand(input_shape, dtype=torch.float16).cuda() - 0.5)
inputs.append(torch.rand(weight_shape, dtype=torch.float16).cuda() - 0.5)

output_bitblas = matmul.forward(*inputs)
output_torch = torch.matmul(inputs[0], inputs[1].t() if layout == "nt" else inputs[1])
torch.testing.assert_close(output_bitblas, output_torch, rtol=1e-2, atol=1e-1)

@pytest.mark.parametrize(
"SPlitK,M,N,K,A_dtype,W_dtype,accum_dtype,out_dtype,layout,with_bias,group_size,with_scaling,with_zeros,zeros_mode",
[
(1, 1, 4096, 12800, "float16", "float16", "float16", "float16", "nt", False, -1, False,
(1, 16, 4096, 12800, "float16", "e4m3_float8", "float32", "float16", "nt", False, -1, False,
False, None),
(4, 1, 4096, 12800, "float16", "float16", "float16", "float16", "nt", False, -1, False,
(4, 16, 4096, 12800, "float16", "e4m3_float8", "float32", "float16", "nt", False, -1, False,
False, None),
],
)
def test_matmul_torch_forward_consistent(SplitK, M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype,
def test_matmul_torch_forward_fp8e4m3(SplitK, M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype,
layout, with_bias, group_size, with_scaling, with_zeros,
zeros_mode):
import torch
Expand All @@ -103,18 +112,39 @@ def test_matmul_torch_forward_consistent(SplitK, M, N, K, A_dtype, W_dtype, accu
with_scaling=with_scaling,
with_zeros=with_zeros,
zeros_mode=zeros_mode,
propagate_a=False,
propagate_b=False,
)
matmul = MatmulWithSplitK(config=matmul_config, enable_tuning=False)

input_shape = (M, K)
weight_shape = (N, K) if layout == "nt" else (K, N)
inputs = []
inputs.append(torch.rand(input_shape, dtype=torch.float16).cuda() - 0.5)
inputs.append(torch.rand(weight_shape, dtype=torch.float16).cuda() - 0.5)
def map_torch_type(intype):

output_bitblas = matmul.forward(*inputs)
output_torch = torch.matmul(inputs[0], inputs[1].t() if layout == "nt" else inputs[1])
torch.testing.assert_close(output_bitblas, output_torch, rtol=1e-2, atol=1e-1)
typemap = {
'e4m3_float8': torch.float8_e4m3fn,
'e5m2_float8': torch.float8_e5m2,
}
if intype in typemap:
return typemap[intype]
else:
return getattr(torch, intype)

numpytype_a = map_torch_type(A_dtype)
numpytype_b = map_torch_type(W_dtype)

torch_a = torch.rand(M * K).uniform_(-1, 1).reshape(input_shape).type(numpytype_a).cuda()
torch_b = torch.rand(N * K).uniform_(-1, 1).reshape(weight_shape).type(numpytype_b).cuda()
ref_out = torch.matmul(torch_a.to(torch.float32),
torch_b.t().to(torch.float32)) if layout == "nt" else torch.matmul(
torch_a.to(torch.float32), torch_b.to(torch.float32))
ref_out = ref_out.to(torch.float16)
bitblas_out = torch.empty_like(ref_out)
matmul.forward(torch_a, torch_b, output=bitblas_out)
print("torch_ref_out", ref_out)
print("bitblas_out", bitblas_out)

torch.testing.assert_close(bitblas_out, ref_out, rtol=1e0, atol=1e-1)


# fmt: on
Expand Down

0 comments on commit 0b2e48f

Please sign in to comment.