Skip to content

Commit

Permalink
int8 test worked
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Jan 11, 2022
1 parent e9b0287 commit ad3faf6
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 3 deletions.
6 changes: 5 additions & 1 deletion src/relay/backend/contrib/cutlass/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,11 @@ namespace contrib {
using namespace backend;
using Str2StrMap = std::unordered_map<std::string, std::string>;

static Str2StrMap dtype_map = {{"float16", "cutlass::half_t"}, {"float32", "float"}};
static Str2StrMap dtype_map = {{"float16", "cutlass::half_t"},
{"float32", "float"},
{"int8", "int8_t"},
{"uint8", "uint8_t"},
{"int32", "int32_t"}};

constexpr const char* kAnyDim = "Any";

Expand Down
7 changes: 5 additions & 2 deletions tests/python/contrib/test_cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,7 @@ def verify_conv2d(
use_fast_math=False,
data_dtype="float16",
weight_dtype="float16",
ref_target="cuda",
):
if not has_cutlass():
return
Expand Down Expand Up @@ -436,7 +437,7 @@ def verify_conv2d(
rt_mod_ref, dev = get_ref_rt_mod(
convert_conv2d_layout(mod_ref, {"nn.conv2d": ["NHWC", "HWIO"]}),
params,
target="cuda",
target=ref_target,
)

ref_out = get_output(rt_mod_ref, ["data"], [np_data])
Expand Down Expand Up @@ -575,6 +576,7 @@ def test_int8():
run_benchmark=False,
data_dtype="int8",
weight_dtype="int8",
ref_target="llvm",
)


Expand Down Expand Up @@ -610,4 +612,5 @@ def test_3xtf32():


if __name__ == "__main__":
pytest.main([__file__])
# pytest.main([__file__])
test_int8()

0 comments on commit ad3faf6

Please sign in to comment.