diff --git a/apps/relax_examples/mlp.py b/apps/relax_examples/mlp.py new file mode 100644 index 000000000000..b89fa8f723ea --- /dev/null +++ b/apps/relax_examples/mlp.py @@ -0,0 +1,59 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Example code on creating, compiling, and running an MLP model in relax + + +import tvm +from tvm.relay import Call +from tvm import relax, tir, topi +import numpy as np + + +def build_mlp(data, weight): + bb = relax.BlockBuilder() + + with bb.function([data, weight], "mlp"): + gv0 = bb.emit_te(tvm.contrib.cblas.matmul, data, weight, transa=False, transb=False) + gv1 = bb.emit_te(topi.nn.relu, gv0) + bb.emit_func_output(gv1) + + mod = bb.get() + return mod + + +if __name__ == "__main__": + # symbolic dimensions + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + # create data and weight variables + data = relax.Var("data", [n, m], relax.DynTensorType(2, "float32")) + weight = relax.Var("weight", [m, n], relax.DynTensorType(2, "float32")) + + # construct a mlp model + mod = build_mlp(data, weight) + + # build and create vm executor + target = tvm.target.Target("llvm") + target_host = tvm.target.Target("llvm") + ex, lib = relax.vm.build(mod, target, target_host) + vm = relax.VirtualMachine(ex, tvm.cpu(), mod=lib) + + # run the mlp model on relax vm + data = tvm.nd.array(np.random.rand(16, 32).astype(np.float32)) + weight = tvm.nd.array(np.random.rand(32, 16).astype(np.float32)) + res = vm["mlp"](data, weight) + print(res) diff --git a/python/tvm/relax/block_builder.py b/python/tvm/relax/block_builder.py index 4fac922580ab..a1a07041bc4f 100644 --- a/python/tvm/relax/block_builder.py +++ b/python/tvm/relax/block_builder.py @@ -40,6 +40,13 @@ def __exit__(self, ptype, value, trace): block = _ffi_api.BlockBuilderEndBlock(self._ib) if len(block.bindings) > 0: self._ib._blocks.append(block) + seqe = rx.SeqExpr(self._ib._blocks, self._ib._func_ret) + func = rx.Function( + self._ib._func_params, seqe, rx.DynTensorType(-1, "float32"), rx.GlobalVar(self._ib._func_name) + ) + gvar = rx.GlobalVar(self._ib._func_name) + self._ib._context_mod[gvar] = func + return func class DataflowScope(object): @@ -82,7 +89,7 @@ class BlockBuilder(Object): lv1 = ib.emit(rx.multiply(lv0, y)) gv0 = ib.emit_output(lv1) ib.emit_func_output(gv0) - func = ib.get() + mod = ib.get() """ def __init__(self): @@ -356,27 +363,12 @@ def normalize(self, expr: Expr) -> Expr: """ return _ffi_api.BlockBuilderNormalize(self, expr) - def get(self) -> Function: - """Return the function being built. - - Returns - ------- - ret : tvm.relax.Function - A Relax function node being built. - """ - # TODO(hyoercubestart, ziheng) get should return IRModule with relax + TIR functions - seqe = rx.SeqExpr(self._blocks, self._func_ret) - func = rx.Function( - self._func_params, seqe, rx.DynTensorType(-1, "float32"), rx.GlobalVar(self._func_name) - ) - return func - - def context_mod(self): - """Return the context module that might contain tir functions. + def get(self) -> tvm.IRModule: + """Return the IRModule being built. Returns ------- - mod : tvm.IRModule - The context module that contains tir functions during emit. + ret : tvm.IRModule + An IRModule with Relax and TIR functions being built. """ return self._context_mod diff --git a/src/relax/ir/block_builder.cc b/src/relax/ir/block_builder.cc index aa4ca0b7c509..9d6096bfa925 100644 --- a/src/relax/ir/block_builder.cc +++ b/src/relax/ir/block_builder.cc @@ -297,7 +297,11 @@ BindingBlock BlockBuilderNode::EndBlock() { return ret; } -Optional InferShape(const Call& call, DiagnosticContext diag_ctx) { +Optional InferShape(const Call& call, DiagnosticContext diag_ctx) { + // if the call node's shape_ is filled, return the shape directly. + if (call->shape_) { + return Downcast(call->shape_.value()); + } auto op_map = Op::GetAttrMap("FInferShape"); if (call->op.as()) { Op op = Downcast(call->op); @@ -309,6 +313,10 @@ Optional InferShape(const Call& call, DiagnosticContext diag_ctx) { } Type InferType(const Call& call, DiagnosticContext diag_ctx) { + // if the call node's checked_type_ is filled, return the type directly. + if (call->checked_type_.defined()) { + return call->checked_type_; + } auto op_map = Op::GetAttrMap("FInferType"); if (call->op.as()) { Op op = Downcast(call->op); diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc index e27626e27f8d..6cfb0f1d54f3 100644 --- a/src/relax/op/op.cc +++ b/src/relax/op/op.cc @@ -59,7 +59,10 @@ RELAY_REGISTER_OP("relax.call_dps") Expr MakeCallDPS(Expr shape, Expr func, Tuple args) { static const Op& op = Op::Get("relax.call_dps"); - return Call(op, {shape, func, args}, {}, {}); + Call call = Call(op, {shape, func, args}, {}, {}); + call->shape_ = shape; + call->checked_type_ = args->fields[0]->checked_type_; + return call; } TVM_REGISTER_GLOBAL("relax.op.call_dps") diff --git a/tests/python/relax/test_analysis.py b/tests/python/relax/test_analysis.py index e8434408c57c..6cb17c185fcd 100644 --- a/tests/python/relax/test_analysis.py +++ b/tests/python/relax/test_analysis.py @@ -47,13 +47,13 @@ def test_post_order_visit(): x = rx.Var("x", [m, n], dtype0) y = rx.Var("y", [n], dtype1) ib = rx.BlockBuilder() - with ib.function([x, y]): + with ib.function([x, y], "func"): with ib.dataflow() as df: lv0 = ib.emit(rx.op.add(x, y)) lv1 = ib.emit(rx.op.multiply(lv0, y)) gv0 = ib.emit_output(lv1) ib.emit_func_output(gv0) - expr = ib.get() + expr = ib.get()["func"] names = [] diff --git a/tests/python/relax/test_blockbuilder.py b/tests/python/relax/test_blockbuilder.py index 497ae31ae75e..7beca6e58468 100644 --- a/tests/python/relax/test_blockbuilder.py +++ b/tests/python/relax/test_blockbuilder.py @@ -20,7 +20,6 @@ from tvm import tir, te from tvm import relay from tvm import relax as rx -import numpy as np from tvm.ir.base import assert_structural_equal from tvm.relax import op @@ -61,7 +60,7 @@ def test_function_single_block(): y = rx.Var("y", [n], dtype1) ib = rx.BlockBuilder() - with ib.function([x, y]): + with ib.function([x, y], "func"): with ib.dataflow() as df: lv0 = ib.emit(rx.op.add(x, y)) assert lv0.name_hint == "lv" @@ -71,7 +70,7 @@ def test_function_single_block(): assert gv0.name_hint == "gv" ib.emit_func_output(gv0) - func = ib.get() + func = ib.get()["func"] assert func.params[0] == x assert func.params[1] == y assert func.body.body == gv0 @@ -106,7 +105,7 @@ def test_function_multi_blocks(): gv2 = ib.emit_output(gv1) ib.emit_func_output(gv2) - func = ib.get() + func = ib.get()["func"] assert gv2.shape[0] == m assert gv2.shape[1] == n assert gv2.checked_type.rank == 2 @@ -121,6 +120,40 @@ def test_function_multi_blocks(): assert len(func.body.blocks[2].bindings) == 2 +def test_multi_functions(): + m = tir.Var("m", "int32") + n = tir.Var("n", "int32") + dtype0 = rx.DynTensorType(rank=2, dtype="float16") + dtype1 = rx.DynTensorType(rank=1, dtype="float16") + x = rx.Var("x", [m, n], dtype0) + y = rx.Var("y", [n], dtype1) + ib = rx.BlockBuilder() + + with ib.function([x, y], "func1"): + with ib.dataflow() as df: + lv0 = ib.emit(rx.op.add(x, y)) + assert lv0.name_hint == "lv" + gv0 = ib.emit_output(lv0) + ib.emit_func_output(gv0) + + with ib.function([x, y], "func2"): + with ib.dataflow() as df: + lv0 = ib.emit(rx.op.add(x, y)) + assert lv0.name_hint == "lv" + gv0 = ib.emit_output(lv0) + ib.emit_func_output(gv0) + + mod = ib.get() + func1 = mod["func1"] + assert func1.params[0] == x + assert func1.params[1] == y + assert func1.name.name_hint == "func1" + func2 = mod["func2"] + assert func2.params[0] == x + assert func2.params[1] == y + assert func2.name.name_hint == "func2" + + def test_binary_shape_type_deduction(): m = tir.Var("m", "int32") n = tir.Var("n", "int32") @@ -177,7 +210,7 @@ def test_emit_match_shape(): y = rx.Var("shape_value", type_annotation=rx.ShapeType(), shape_annotation=shape_anno) ib = rx.BlockBuilder() - with ib.function([x, y]): + with ib.function([x, y], "func"): with ib.dataflow() as df: # lv0: Tensor[(m, n), "float32"] = # match_shape(x: Tensor[_, "float32"], [m, n]) @@ -194,7 +227,7 @@ def test_emit_match_shape(): gv0 = ib.emit_output(lv1) ib.emit_func_output(gv0) - func = ib.get() + func = ib.get()["func"] block = func.body.blocks[0] b0, b1 = block.bindings[:2] assert isinstance(b0, rx.MatchShape) @@ -248,11 +281,8 @@ def te_func(args, args_dict, msg): out = bb.emit_te(te_func, [x, y], {"C": z}, msg="hello") bb.emit_func_output(out) - func = bb.get() - mod = bb.context_mod() - - gvar = tvm.relay.GlobalVar("rx_func") - mod[gvar] = func + mod = bb.get() + rx_func = mod["rx_func"] def get_tir_func(): A = te.placeholder((n, m), dtype="float32", name="A") @@ -265,20 +295,20 @@ def get_tir_func(): assert_structural_equal(mod["te_func"].body, get_tir_func().body) # check Relax function calls TIR function with call_dps call - assert func.params[0] == x - assert func.params[1] == y - assert func.params[2] == z - assert func.name.name_hint == "rx_func" - assert func.body.body == out - assert len(func.body.blocks) == 1 - assert len(func.body.blocks[0].bindings) == 1 - assert isinstance(func.body.blocks[0].bindings[0].value, rx.Call) - assert func.body.blocks[0].bindings[0].value.op == relay.op.get("relax.call_dps") - assert len(func.body.blocks[0].bindings[0].value.args) == 3 - assert func.body.blocks[0].bindings[0].value.args[1].name_hint == "te_func" - assert func.body.blocks[0].bindings[0].value.args[2][0] == x - assert func.body.blocks[0].bindings[0].value.args[2][1] == y - assert func.body.blocks[0].bindings[0].value.args[2][2] == z + assert rx_func.params[0] == x + assert rx_func.params[1] == y + assert rx_func.params[2] == z + assert rx_func.name.name_hint == "rx_func" + assert rx_func.body.body == out + assert len(rx_func.body.blocks) == 1 + assert len(rx_func.body.blocks[0].bindings) == 1 + assert isinstance(rx_func.body.blocks[0].bindings[0].value, rx.Call) + assert rx_func.body.blocks[0].bindings[0].value.op == relay.op.get("relax.call_dps") + assert len(rx_func.body.blocks[0].bindings[0].value.args) == 3 + assert rx_func.body.blocks[0].bindings[0].value.args[1].name_hint == "te_func" + assert rx_func.body.blocks[0].bindings[0].value.args[2][0] == x + assert rx_func.body.blocks[0].bindings[0].value.args[2][1] == y + assert rx_func.body.blocks[0].bindings[0].value.args[2][2] == z def test_emit_te_multiple(): @@ -297,16 +327,45 @@ def te_func(A): y1 = bb.emit_te(te_func, y) bb.emit_func_output(y1) - func = bb.get() + func = bb.get()["rx_func"] assert func.body.blocks[0].bindings[0].value.args[1].name_hint == "te_func" assert func.body.blocks[0].bindings[1].value.args[1].name_hint == "te_func1" + +def test_emit_te_extern(): + bb = rx.BlockBuilder() + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + type_anno = rx.DynTensorType(2, "float32") + x = rx.Var("x", [n, m], type_anno) + y = rx.Var("y", [m, n], type_anno) + + with bb.function([x, y], "rx_cblas_matmul"): + out = bb.emit_te(tvm.contrib.cblas.matmul, x, y, transa=False, transb=False) + bb.emit_func_output(out) + + mod = bb.get() + rx_func = mod["rx_cblas_matmul"] + + # check Relax function calls TIR function with call_dps call + assert rx_func.params[0] == x + assert rx_func.params[1] == y + assert len(rx_func.body.blocks) == 1 + assert isinstance(rx_func.body.blocks[0].bindings[0].value, rx.Call) + assert rx_func.body.blocks[0].bindings[0].value.op == relay.op.get("relax.call_dps") + assert len(rx_func.body.blocks[0].bindings[0].value.args) == 3 + assert rx_func.body.blocks[0].bindings[0].value.args[1].name_hint == "matmul" + assert rx_func.body.blocks[0].bindings[0].value.args[2][0] == x + assert rx_func.body.blocks[0].bindings[0].value.args[2][1] == y + + if __name__ == "__main__": test_block_builder() test_function_single_block() test_function_multi_blocks() + test_multi_functions() test_binary_shape_type_deduction() test_emit_match_shape() test_normalize() test_emit_te() test_emit_te_multiple() + test_emit_te_extern() diff --git a/tests/python/relax/test_transform.py b/tests/python/relax/test_transform.py index ddec2fe929f8..f5dc17f1eecd 100644 --- a/tests/python/relax/test_transform.py +++ b/tests/python/relax/test_transform.py @@ -35,17 +35,17 @@ def test_fma_rewrite(): x = relax.Var("x", [m, n], dtype0) y = relax.Var("y", [m, n], dtype1) ib = relax.BlockBuilder() - with ib.function([x, y]): + with ib.function([x, y], "func"): with ib.dataflow() as df: lv0 = ib.emit(relax.op.multiply(x, y)) gv0 = ib.emit_output(relax.op.add(lv0, y)) ib.emit_func_output(gv0) - expr = ib.get() - mod = IRModule.from_expr(expr) + mod = ib.get() + func = mod["func"] # before rewrite - v0 = expr.body.blocks[0].bindings[1].var - s0 = expr.body.blocks[0].bindings[1].value + v0 = func.body.blocks[0].bindings[1].var + s0 = func.body.blocks[0].bindings[1].value assert isinstance(s0, tvm.relay.Call) assert s0.op.name == "relax.add" assert structural_equal(v0.shape, relax.ShapeExpr([m, n])) @@ -54,9 +54,9 @@ def test_fma_rewrite(): # after rewrite new_mod = relax.transform.FMARewrite()(mod) - func = new_mod["main"] - v1 = func.body.blocks[0].bindings[1].var - s1 = func.body.blocks[0].bindings[1].value + new_func = new_mod["func"] + v1 = new_func.body.blocks[0].bindings[1].var + s1 = new_func.body.blocks[0].bindings[1].value assert isinstance(s1, tvm.relay.Call) assert s1.op.name == "relax.ewise_fma" assert structural_equal(v1.shape, relax.ShapeExpr([m, n])) @@ -65,7 +65,7 @@ def test_fma_rewrite(): # The var binded to the fma call is reused because the shape # and type of var are unchanged after rewriting assert gv0 == v0 - assert type(func.body.blocks[0].bindings[1].var) == relax.Var + assert type(new_func.body.blocks[0].bindings[1].var) == relax.Var def test_visit_shape(): @tvm.script.ir_module diff --git a/tests/python/relax/test_vm.py b/tests/python/relax/test_vm.py index 76e0f838a0a4..6599bf6ee3dd 100644 --- a/tests/python/relax/test_vm.py +++ b/tests/python/relax/test_vm.py @@ -19,8 +19,9 @@ import numpy as np import tvm from tvm.relay import Call -from tvm import relax +from tvm import relax, tir from tvm.runtime import container +import numpy as np import tvm.script from tvm.script import tir as T, relax as R @@ -393,7 +394,6 @@ def func(x:Tensor[(m, n), "float32"], w:Tensor[(n, k), "float32"]) -> Tensor: ex, lib = relax.vm.build(mod, target, target_host) vm = relax.VirtualMachine(ex, tvm.cpu(), mod=lib) - import numpy as np data = tvm.nd.array(np.random.rand(32, 16).astype(np.float32)) weight = tvm.nd.array(np.random.rand(16, 32).astype(np.float32)) res = vm["func"](data, weight) @@ -401,6 +401,31 @@ def func(x:Tensor[(m, n), "float32"], w:Tensor[(n, k), "float32"]) -> Tensor: np.testing.assert_allclose(expected, res.asnumpy(), rtol=1e-4, atol=1e-4) +def test_vm_emit_te_extern(): + bb = relax.BlockBuilder() + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + type_anno = relax.DynTensorType(2, "float32") + x = relax.Var("x", [n, m], type_anno) + y = relax.Var("y", [m, n], type_anno) + + with bb.function([x, y], "rx_cblas_matmul"): + out = bb.emit_te(tvm.contrib.cblas.matmul, x, y, transa=False, transb=False) + bb.emit_func_output(out) + + mod = bb.get() + + target = tvm.target.Target("llvm") + target_host = tvm.target.Target("llvm") + ex, lib = relax.vm.build(mod, target, target_host) + vm = relax.VirtualMachine(ex, tvm.cpu(), mod=lib) + + data = tvm.nd.array(np.random.rand(16, 32).astype(np.float32)) + weight = tvm.nd.array(np.random.rand(32, 16).astype(np.float32)) + res = vm["rx_cblas_matmul"](data, weight) + expected = np.dot(data.asnumpy(), weight.asnumpy()) + np.testing.assert_allclose(expected, res.asnumpy(), rtol=1e-4, atol=1e-4) + + if __name__ == "__main__": test_vm_execute() test_vm_multiple_func() @@ -417,3 +442,4 @@ def func(x:Tensor[(m, n), "float32"], w:Tensor[(n, k), "float32"]) -> Tensor: test_vm_compile_stage3() test_vm_compile_e2e() test_vm_compile_e2e_func_param_with_shape() + test_vm_emit_te_extern()