Skip to content

Commit

Permalink
int8 and 3xtf32 gemm works
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Jan 11, 2022
1 parent 7764408 commit 2aaed84
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 19 deletions.
5 changes: 1 addition & 4 deletions python/tvm/relay/op/contrib/cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,7 @@ def get_root_call(call, root_op_name):

def check_gemm(call):
"""Check if the given dense workload can be offloaded to CUTLASS."""
dense = get_root_call(call, "nn.dense")
lhs = dense.args[0].checked_type
rhs = dense.args[1].checked_type
return check_dtype(lhs, rhs)
return True


def check_batch_matmul(call):
Expand Down
44 changes: 29 additions & 15 deletions tests/python/contrib/test_cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,14 @@ def get_output_vm(vm, names, inputs):
return vm.invoke("main", **params).numpy()


def get_dense_with_shape(data_shape, weight_shape, out_dtype="float16"):
data = relay.var("data", shape=data_shape, dtype="float16")
weight = relay.var("weight", shape=weight_shape, dtype="float16")
def get_dense_with_shape(data_shape, weight_shape, out_dtype="float16", data_dtype="float16", weight_dtype="float16"):
data = relay.var("data", shape=data_shape, dtype=data_dtype)
weight = relay.var("weight", shape=weight_shape, dtype=weight_dtype)
return relay.nn.dense(data, weight, out_dtype=out_dtype)


def get_dense(M, N, K, out_dtype="float16"):
return get_dense_with_shape((M, K), (N, K), out_dtype)
def get_dense(M, N, K, out_dtype="float16", data_dtype="float16", weight_dtype="float16"):
return get_dense_with_shape((M, K), (N, K), out_dtype, data_dtype, weight_dtype)


def get_dense_bias(M, N, K, out_dtype="float16"):
Expand Down Expand Up @@ -178,6 +178,7 @@ def get_conv2d_nchw_bias_residual(d_shape, w_shape, padding, out_dtype="float16"

def profile_and_build(mod, params, sm, tmp_dir="./tmp", lib_path="compile.so", use_fast_math=False):
mod = partition_for_cutlass(mod)
print(mod)
mod, num_cutlass_partition = tune_cutlass_kernels(
mod, sm, profile_all=False, use_multiprocessing=False, tmp_dir=tmp_dir
)
Expand Down Expand Up @@ -210,17 +211,27 @@ def profile_and_build_vm(


def verify_dense(
func, M, N, K, ref_target="cuda", sm=80, atol=1e-5, rtol=1e-5, run_benchmark=False
func,
M,
N,
K,
ref_target="cuda",
sm=80,
atol=1e-5,
rtol=1e-5,
run_benchmark=False,
data_dtype="float16",
weight_dtype="float16",
):
if not has_cutlass():
return
mod = tvm.IRModule.from_expr(func)
typ = relay.transform.InferType()(mod)["main"].body.checked_type
out_dtype = typ.dtype
use_vm = any(isinstance(s, tvm.tir.Any) for s in typ.shape)
np_data = np.random.uniform(-1, 1, (M, K)).astype("float16")
np_weight = np.random.uniform(-1, 1, (N, K)).astype("float16")
np_bias = np.random.uniform(-1, 1, (N,)).astype(out_dtype)
np_data = get_random_ndarray((M, K), data_dtype)
np_weight = get_random_ndarray((N, K), weight_dtype)
np_bias = get_random_ndarray((N,), out_dtype)

params = {"weight": np_weight, "bias": np_bias}

Expand Down Expand Up @@ -292,7 +303,7 @@ def verify_batch_matmul(
print("TVM Tensorcore (no tuning):", rt_mod_ref.benchmark(dev, number=1, repeat=600))


M = 1820
M = 1024
N = 768
K = 768

Expand All @@ -302,6 +313,9 @@ def test_dense():
verify_dense(get_dense(M, N, K, out_dtype="float32"), M, N, K)
# Test align1 case
verify_dense(get_dense_bias(M, N + 1, K), M, N + 1, K)
verify_dense(get_dense(M, N, K, "int32", "int8", "int8"), M, N, K, data_dtype="int8", weight_dtype="int8")
# Test 3xtf32 kernels
verify_dense(get_dense(M, N, K, "float32", "float32", "float32"), M, N, K, data_dtype="float32", weight_dtype="float32")


def test_dense_bias():
Expand Down Expand Up @@ -548,7 +562,7 @@ def test_conv2d_residual_block():
verify_conv2d(func, func, d_shape, w_shape, sm=80, atol=tol, rtol=tol, run_benchmark=False)


def test_int8():
def test_conv2d_int8():
d_shape = (16, 16, 32, 32)
w_shape = (32, 16, 3, 3)
padding = (1, 1)
Expand Down Expand Up @@ -581,7 +595,7 @@ def test_int8():
)


def test_3xtf32():
def test_conv2d_3xtf32():
d_shape = (16, 16, 32, 32)
w_shape = (32, 16, 3, 3)
padding = (1, 1)
Expand Down Expand Up @@ -609,11 +623,11 @@ def test_3xtf32():
run_benchmark=False,
data_dtype="float32",
weight_dtype="float32",
ref_target="llvm"
ref_target="llvm",
)


if __name__ == "__main__":
# pytest.main([__file__])
# test_int8()
test_3xtf32()
test_dense()
# test_3xtf32()

0 comments on commit 2aaed84

Please sign in to comment.