Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CUBLAS] Add cuBLAS as a Relay partitioning target (BYOC) #10820

Merged
merged 2 commits into from
Mar 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 158 additions & 0 deletions python/tvm/relay/op/contrib/cublas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=unused-argument
"""cuBLAS Relay integration."""
from typing import Callable, List, Tuple, Dict, Optional

import tvm
import tvm.ir
from tvm import relay
from tvm import te
from tvm.relay import transform
from tvm.contrib import cublas

from ...dataflow_pattern import is_op, wildcard
from .register import register_pattern_table


def partition_for_cublas(
mod: tvm.IRModule, params: Optional[Dict[str, tvm.runtime.NDArray]] = None
) -> tvm.IRModule:
"""Partition the graph to offload for cuBLAS.

Parameters
----------
mod : tvm.IRModule
The module to partition.
params : Optional[Dict[str, tvm.runtime.NDArray]]
Constant input parameters.

Returns
-------
tvm.IRModule
The partitioned module.
"""

seq = tvm.transform.Sequential(
[
transform.InferType(),
transform.MergeComposite(pattern_table()),
transform.AnnotateTarget("cublas"),
transform.PartitionGraph(),
transform.InferType(),
]
)
return seq(mod)


@register_pattern_table("cublas")
def pattern_table() -> List[Tuple[str, relay.Pattern, Callable[[relay.Call], bool]]]:
"""Get the cuBLAS pattern table."""

def matmul_pattern() -> relay.Pattern:
"""Create pattern for matrix multiply."""
return is_op("nn.matmul")(wildcard(), wildcard())

def check_matmul(matched: relay.Call) -> bool:
"""Check if matmul is supported by cuBLAS."""
# Units not supported
if matched.attrs["units"] is not None:
return False
# Input data types can't be mixed
if matched.args[0].checked_type.dtype != matched.args[1].checked_type.dtype:
return False
in_dtype = matched.args[0].checked_type.dtype
out_dtype = matched.checked_type.dtype
# Only the following data type combinations are supported
if (in_dtype, out_dtype) not in [
("float32", "float32"),
("float16", "float16"),
("float16", "float32"),
("int8", "int32"),
("float64", "float64"),
("int8", "float32"),
]:
return False
# If inputs are int8, input column strides must be a multiple of 4
if in_dtype == "int8":
if (
matched.args[0].checked_type.shape[1] % 4 != 0
or matched.args[1].checked_type.shape[1] % 4 != 0
):
return False

return True

return [
("cublas.matmul", matmul_pattern(), check_matmul),
]


_LowerFunc = Callable[[relay.Call, List[te.Tensor]], te.Tensor]
_LOWER_MAP: Dict[str, _LowerFunc] = {}


def _lower_composite(comp_name: str) -> Callable[[_LowerFunc], _LowerFunc]:
"""Register a lowering function for a given composite function name."""

def _register(f: _LowerFunc) -> _LowerFunc:
_LOWER_MAP[comp_name] = f
return f

return _register


@tvm._ffi.register_func("relay.ext.cublas")
def relay_to_runtime(partition: relay.Function) -> tvm.runtime.Module:
"""Compile cuBLAS Relay functions to a runtime module."""
assert isinstance(partition, relay.Function)
assert isinstance(partition.body, relay.Call)
assert isinstance(partition.body.op, relay.Function)

global_name = str(partition.attrs.global_symbol)
target = tvm.target.cuda()
comp_func = partition.body.op
comp_name = comp_func.attrs["Composite"]
assert comp_name in _LOWER_MAP
assert isinstance(comp_func.body, relay.Call)

op = comp_func.body
inputs = []
for i, param in enumerate(comp_func.params):
inputs.append(
te.placeholder(
param.checked_type.shape,
name=f"input_{i}",
dtype=param.checked_type.dtype,
)
)

output = _LOWER_MAP[comp_name](op, inputs)
prim_func = te.create_prim_func(inputs + [output])
return tvm.build(prim_func, target=target, name=global_name)


@_lower_composite("cublas.matmul")
def _lower_matmul(op: relay.Call, inputs: List[te.Tensor]) -> te.Tensor:
"""Lower a matmul using cuBLAS."""
return cublas.matmul(
inputs[0],
inputs[1],
transa=op.attrs["transpose_a"],
transb=op.attrs["transpose_b"],
dtype=op.checked_type.dtype,
)
90 changes: 87 additions & 3 deletions tests/python/contrib/test_cublas.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,18 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest

import tvm
from tvm import te
from tvm import relay
import numpy as np
from tvm.contrib import cublas
from tvm.contrib import cublaslt
from tvm.contrib import graph_executor
import tvm.testing
from tvm.relay.op.contrib import get_pattern_table
from tvm.relay.op.contrib.cublas import partition_for_cublas


def verify_matmul_add(in_dtype, out_dtype, rtol=1e-5):
Expand Down Expand Up @@ -170,7 +176,85 @@ def test_batch_matmul():
verify_batch_matmul((16, 1024, 128), (16, 128, 236), (16, 1024, 236), "int8", "int32")


def _verify_cublas_relay(expr):
np.random.seed(42)

mod = tvm.IRModule.from_expr(expr)
mod = relay.transform.InferType()(mod)
func = mod["main"]
cublas_mod = partition_for_cublas(mod)
assert len(cublas_mod.get_global_vars()) == 2

input_data = []
for param in func.params:
shape = [int(x) for x in param.checked_type.shape]
input_data.append(
(param.name_hint, np.random.uniform(0, 32, size=shape).astype(param.checked_type.dtype))
)

# Test against CPU reference
cuda_config = (tvm.target.cuda(), tvm.cuda(), cublas_mod)
cpu_config = (tvm.target.Target("llvm"), tvm.cpu(), mod)
outputs = []
for target, dev, test_mod in [cuda_config, cpu_config]:
with tvm.transform.PassContext(opt_level=3):
lib = relay.build(test_mod, target=target, target_host=cpu_config[0])
module = graph_executor.GraphModule(lib["default"](dev))
for name, data in input_data:
module.set_input(name, tvm.nd.array(data, dev))

module.run()
out_type = func.body.checked_type
outputs.append(
module.get_output(0, tvm.nd.empty(out_type.shape, dtype=out_type.dtype)).numpy()
)

tvm.testing.assert_allclose(
outputs[0],
outputs[1],
rtol=1e-2,
)


@tvm.testing.requires_cuda
@pytest.mark.parametrize(
"n,m,k,transpose_a,transpose_b",
[
(64, 128, 32, False, False),
(17, 32, 16, True, False),
(24, 17, 12, False, True),
(96, 4, 17, True, True),
],
)
@pytest.mark.parametrize(
"in_dtype,out_dtype",
[
("float32", "float32"),
("float16", "float16"),
("float16", "float32"),
("int8", "int32"),
("float64", "float64"),
("int8", "float32"),
],
)
def test_relay_cublas_matmul(n, m, k, in_dtype, out_dtype, transpose_a, transpose_b):
unsupported_configs = [
(17, 32, 16, "int8", "float32", True, False),
(96, 4, 17, "int8", "float32", True, True),
(17, 32, 16, "int8", "int32", True, False),
(96, 4, 17, "int8", "int32", True, True),
]
if (n, m, k, in_dtype, out_dtype, transpose_a, transpose_b) in unsupported_configs:
pytest.skip("Unsupported parameters.")

a_shape = (k, n) if transpose_a else (n, k)
b_shape = (m, k) if transpose_b else (k, m)
a = tvm.relay.var("A", tvm.relay.TensorType(a_shape, in_dtype))
b = tvm.relay.var("B", tvm.relay.TensorType(b_shape, in_dtype))
# Directly use matmul because nn.matmul sometimes defers to nn.dense
matmul = relay.op.nn._make.matmul(a, b, None, out_dtype, transpose_a, transpose_b)
_verify_cublas_relay(matmul)


if __name__ == "__main__":
test_matmul_add()
test_batch_matmul()
test_matmul_add_igemm()
pytest.main([__file__])
3 changes: 3 additions & 0 deletions tests/scripts/task_mypy.sh
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ mypy --check-untyped-defs python/tvm/tir/transform/
echo "Checking MyPy Type defs in the TIR package with unittest"
MYPYPATH=$TVM_PATH/python mypy --check-untyped-defs tests/python/unittest/test_tvmscript_type.py

echo "Checking MyPy Type defs in tvm.relay.op.contrib.cublas"
mypy --disallow-untyped-defs python/tvm/relay/op/contrib/cublas.py

#TODO(@mikepapadim): This is failing atm
# echo "Checking MyPy Type defs in the tvm.relay.backend.contrib.ethosu package."
# mypy --check-untyped-defs python/tvm/relay/backend/contrib/ethosu/