diff --git a/python/tvm/relay/op/contrib/cublas.py b/python/tvm/relay/op/contrib/cublas.py index 09505cdaa8d11..a93169c2d84e5 100644 --- a/python/tvm/relay/op/contrib/cublas.py +++ b/python/tvm/relay/op/contrib/cublas.py @@ -64,17 +64,23 @@ def pattern_table() -> List[Tuple[str, relay.Pattern, Callable[[relay.Call], boo """Get the cuBLAS pattern table.""" def matmul_pattern() -> relay.Pattern: - """Create pattern for matrix multiply.""" + """Create pattern for matmul.""" return is_op("nn.matmul")(wildcard(), wildcard()) - def check_matmul(matched: relay.Call) -> bool: + def batch_matmul_pattern() -> relay.Pattern: + """Create pattern for batch_matmul.""" + return is_op("nn.batch_matmul")(wildcard(), wildcard()) + + def dense_pattern() -> relay.Pattern: + """Create pattern for dense.""" + return is_op("nn.dense")(wildcard(), wildcard()) + + def check_matmul_like(matched: relay.Call) -> bool: """Check if matmul is supported by cuBLAS.""" - # Units not supported - if matched.attrs["units"] is not None: - return False # Input data types can't be mixed if matched.args[0].checked_type.dtype != matched.args[1].checked_type.dtype: return False + in_dtype = matched.args[0].checked_type.dtype out_dtype = matched.checked_type.dtype # Only the following data type combinations are supported @@ -87,18 +93,21 @@ def check_matmul(matched: relay.Call) -> bool: ("int8", "float32"), ]: return False + # If inputs are int8, input column strides must be a multiple of 4 if in_dtype == "int8": if ( - matched.args[0].checked_type.shape[1] % 4 != 0 - or matched.args[1].checked_type.shape[1] % 4 != 0 + matched.args[0].checked_type.shape[-1] % 4 != 0 + or matched.args[1].checked_type.shape[-1] % 4 != 0 ): return False return True return [ - ("cublas.matmul", matmul_pattern(), check_matmul), + ("cublas.matmul", matmul_pattern(), check_matmul_like), + ("cublas.batch_matmul", batch_matmul_pattern(), check_matmul_like), + ("cublas.dense", dense_pattern(), check_matmul_like), ] @@ -156,3 +165,23 @@ def _lower_matmul(op: relay.Call, inputs: List[te.Tensor]) -> te.Tensor: transb=op.attrs["transpose_b"], dtype=op.checked_type.dtype, ) + + +@_lower_composite("cublas.batch_matmul") +def _lower_batch_matmul(op: relay.Call, inputs: List[te.Tensor]) -> te.Tensor: + """Lower a batch_matmul using cuBLAS.""" + return cublas.batch_matmul( + inputs[0], + inputs[1], + transa=op.attrs["transpose_a"], + transb=op.attrs["transpose_b"], + dtype=op.checked_type.dtype, + ) + + +@_lower_composite("cublas.dense") +def _lower_dense(op: relay.Call, inputs: List[te.Tensor]) -> te.Tensor: + """Lower a dense using cuBLAS.""" + return cublas.matmul( + inputs[0], inputs[1], transa=False, transb=True, dtype=op.checked_type.dtype + ) diff --git a/src/runtime/contrib/cublas/cublas.cc b/src/runtime/contrib/cublas/cublas.cc index b13f9e858d663..ee0f50e3495b6 100644 --- a/src/runtime/contrib/cublas/cublas.cc +++ b/src/runtime/contrib/cublas/cublas.cc @@ -277,23 +277,23 @@ inline void CallBatchGemmEx(TVMArgs args, TVMRetValue* ret, cublasHandle_t hdl) ICHECK_EQ(C->ndim, 3); int batch_size = BatchCount3D(C); - ICHECK_EQ(ElementStride(A), 1); - ICHECK_EQ(ElementStride(B), 1); - ICHECK_EQ(ElementStride(C), 1); + ICHECK_EQ(ElementStride3D(A), 1); + ICHECK_EQ(ElementStride3D(B), 1); + ICHECK_EQ(ElementStride3D(C), 1); ICHECK(TypeEqual(A->dtype, B->dtype)); // C can never be transposed. - ICHECK(!IsInPlaceTransposed(C)); + ICHECK(!IsInPlaceTransposed3D(C)); // Reversed strides indicates an in-place transpose operation. - transa = IsInPlaceTransposed(A) ? !transa : transa; - transb = IsInPlaceTransposed(B) ? !transb : transb; + transa = IsInPlaceTransposed3D(A) ? !transa : transa; + transb = IsInPlaceTransposed3D(B) ? !transb : transb; ICHECK(CheckMixPrecisionType(A->dtype, C->dtype, true)) << "Unsupported data type"; - ICHECK(!TypeMatch(A->dtype, kDLInt, 8) || ColumnStride(A) % 4 == 0) + ICHECK(!TypeMatch(A->dtype, kDLInt, 8) || ColumnStride3D(A) % 4 == 0) << "leading dimension must divide 4 for int8 gemm"; - ICHECK(!TypeMatch(B->dtype, kDLInt, 8) || ColumnStride(B) % 4 == 0) + ICHECK(!TypeMatch(B->dtype, kDLInt, 8) || ColumnStride3D(B) % 4 == 0) << "leading dimension must divide 4 for int8 gemm"; double alpha = args.size() > 5 ? args[5] : 1.0; double beta = args.size() > 6 ? args[6] : 0.0; diff --git a/tests/python/contrib/test_cublas.py b/tests/python/contrib/test_cublas.py index 64d954e50cfeb..0ae1e8e9ad5b0 100644 --- a/tests/python/contrib/test_cublas.py +++ b/tests/python/contrib/test_cublas.py @@ -256,5 +256,126 @@ def test_relay_cublas_matmul(n, m, k, in_dtype, out_dtype, transpose_a, transpos _verify_cublas_relay(matmul) +@tvm.testing.requires_cuda +@pytest.mark.parametrize( + "n,m,k", + [ + (64, 128, 32), + (17, 32, 16), + (24, 17, 12), + (96, 4, 17), + ], +) +@pytest.mark.parametrize( + "in_dtype,out_dtype", + [ + ("float32", "float32"), + ("float16", "float16"), + ("float16", "float32"), + ("int8", "int32"), + ("float64", "float64"), + ("int8", "float32"), + ], +) +def test_relay_cublas_dense(n, m, k, in_dtype, out_dtype): + unsupported_configs = [ + (96, 4, 17, "int8", "float32"), + (96, 4, 17, "int8", "int32"), + ] + if (n, m, k, in_dtype, out_dtype) in unsupported_configs: + pytest.skip("Unsupported parameters.") + + data = tvm.relay.var("data", tvm.relay.TensorType((n, k), in_dtype)) + weight = tvm.relay.var("weight", tvm.relay.TensorType((m, k), in_dtype)) + dense = relay.op.nn.dense(data, weight, out_dtype=out_dtype) + _verify_cublas_relay(dense) + + +@tvm.testing.requires_cuda +@pytest.mark.parametrize( + "n,m,k,batch_a,batch_b,transpose_a,transpose_b", + [ + (64, 128, 32, 16, 16, False, False), + (17, 32, 16, 16, 1, True, False), + (24, 17, 12, 17, 17, False, True), + (96, 4, 17, 53, 1, True, True), + ], +) +@pytest.mark.parametrize( + "in_dtype,out_dtype", + [ + ("float32", "float32"), + ("float16", "float16"), + ("float16", "float32"), + ("int8", "int32"), + ("float64", "float64"), + ("int8", "float32"), + ], +) +def test_relay_cublas_batch_matmul( + n, m, k, batch_a, batch_b, in_dtype, out_dtype, transpose_a, transpose_b +): + unsupported_configs = [ + (17, 32, 16, 16, 1, "int8", "float32", True, False), + (96, 4, 17, 53, 1, "int8", "float32", True, True), + (17, 32, 16, 16, 1, "int8", "int32", True, False), + (96, 4, 17, 53, 1, "int8", "int32", True, True), + ] + if ( + n, + m, + k, + batch_a, + batch_b, + in_dtype, + out_dtype, + transpose_a, + transpose_b, + ) in unsupported_configs: + pytest.skip("Unsupported parameters.") + + a_shape = (batch_a, k, n) if transpose_a else (batch_a, n, k) + b_shape = (batch_b, m, k) if transpose_b else (batch_b, k, m) + a = tvm.relay.var("A", tvm.relay.TensorType(a_shape, in_dtype)) + b = tvm.relay.var("B", tvm.relay.TensorType(b_shape, in_dtype)) + batch_matmul = relay.op.nn.batch_matmul(a, b, out_dtype, transpose_a, transpose_b) + _verify_cublas_relay(batch_matmul) + + +@tvm.testing.requires_cuda +@pytest.mark.parametrize( + "n,m,k", + [ + (64, 128, 32), + (17, 32, 16), + (24, 17, 12), + (96, 4, 17), + ], +) +@pytest.mark.parametrize( + "in_dtype,out_dtype", + [ + ("float32", "float32"), + ("float16", "float16"), + ("float16", "float32"), + ("int8", "int32"), + ("float64", "float64"), + ("int8", "float32"), + ], +) +def test_relay_cublas_dense(n, m, k, in_dtype, out_dtype): + unsupported_configs = [ + (96, 4, 17, "int8", "float32"), + (96, 4, 17, "int8", "int32"), + ] + if (n, m, k, in_dtype, out_dtype) in unsupported_configs: + pytest.skip("Unsupported parameters.") + + data = tvm.relay.var("data", tvm.relay.TensorType((n, k), in_dtype)) + weight = tvm.relay.var("weight", tvm.relay.TensorType((m, k), in_dtype)) + dense = relay.op.nn.dense(data, weight, out_dtype=out_dtype) + _verify_cublas_relay(dense) + + if __name__ == "__main__": pytest.main([__file__])