Skip to content

Commit

Permalink
Call topi and external library through emit_te and add MLP example (a…
Browse files Browse the repository at this point in the history
  • Loading branch information
YuchenJin committed Jan 26, 2022
1 parent a061c1b commit d168475
Show file tree
Hide file tree
Showing 8 changed files with 208 additions and 61 deletions.
59 changes: 59 additions & 0 deletions apps/relax_examples/mlp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# Example code on creating, compiling, and running an MLP model in relax


import tvm
from tvm.relay import Call
from tvm import relax, tir, topi
import numpy as np


def build_mlp(data, weight):
bb = relax.BlockBuilder()

with bb.function([data, weight], "mlp"):
gv0 = bb.emit_te(tvm.contrib.cblas.matmul, data, weight, transa=False, transb=False)
gv1 = bb.emit_te(topi.nn.relu, gv0)
bb.emit_func_output(gv1)

mod = bb.get()
return mod


if __name__ == "__main__":
# symbolic dimensions
n, m = tir.Var("n", "int64"), tir.Var("m", "int64")
# create data and weight variables
data = relax.Var("data", [n, m], relax.DynTensorType(2, "float32"))
weight = relax.Var("weight", [m, n], relax.DynTensorType(2, "float32"))

# construct a mlp model
mod = build_mlp(data, weight)

# build and create vm executor
target = tvm.target.Target("llvm")
target_host = tvm.target.Target("llvm")
ex, lib = relax.vm.build(mod, target, target_host)
vm = relax.VirtualMachine(ex, tvm.cpu(), mod=lib)

# run the mlp model on relax vm
data = tvm.nd.array(np.random.rand(16, 32).astype(np.float32))
weight = tvm.nd.array(np.random.rand(32, 16).astype(np.float32))
res = vm["mlp"](data, weight)
print(res)
32 changes: 12 additions & 20 deletions python/tvm/relax/block_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,13 @@ def __exit__(self, ptype, value, trace):
block = _ffi_api.BlockBuilderEndBlock(self._ib)
if len(block.bindings) > 0:
self._ib._blocks.append(block)
seqe = rx.SeqExpr(self._ib._blocks, self._ib._func_ret)
func = rx.Function(
self._ib._func_params, seqe, rx.DynTensorType(-1, "float32"), rx.GlobalVar(self._ib._func_name)
)
gvar = rx.GlobalVar(self._ib._func_name)
self._ib._context_mod[gvar] = func
return func


class DataflowScope(object):
Expand Down Expand Up @@ -82,7 +89,7 @@ class BlockBuilder(Object):
lv1 = ib.emit(rx.multiply(lv0, y))
gv0 = ib.emit_output(lv1)
ib.emit_func_output(gv0)
func = ib.get()
mod = ib.get()
"""

def __init__(self):
Expand Down Expand Up @@ -356,27 +363,12 @@ def normalize(self, expr: Expr) -> Expr:
"""
return _ffi_api.BlockBuilderNormalize(self, expr)

def get(self) -> Function:
"""Return the function being built.
Returns
-------
ret : tvm.relax.Function
A Relax function node being built.
"""
# TODO(hyoercubestart, ziheng) get should return IRModule with relax + TIR functions
seqe = rx.SeqExpr(self._blocks, self._func_ret)
func = rx.Function(
self._func_params, seqe, rx.DynTensorType(-1, "float32"), rx.GlobalVar(self._func_name)
)
return func

def context_mod(self):
"""Return the context module that might contain tir functions.
def get(self) -> tvm.IRModule:
"""Return the IRModule being built.
Returns
-------
mod : tvm.IRModule
The context module that contains tir functions during emit.
ret : tvm.IRModule
An IRModule with Relax and TIR functions being built.
"""
return self._context_mod
10 changes: 9 additions & 1 deletion src/relax/ir/block_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,11 @@ BindingBlock BlockBuilderNode::EndBlock() {
return ret;
}

Optional<RelayExpr> InferShape(const Call& call, DiagnosticContext diag_ctx) {
Optional<Expr> InferShape(const Call& call, DiagnosticContext diag_ctx) {
// if the call node's shape_ is filled, return the shape directly.
if (call->shape_) {
return Downcast<Expr>(call->shape_.value());
}
auto op_map = Op::GetAttrMap<FInferShape>("FInferShape");
if (call->op.as<OpNode>()) {
Op op = Downcast<Op>(call->op);
Expand All @@ -309,6 +313,10 @@ Optional<RelayExpr> InferShape(const Call& call, DiagnosticContext diag_ctx) {
}

Type InferType(const Call& call, DiagnosticContext diag_ctx) {
// if the call node's checked_type_ is filled, return the type directly.
if (call->checked_type_.defined()) {
return call->checked_type_;
}
auto op_map = Op::GetAttrMap<FInferType>("FInferType");
if (call->op.as<OpNode>()) {
Op op = Downcast<Op>(call->op);
Expand Down
5 changes: 4 additions & 1 deletion src/relax/op/op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,10 @@ RELAY_REGISTER_OP("relax.call_dps")

Expr MakeCallDPS(Expr shape, Expr func, Tuple args) {
static const Op& op = Op::Get("relax.call_dps");
return Call(op, {shape, func, args}, {}, {});
Call call = Call(op, {shape, func, args}, {}, {});
call->shape_ = shape;
call->checked_type_ = args->fields[0]->checked_type_;
return call;
}

TVM_REGISTER_GLOBAL("relax.op.call_dps")
Expand Down
4 changes: 2 additions & 2 deletions tests/python/relax/test_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,13 @@ def test_post_order_visit():
x = rx.Var("x", [m, n], dtype0)
y = rx.Var("y", [n], dtype1)
ib = rx.BlockBuilder()
with ib.function([x, y]):
with ib.function([x, y], "func"):
with ib.dataflow() as df:
lv0 = ib.emit(rx.op.add(x, y))
lv1 = ib.emit(rx.op.multiply(lv0, y))
gv0 = ib.emit_output(lv1)
ib.emit_func_output(gv0)
expr = ib.get()
expr = ib.get()["func"]

names = []

Expand Down
111 changes: 85 additions & 26 deletions tests/python/relax/test_blockbuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from tvm import tir, te
from tvm import relay
from tvm import relax as rx
import numpy as np

from tvm.ir.base import assert_structural_equal
from tvm.relax import op
Expand Down Expand Up @@ -61,7 +60,7 @@ def test_function_single_block():
y = rx.Var("y", [n], dtype1)
ib = rx.BlockBuilder()

with ib.function([x, y]):
with ib.function([x, y], "func"):
with ib.dataflow() as df:
lv0 = ib.emit(rx.op.add(x, y))
assert lv0.name_hint == "lv"
Expand All @@ -71,7 +70,7 @@ def test_function_single_block():
assert gv0.name_hint == "gv"
ib.emit_func_output(gv0)

func = ib.get()
func = ib.get()["func"]
assert func.params[0] == x
assert func.params[1] == y
assert func.body.body == gv0
Expand Down Expand Up @@ -106,7 +105,7 @@ def test_function_multi_blocks():
gv2 = ib.emit_output(gv1)
ib.emit_func_output(gv2)

func = ib.get()
func = ib.get()["func"]
assert gv2.shape[0] == m
assert gv2.shape[1] == n
assert gv2.checked_type.rank == 2
Expand All @@ -121,6 +120,40 @@ def test_function_multi_blocks():
assert len(func.body.blocks[2].bindings) == 2


def test_multi_functions():
m = tir.Var("m", "int32")
n = tir.Var("n", "int32")
dtype0 = rx.DynTensorType(rank=2, dtype="float16")
dtype1 = rx.DynTensorType(rank=1, dtype="float16")
x = rx.Var("x", [m, n], dtype0)
y = rx.Var("y", [n], dtype1)
ib = rx.BlockBuilder()

with ib.function([x, y], "func1"):
with ib.dataflow() as df:
lv0 = ib.emit(rx.op.add(x, y))
assert lv0.name_hint == "lv"
gv0 = ib.emit_output(lv0)
ib.emit_func_output(gv0)

with ib.function([x, y], "func2"):
with ib.dataflow() as df:
lv0 = ib.emit(rx.op.add(x, y))
assert lv0.name_hint == "lv"
gv0 = ib.emit_output(lv0)
ib.emit_func_output(gv0)

mod = ib.get()
func1 = mod["func1"]
assert func1.params[0] == x
assert func1.params[1] == y
assert func1.name.name_hint == "func1"
func2 = mod["func2"]
assert func2.params[0] == x
assert func2.params[1] == y
assert func2.name.name_hint == "func2"


def test_binary_shape_type_deduction():
m = tir.Var("m", "int32")
n = tir.Var("n", "int32")
Expand Down Expand Up @@ -177,7 +210,7 @@ def test_emit_match_shape():
y = rx.Var("shape_value", type_annotation=rx.ShapeType(), shape_annotation=shape_anno)
ib = rx.BlockBuilder()

with ib.function([x, y]):
with ib.function([x, y], "func"):
with ib.dataflow() as df:
# lv0: Tensor[(m, n), "float32"] =
# match_shape(x: Tensor[_, "float32"], [m, n])
Expand All @@ -194,7 +227,7 @@ def test_emit_match_shape():
gv0 = ib.emit_output(lv1)

ib.emit_func_output(gv0)
func = ib.get()
func = ib.get()["func"]
block = func.body.blocks[0]
b0, b1 = block.bindings[:2]
assert isinstance(b0, rx.MatchShape)
Expand Down Expand Up @@ -248,11 +281,8 @@ def te_func(args, args_dict, msg):
out = bb.emit_te(te_func, [x, y], {"C": z}, msg="hello")
bb.emit_func_output(out)

func = bb.get()
mod = bb.context_mod()

gvar = tvm.relay.GlobalVar("rx_func")
mod[gvar] = func
mod = bb.get()
rx_func = mod["rx_func"]

def get_tir_func():
A = te.placeholder((n, m), dtype="float32", name="A")
Expand All @@ -265,20 +295,20 @@ def get_tir_func():
assert_structural_equal(mod["te_func"].body, get_tir_func().body)

# check Relax function calls TIR function with call_dps call
assert func.params[0] == x
assert func.params[1] == y
assert func.params[2] == z
assert func.name.name_hint == "rx_func"
assert func.body.body == out
assert len(func.body.blocks) == 1
assert len(func.body.blocks[0].bindings) == 1
assert isinstance(func.body.blocks[0].bindings[0].value, rx.Call)
assert func.body.blocks[0].bindings[0].value.op == relay.op.get("relax.call_dps")
assert len(func.body.blocks[0].bindings[0].value.args) == 3
assert func.body.blocks[0].bindings[0].value.args[1].name_hint == "te_func"
assert func.body.blocks[0].bindings[0].value.args[2][0] == x
assert func.body.blocks[0].bindings[0].value.args[2][1] == y
assert func.body.blocks[0].bindings[0].value.args[2][2] == z
assert rx_func.params[0] == x
assert rx_func.params[1] == y
assert rx_func.params[2] == z
assert rx_func.name.name_hint == "rx_func"
assert rx_func.body.body == out
assert len(rx_func.body.blocks) == 1
assert len(rx_func.body.blocks[0].bindings) == 1
assert isinstance(rx_func.body.blocks[0].bindings[0].value, rx.Call)
assert rx_func.body.blocks[0].bindings[0].value.op == relay.op.get("relax.call_dps")
assert len(rx_func.body.blocks[0].bindings[0].value.args) == 3
assert rx_func.body.blocks[0].bindings[0].value.args[1].name_hint == "te_func"
assert rx_func.body.blocks[0].bindings[0].value.args[2][0] == x
assert rx_func.body.blocks[0].bindings[0].value.args[2][1] == y
assert rx_func.body.blocks[0].bindings[0].value.args[2][2] == z


def test_emit_te_multiple():
Expand All @@ -297,16 +327,45 @@ def te_func(A):
y1 = bb.emit_te(te_func, y)
bb.emit_func_output(y1)

func = bb.get()
func = bb.get()["rx_func"]
assert func.body.blocks[0].bindings[0].value.args[1].name_hint == "te_func"
assert func.body.blocks[0].bindings[1].value.args[1].name_hint == "te_func1"


def test_emit_te_extern():
bb = rx.BlockBuilder()
n, m = tir.Var("n", "int64"), tir.Var("m", "int64")
type_anno = rx.DynTensorType(2, "float32")
x = rx.Var("x", [n, m], type_anno)
y = rx.Var("y", [m, n], type_anno)

with bb.function([x, y], "rx_cblas_matmul"):
out = bb.emit_te(tvm.contrib.cblas.matmul, x, y, transa=False, transb=False)
bb.emit_func_output(out)

mod = bb.get()
rx_func = mod["rx_cblas_matmul"]

# check Relax function calls TIR function with call_dps call
assert rx_func.params[0] == x
assert rx_func.params[1] == y
assert len(rx_func.body.blocks) == 1
assert isinstance(rx_func.body.blocks[0].bindings[0].value, rx.Call)
assert rx_func.body.blocks[0].bindings[0].value.op == relay.op.get("relax.call_dps")
assert len(rx_func.body.blocks[0].bindings[0].value.args) == 3
assert rx_func.body.blocks[0].bindings[0].value.args[1].name_hint == "matmul"
assert rx_func.body.blocks[0].bindings[0].value.args[2][0] == x
assert rx_func.body.blocks[0].bindings[0].value.args[2][1] == y


if __name__ == "__main__":
test_block_builder()
test_function_single_block()
test_function_multi_blocks()
test_multi_functions()
test_binary_shape_type_deduction()
test_emit_match_shape()
test_normalize()
test_emit_te()
test_emit_te_multiple()
test_emit_te_extern()
Loading

0 comments on commit d168475

Please sign in to comment.