diff --git a/tests/python/contrib/test_cutlass.py b/tests/python/contrib/test_cutlass.py index 3a7bdf43fd810..ef55c74dc3a5e 100644 --- a/tests/python/contrib/test_cutlass.py +++ b/tests/python/contrib/test_cutlass.py @@ -518,7 +518,7 @@ def verify_conv2d_common( ): if not has_cutlass(): return - if sm < 80 and data_dtype == "float32": + if sm < 80 and inputs[0].dtype == "float32": return mod_nchw = tvm.IRModule.from_expr(expr_nchw)