Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CUBLAS] Add support for nn.dense and nn.batch_matmul #10826

Merged
merged 2 commits into from
Mar 31, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 35 additions & 8 deletions python/tvm/relay/op/contrib/cublas.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,17 +64,23 @@ def pattern_table() -> List[Tuple[str, relay.Pattern, Callable[[relay.Call], boo
"""Get the cuBLAS pattern table."""

def matmul_pattern() -> relay.Pattern:
"""Create pattern for matrix multiply."""
"""Create pattern for matmul."""
return is_op("nn.matmul")(wildcard(), wildcard())

def check_matmul(matched: relay.Call) -> bool:
def batch_matmul_pattern() -> relay.Pattern:
"""Create pattern for batch_matmul."""
return is_op("nn.batch_matmul")(wildcard(), wildcard())

def dense_pattern() -> relay.Pattern:
"""Create pattern for dense."""
return is_op("nn.dense")(wildcard(), wildcard())

def check_matmul_like(matched: relay.Call) -> bool:
"""Check if matmul is supported by cuBLAS."""
# Units not supported
if matched.attrs["units"] is not None:
return False
# Input data types can't be mixed
if matched.args[0].checked_type.dtype != matched.args[1].checked_type.dtype:
return False

in_dtype = matched.args[0].checked_type.dtype
out_dtype = matched.checked_type.dtype
# Only the following data type combinations are supported
Expand All @@ -87,18 +93,21 @@ def check_matmul(matched: relay.Call) -> bool:
("int8", "float32"),
]:
return False

# If inputs are int8, input column strides must be a multiple of 4
if in_dtype == "int8":
if (
matched.args[0].checked_type.shape[1] % 4 != 0
or matched.args[1].checked_type.shape[1] % 4 != 0
matched.args[0].checked_type.shape[-1] % 4 != 0
or matched.args[1].checked_type.shape[-1] % 4 != 0
):
return False

return True

return [
("cublas.matmul", matmul_pattern(), check_matmul),
("cublas.matmul", matmul_pattern(), check_matmul_like),
("cublas.batch_matmul", batch_matmul_pattern(), check_matmul_like),
("cublas.dense", dense_pattern(), check_matmul_like),
]


Expand Down Expand Up @@ -156,3 +165,21 @@ def _lower_matmul(op: relay.Call, inputs: List[te.Tensor]) -> te.Tensor:
transb=op.attrs["transpose_b"],
dtype=op.checked_type.dtype,
)


@_lower_composite("cublas.batch_matmul")
def _lower_batch_matmul(op: relay.Call, inputs: List[te.Tensor]) -> te.Tensor:
"""Lower a batch_matmul using cuBLAS."""
return cublas.batch_matmul(
inputs[0],
inputs[1],
transa=op.attrs["transpose_a"],
transb=op.attrs["transpose_b"],
dtype=op.checked_type.dtype,
)


@_lower_composite("cublas.dense")
def _lower_dense(op: relay.Call, inputs: List[te.Tensor]) -> te.Tensor:
"""Lower a dense using cuBLAS."""
return cublas.matmul(inputs[0], inputs[1], False, True, dtype=op.checked_type.dtype)
mbaret marked this conversation as resolved.
Show resolved Hide resolved
16 changes: 8 additions & 8 deletions src/runtime/contrib/cublas/cublas.cc
Original file line number Diff line number Diff line change
Expand Up @@ -277,23 +277,23 @@ inline void CallBatchGemmEx(TVMArgs args, TVMRetValue* ret, cublasHandle_t hdl)
ICHECK_EQ(C->ndim, 3);

int batch_size = BatchCount3D(C);
ICHECK_EQ(ElementStride(A), 1);
ICHECK_EQ(ElementStride(B), 1);
ICHECK_EQ(ElementStride(C), 1);
ICHECK_EQ(ElementStride3D(A), 1);
ICHECK_EQ(ElementStride3D(B), 1);
ICHECK_EQ(ElementStride3D(C), 1);

ICHECK(TypeEqual(A->dtype, B->dtype));

// C can never be transposed.
ICHECK(!IsInPlaceTransposed(C));
ICHECK(!IsInPlaceTransposed3D(C));

// Reversed strides indicates an in-place transpose operation.
transa = IsInPlaceTransposed(A) ? !transa : transa;
transb = IsInPlaceTransposed(B) ? !transb : transb;
transa = IsInPlaceTransposed3D(A) ? !transa : transa;
transb = IsInPlaceTransposed3D(B) ? !transb : transb;

ICHECK(CheckMixPrecisionType(A->dtype, C->dtype, true)) << "Unsupported data type";
ICHECK(!TypeMatch(A->dtype, kDLInt, 8) || ColumnStride(A) % 4 == 0)
ICHECK(!TypeMatch(A->dtype, kDLInt, 8) || ColumnStride3D(A) % 4 == 0)
<< "leading dimension must divide 4 for int8 gemm";
ICHECK(!TypeMatch(B->dtype, kDLInt, 8) || ColumnStride(B) % 4 == 0)
ICHECK(!TypeMatch(B->dtype, kDLInt, 8) || ColumnStride3D(B) % 4 == 0)
<< "leading dimension must divide 4 for int8 gemm";
double alpha = args.size() > 5 ? args[5] : 1.0;
double beta = args.size() > 6 ? args[6] : 0.0;
Expand Down
121 changes: 121 additions & 0 deletions tests/python/contrib/test_cublas.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,5 +256,126 @@ def test_relay_cublas_matmul(n, m, k, in_dtype, out_dtype, transpose_a, transpos
_verify_cublas_relay(matmul)


@tvm.testing.requires_cuda
@pytest.mark.parametrize(
"n,m,k",
[
(64, 128, 32),
(17, 32, 16),
(24, 17, 12),
(96, 4, 17),
],
)
@pytest.mark.parametrize(
"in_dtype,out_dtype",
[
("float32", "float32"),
("float16", "float16"),
("float16", "float32"),
("int8", "int32"),
("float64", "float64"),
("int8", "float32"),
],
)
def test_relay_cublas_dense(n, m, k, in_dtype, out_dtype):
unsupported_configs = [
(96, 4, 17, "int8", "float32"),
(96, 4, 17, "int8", "int32"),
]
if (n, m, k, in_dtype, out_dtype) in unsupported_configs:
pytest.skip("Unsupported parameters.")

data = tvm.relay.var("data", tvm.relay.TensorType((n, k), in_dtype))
weight = tvm.relay.var("weight", tvm.relay.TensorType((m, k), in_dtype))
dense = relay.op.nn.dense(data, weight, out_dtype=out_dtype)
_verify_cublas_relay(dense)


@tvm.testing.requires_cuda
@pytest.mark.parametrize(
"n,m,k,batch_a,batch_b,transpose_a,transpose_b",
[
(64, 128, 32, 16, 16, False, False),
(17, 32, 16, 16, 1, True, False),
(24, 17, 12, 17, 17, False, True),
(96, 4, 17, 53, 1, True, True),
],
)
@pytest.mark.parametrize(
"in_dtype,out_dtype",
[
("float32", "float32"),
("float16", "float16"),
("float16", "float32"),
("int8", "int32"),
("float64", "float64"),
("int8", "float32"),
],
)
def test_relay_cublas_batch_matmul(
n, m, k, batch_a, batch_b, in_dtype, out_dtype, transpose_a, transpose_b
):
unsupported_configs = [
(17, 32, 16, 16, 1, "int8", "float32", True, False),
(96, 4, 17, 53, 1, "int8", "float32", True, True),
(17, 32, 16, 16, 1, "int8", "int32", True, False),
(96, 4, 17, 53, 1, "int8", "int32", True, True),
]
if (
n,
m,
k,
batch_a,
batch_b,
in_dtype,
out_dtype,
transpose_a,
transpose_b,
) in unsupported_configs:
pytest.skip("Unsupported parameters.")

a_shape = (batch_a, k, n) if transpose_a else (batch_a, n, k)
b_shape = (batch_b, m, k) if transpose_b else (batch_b, k, m)
a = tvm.relay.var("A", tvm.relay.TensorType(a_shape, in_dtype))
b = tvm.relay.var("B", tvm.relay.TensorType(b_shape, in_dtype))
batch_matmul = relay.op.nn.batch_matmul(a, b, out_dtype, transpose_a, transpose_b)
_verify_cublas_relay(batch_matmul)


@tvm.testing.requires_cuda
@pytest.mark.parametrize(
"n,m,k",
[
(64, 128, 32),
(17, 32, 16),
(24, 17, 12),
(96, 4, 17),
],
)
@pytest.mark.parametrize(
"in_dtype,out_dtype",
[
("float32", "float32"),
("float16", "float16"),
("float16", "float32"),
("int8", "int32"),
("float64", "float64"),
("int8", "float32"),
],
)
def test_relay_cublas_dense(n, m, k, in_dtype, out_dtype):
unsupported_configs = [
(96, 4, 17, "int8", "float32"),
(96, 4, 17, "int8", "int32"),
]
if (n, m, k, in_dtype, out_dtype) in unsupported_configs:
pytest.skip("Unsupported parameters.")

data = tvm.relay.var("data", tvm.relay.TensorType((n, k), in_dtype))
weight = tvm.relay.var("weight", tvm.relay.TensorType((m, k), in_dtype))
dense = relay.op.nn.dense(data, weight, out_dtype=out_dtype)
_verify_cublas_relay(dense)


if __name__ == "__main__":
pytest.main([__file__])