Skip to content

Commit

Permalink
update test
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Dec 12, 2021
1 parent 81bf9e6 commit c08bb38
Showing 1 changed file with 53 additions and 61 deletions.
114 changes: 53 additions & 61 deletions tests/python/contrib/test_cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,29 +114,33 @@ def get_conv2d_nchw(d_shape, w_shape, padding, out_dtype="float16"):
data = relay.var("data", shape=d_shape, dtype="float16")
weight = relay.var("weight", shape=w_shape, dtype="float16")
out_channel = w_shape[0]
return tvm.IRModule.from_expr(
relay.nn.conv2d(
data=data,
weight=weight,
kernel_size=w_shape[2:],
channels=out_channel,
padding=padding,
out_dtype=out_dtype,
)
return relay.nn.conv2d(
data=data,
weight=weight,
kernel_size=w_shape[2:],
channels=out_channel,
padding=padding,
out_dtype=out_dtype,
)


def get_conv2d_nchw_bias(d_shape, w_shape, padding, out_dtype="float16"):
conv2d = get_conv2d_nchw(d_shape, w_shape, padding, out_dtype=out_dtype)
bias = relay.var("bias", shape=(w_shape[0],), dtype=out_dtype)
return relay.nn.bias_add(conv2d, bias)


def get_conv2d_nchw_bias_relu(d_shape, w_shape, padding, out_dtype="float16"):
return relay.nn.relu(get_conv2d_nchw_bias(d_shape, w_shape, padding, out_dtype=out_dtype))


def profile_and_build(mod, params, sm, tmp_dir="./tmp", lib_path="compile.so"):
mod = partition_for_cutlass(mod)
mod, num_cutlass_partition = tune_cutlass_kernels(
mod, sm, profile_all=True, use_multiprocessing=True, tmp_dir=tmp_dir
mod, sm, profile_all=False, use_multiprocessing=False, tmp_dir=tmp_dir
)
with tvm.transform.PassContext(opt_level=3):
lib = relay.build(mod, target="cuda", params=params)
lib = build_cutlass_kernels(lib, sm, tmp_dir, lib_path)
dev = tvm.device("cuda", 0)
rt_mod = tvm.contrib.graph_executor.GraphModule(lib["default"](dev))
return rt_mod, dev, num_cutlass_partition
print(mod)
return


def profile_and_build_vm(
Expand Down Expand Up @@ -314,8 +318,8 @@ def convert_conv2d_layout(mod, desired_layouts):


def verify_conv2d(
mod_nchw, # can be dynamic batch
mod_ref, # always static batch
expr_nchw, # can be dynamic batch
expr_ref, # always static batch
d_shape,
w_shape,
sm=80,
Expand All @@ -324,59 +328,38 @@ def verify_conv2d(
use_cudnn_ref=False,
run_benchmark=False,
):
if not has_cutlass():
return
# if not has_cutlass():
# return

mod_nchw = tvm.IRModule.from_expr(expr_nchw)
mod_ref = tvm.IRModule.from_expr(expr_ref)

typ = relay.transform.InferType()(mod_nchw)["main"].body.checked_type
out_dtype = typ.dtype

np_data = np.random.uniform(-1, 1, d_shape).astype("float16")
np_weight = np.random.uniform(-1, 1, w_shape).astype("float16")
np_bias = np.random.uniform(-1, 1, (w_shape[0],)).astype(out_dtype)

params = {"weight": np_weight}
params = {"weight": np_weight, "bias": np_bias}

typ = relay.transform.InferType()(mod_nchw)["main"].body.checked_type
use_vm = any(isinstance(s, tvm.tir.Any) for s in typ.shape)

mod_weight_ohwi = convert_conv2d_layout(mod_nchw, {"nn.conv2d": ["NHWC", "OHWI"]})

if use_vm:
rt_mod, _, num_cutlass_partition = profile_and_build_vm(mod_weight_ohwi, params, sm)
out = get_output_vm(rt_mod, ["data"], [np_data])
else:
rt_mod, _, num_cutlass_partition = profile_and_build(
mod_weight_ohwi,
params,
sm,
)
out = get_output(rt_mod, ["data"], [np_data])

assert num_cutlass_partition > 0

if use_cudnn_ref:
rt_mod_ref, dev = get_ref_rt_mod(
convert_conv2d_layout(mod_ref, {"nn.conv2d": ["NHWC", "OHWI"]}),
params,
target="cuda -libs=cudnn",
)
else:
rt_mod_ref, dev = get_ref_rt_mod(
convert_conv2d_layout(mod_ref, {"nn.conv2d": ["NHWC", "HWIO"]}),
params,
target="cuda",
)

ref_out = get_output(rt_mod_ref, ["data"], [np_data])

if run_benchmark:
print("CUTLASS:", rt_mod.benchmark(dev, number=1, repeat=600))
print("TVM Tensorcore (no tuning):", rt_mod_ref.benchmark(dev, number=1, repeat=600))

np.testing.assert_allclose(out, ref_out, atol=atol, rtol=rtol)
profile_and_build(
mod_weight_ohwi,
params,
sm,
)


def test_conv2d():
padding = (1, 1)
for IC in [3, 16]:
d_shape = (16, IC, 32, 32)
w_shape = (32, IC, 3, 3)
mod_nchw = get_conv2d_nchw(d_shape, w_shape)
mod_nchw = get_conv2d_nchw(d_shape, w_shape, padding)

verify_conv2d(
mod_nchw,
Expand All @@ -390,19 +373,28 @@ def test_conv2d():
run_benchmark=False,
)

return
d_shape = (16, 16, 32, 32)
w_shape = (32, 16, 3, 3)
dyn_batch_shape = (relay.Any(),) + d_shape[1:]

mod_nchw = get_conv2d_nchw(d_shape, w_shape)
mod_dyn = get_conv2d_nchw(dyn_batch_shape, w_shape)
mod_nchw = get_conv2d_nchw(d_shape, w_shape, padding)
mod_dyn = get_conv2d_nchw(dyn_batch_shape, w_shape, padding)

verify_conv2d(
mod_dyn, mod_nchw, d_shape, w_shape, sm=80, atol=1e-5, rtol=1e-5, run_benchmark=False
)


def test_conv2d_bias():
d_shape = (16, 16, 32, 32)
w_shape = (32, 16, 3, 3)
padding = (1, 1)
mod_nchw = get_conv2d_nchw_bias(d_shape, w_shape, padding)

verify_conv2d(
mod_nchw, mod_nchw, d_shape, w_shape, sm=80, atol=1e-5, rtol=1e-5, run_benchmark=False
)


if __name__ == "__main__":
# pytest.main([__file__])
test_conv2d()
test_conv2d_bias()

0 comments on commit c08bb38

Please sign in to comment.