Skip to content

Commit

Permalink
add mixed compilers to test
Browse files Browse the repository at this point in the history
Change-Id: I3ca48738e096bb0f4dc362f0e9550317fc0d5afd
  • Loading branch information
lhutton1 committed Mar 14, 2022
1 parent d491eab commit cf90859
Showing 1 changed file with 25 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -19,24 +19,38 @@
Test the outline compiler functions pass.
"""

import pytest

pytest.importorskip("ethosu.vela")

import tvm
from tvm import relay
from tvm.relay.backend.contrib.ethosu.codegen import OutlineCompilerFunctions


def test_outline_compiler_functions():
compiler_name = "my-compiler"
wrong_compiler_name = "wrong-compiler"

def before():
inp = relay.var("input")

# Inlined functions for "my-compiler"
x = relay.var("x", shape=(1, 2, 2, 4))
x = relay.reshape(x, newshape=(1, 4, 4))
x = relay.Function(relay.analysis.free_vars(x), x)
x = x.with_attr("Compiler", compiler_name)
x = x.with_attr("global_symbol", "ext_func")

# Inlined function for "wrong-compiler"
y = relay.var("y", shape=(1, 4, 4))
y = relay.reshape(y, newshape=(1, 16))
y = relay.Function(relay.analysis.free_vars(y), y)
y = y.with_attr("Compiler", wrong_compiler_name)
y = y.with_attr("global_symbol", "ext_func_2")

out = relay.Call(x, [inp])
out = relay.Call(y, [out])
out = relay.Function([inp], out)
return tvm.ir.IRModule.from_expr(out)

Expand All @@ -52,11 +66,21 @@ def expected():
x = x.with_attr("global_symbol", "ext_func")
mod["ext_func"] = x

y = relay.var("y", shape=(1, 4, 4))
y = relay.reshape(y, newshape=(1, 16))
y = relay.Function(relay.analysis.free_vars(y), y)
y = y.with_attr("Compiler", wrong_compiler_name)
y = y.with_attr("global_symbol", "ext_func_2")

out = relay.Call(mod.get_global_var("ext_func"), [inp])
out = relay.Call(y, [out])
mod["main"] = relay.Function([inp], out)
return mod

after = OutlineCompilerFunctions(compiler_name)(before())
exp = expected()
assert after["ext_func"]

global_vars = [str(gv) for gv in after.get_global_vars()]
assert "@ext_func" in global_vars
assert "@ext_func_2" not in global_vars
assert tvm.ir.structural_equal(after["ext_func"], exp["ext_func"])

0 comments on commit cf90859

Please sign in to comment.