Skip to content

Commit

Permalink
[CUBLAS] Add support for nn.dense and nn.batch_matmul (apache#10826)
Browse files Browse the repository at this point in the history
* [CUBLAS] Add support for nn.dense and nn.batch_matmul

This commit includes a fix for cublas.batch_matmul
when mixed precision is being used.

* Specify args in dense
  • Loading branch information
mbaret authored and pfk-beta committed Apr 11, 2022
1 parent f6dc7ff commit bb43e96
Show file tree
Hide file tree
Showing 3 changed files with 166 additions and 16 deletions.
45 changes: 37 additions & 8 deletions python/tvm/relay/op/contrib/cublas.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,17 +64,23 @@ def pattern_table() -> List[Tuple[str, relay.Pattern, Callable[[relay.Call], boo
"""Get the cuBLAS pattern table."""

def matmul_pattern() -> relay.Pattern:
"""Create pattern for matrix multiply."""
"""Create pattern for matmul."""
return is_op("nn.matmul")(wildcard(), wildcard())

def check_matmul(matched: relay.Call) -> bool:
def batch_matmul_pattern() -> relay.Pattern:
"""Create pattern for batch_matmul."""
return is_op("nn.batch_matmul")(wildcard(), wildcard())

def dense_pattern() -> relay.Pattern:
"""Create pattern for dense."""
return is_op("nn.dense")(wildcard(), wildcard())

def check_matmul_like(matched: relay.Call) -> bool:
"""Check if matmul is supported by cuBLAS."""
# Units not supported
if matched.attrs["units"] is not None:
return False
# Input data types can't be mixed
if matched.args[0].checked_type.dtype != matched.args[1].checked_type.dtype:
return False

in_dtype = matched.args[0].checked_type.dtype
out_dtype = matched.checked_type.dtype
# Only the following data type combinations are supported
Expand All @@ -87,18 +93,21 @@ def check_matmul(matched: relay.Call) -> bool:
("int8", "float32"),
]:
return False

# If inputs are int8, input column strides must be a multiple of 4
if in_dtype == "int8":
if (
matched.args[0].checked_type.shape[1] % 4 != 0
or matched.args[1].checked_type.shape[1] % 4 != 0
matched.args[0].checked_type.shape[-1] % 4 != 0
or matched.args[1].checked_type.shape[-1] % 4 != 0
):
return False

return True

return [
("cublas.matmul", matmul_pattern(), check_matmul),
("cublas.matmul", matmul_pattern(), check_matmul_like),
("cublas.batch_matmul", batch_matmul_pattern(), check_matmul_like),
("cublas.dense", dense_pattern(), check_matmul_like),
]


Expand Down Expand Up @@ -156,3 +165,23 @@ def _lower_matmul(op: relay.Call, inputs: List[te.Tensor]) -> te.Tensor:
transb=op.attrs["transpose_b"],
dtype=op.checked_type.dtype,
)


@_lower_composite("cublas.batch_matmul")
def _lower_batch_matmul(op: relay.Call, inputs: List[te.Tensor]) -> te.Tensor:
"""Lower a batch_matmul using cuBLAS."""
return cublas.batch_matmul(
inputs[0],
inputs[1],
transa=op.attrs["transpose_a"],
transb=op.attrs["transpose_b"],
dtype=op.checked_type.dtype,
)


@_lower_composite("cublas.dense")
def _lower_dense(op: relay.Call, inputs: List[te.Tensor]) -> te.Tensor:
"""Lower a dense using cuBLAS."""
return cublas.matmul(
inputs[0], inputs[1], transa=False, transb=True, dtype=op.checked_type.dtype
)
16 changes: 8 additions & 8 deletions src/runtime/contrib/cublas/cublas.cc
Original file line number Diff line number Diff line change
Expand Up @@ -277,23 +277,23 @@ inline void CallBatchGemmEx(TVMArgs args, TVMRetValue* ret, cublasHandle_t hdl)
ICHECK_EQ(C->ndim, 3);

int batch_size = BatchCount3D(C);
ICHECK_EQ(ElementStride(A), 1);
ICHECK_EQ(ElementStride(B), 1);
ICHECK_EQ(ElementStride(C), 1);
ICHECK_EQ(ElementStride3D(A), 1);
ICHECK_EQ(ElementStride3D(B), 1);
ICHECK_EQ(ElementStride3D(C), 1);

ICHECK(TypeEqual(A->dtype, B->dtype));

// C can never be transposed.
ICHECK(!IsInPlaceTransposed(C));
ICHECK(!IsInPlaceTransposed3D(C));

// Reversed strides indicates an in-place transpose operation.
transa = IsInPlaceTransposed(A) ? !transa : transa;
transb = IsInPlaceTransposed(B) ? !transb : transb;
transa = IsInPlaceTransposed3D(A) ? !transa : transa;
transb = IsInPlaceTransposed3D(B) ? !transb : transb;

ICHECK(CheckMixPrecisionType(A->dtype, C->dtype, true)) << "Unsupported data type";
ICHECK(!TypeMatch(A->dtype, kDLInt, 8) || ColumnStride(A) % 4 == 0)
ICHECK(!TypeMatch(A->dtype, kDLInt, 8) || ColumnStride3D(A) % 4 == 0)
<< "leading dimension must divide 4 for int8 gemm";
ICHECK(!TypeMatch(B->dtype, kDLInt, 8) || ColumnStride(B) % 4 == 0)
ICHECK(!TypeMatch(B->dtype, kDLInt, 8) || ColumnStride3D(B) % 4 == 0)
<< "leading dimension must divide 4 for int8 gemm";
double alpha = args.size() > 5 ? args[5] : 1.0;
double beta = args.size() > 6 ? args[6] : 0.0;
Expand Down
121 changes: 121 additions & 0 deletions tests/python/contrib/test_cublas.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,5 +256,126 @@ def test_relay_cublas_matmul(n, m, k, in_dtype, out_dtype, transpose_a, transpos
_verify_cublas_relay(matmul)


@tvm.testing.requires_cuda
@pytest.mark.parametrize(
"n,m,k",
[
(64, 128, 32),
(17, 32, 16),
(24, 17, 12),
(96, 4, 17),
],
)
@pytest.mark.parametrize(
"in_dtype,out_dtype",
[
("float32", "float32"),
("float16", "float16"),
("float16", "float32"),
("int8", "int32"),
("float64", "float64"),
("int8", "float32"),
],
)
def test_relay_cublas_dense(n, m, k, in_dtype, out_dtype):
unsupported_configs = [
(96, 4, 17, "int8", "float32"),
(96, 4, 17, "int8", "int32"),
]
if (n, m, k, in_dtype, out_dtype) in unsupported_configs:
pytest.skip("Unsupported parameters.")

data = tvm.relay.var("data", tvm.relay.TensorType((n, k), in_dtype))
weight = tvm.relay.var("weight", tvm.relay.TensorType((m, k), in_dtype))
dense = relay.op.nn.dense(data, weight, out_dtype=out_dtype)
_verify_cublas_relay(dense)


@tvm.testing.requires_cuda
@pytest.mark.parametrize(
"n,m,k,batch_a,batch_b,transpose_a,transpose_b",
[
(64, 128, 32, 16, 16, False, False),
(17, 32, 16, 16, 1, True, False),
(24, 17, 12, 17, 17, False, True),
(96, 4, 17, 53, 1, True, True),
],
)
@pytest.mark.parametrize(
"in_dtype,out_dtype",
[
("float32", "float32"),
("float16", "float16"),
("float16", "float32"),
("int8", "int32"),
("float64", "float64"),
("int8", "float32"),
],
)
def test_relay_cublas_batch_matmul(
n, m, k, batch_a, batch_b, in_dtype, out_dtype, transpose_a, transpose_b
):
unsupported_configs = [
(17, 32, 16, 16, 1, "int8", "float32", True, False),
(96, 4, 17, 53, 1, "int8", "float32", True, True),
(17, 32, 16, 16, 1, "int8", "int32", True, False),
(96, 4, 17, 53, 1, "int8", "int32", True, True),
]
if (
n,
m,
k,
batch_a,
batch_b,
in_dtype,
out_dtype,
transpose_a,
transpose_b,
) in unsupported_configs:
pytest.skip("Unsupported parameters.")

a_shape = (batch_a, k, n) if transpose_a else (batch_a, n, k)
b_shape = (batch_b, m, k) if transpose_b else (batch_b, k, m)
a = tvm.relay.var("A", tvm.relay.TensorType(a_shape, in_dtype))
b = tvm.relay.var("B", tvm.relay.TensorType(b_shape, in_dtype))
batch_matmul = relay.op.nn.batch_matmul(a, b, out_dtype, transpose_a, transpose_b)
_verify_cublas_relay(batch_matmul)


@tvm.testing.requires_cuda
@pytest.mark.parametrize(
"n,m,k",
[
(64, 128, 32),
(17, 32, 16),
(24, 17, 12),
(96, 4, 17),
],
)
@pytest.mark.parametrize(
"in_dtype,out_dtype",
[
("float32", "float32"),
("float16", "float16"),
("float16", "float32"),
("int8", "int32"),
("float64", "float64"),
("int8", "float32"),
],
)
def test_relay_cublas_dense(n, m, k, in_dtype, out_dtype):
unsupported_configs = [
(96, 4, 17, "int8", "float32"),
(96, 4, 17, "int8", "int32"),
]
if (n, m, k, in_dtype, out_dtype) in unsupported_configs:
pytest.skip("Unsupported parameters.")

data = tvm.relay.var("data", tvm.relay.TensorType((n, k), in_dtype))
weight = tvm.relay.var("weight", tvm.relay.TensorType((m, k), in_dtype))
dense = relay.op.nn.dense(data, weight, out_dtype=out_dtype)
_verify_cublas_relay(dense)


if __name__ == "__main__":
pytest.main([__file__])

0 comments on commit bb43e96

Please sign in to comment.