From 77c9385501595f804bd33b436aed0cc192059a10 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 13 Dec 2021 11:58:18 +0900 Subject: [PATCH] add test --- python/tvm/relay/transform/transform.py | 10 +++- .../python/relay/test_pass_partition_graph.py | 50 ++++++++++++++++++- 2 files changed, 57 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index 3c178ad6eefb..427e4cab699e 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -695,9 +695,15 @@ def LambdaLift(): return _ffi_api.LambdaLift() -def PartitionGraph(mod_name="default"): +def PartitionGraph(mod_name="default", bind_constants=True): """Partition a Relay program into regions that can be executed on different backends. + Parameters + ---------- + bind_constants: bool + Whether or not to bind constants in partitioned subgraphs. For C-source based codegen, + it is recommended to set this to False to avoid embedding large constants in + a C source file. Returns ------- @@ -705,7 +711,7 @@ def PartitionGraph(mod_name="default"): The registered pass that partitions the Relay program. """ mod_name = mangle_module_name(mod_name) - return _ffi_api.PartitionGraph(mod_name) + return _ffi_api.PartitionGraph(mod_name, bind_constants) def AnnotateTarget(targets, include_non_call_ops=True): diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py index 7d79698acf12..80fb2e03af2c 100644 --- a/tests/python/relay/test_pass_partition_graph.py +++ b/tests/python/relay/test_pass_partition_graph.py @@ -1471,6 +1471,54 @@ def run(dtype, shape): run("float32", [2, 3]) +def test_not_bind_constant(): + def get_net(prefix, data, out_channel): + weight = relay.var(prefix + "weight") + bn_gamma = relay.var(prefix + "bn_gamma") + bn_beta = relay.var(prefix + "bn_beta") + bn_mmean = relay.var(prefix + "bn_mean") + bn_mvar = relay.var(prefix + "bn_var") + + layer = relay.nn.conv2d( + data=data, weight=weight, kernel_size=(3, 3), channels=out_channel, padding=(1, 1) + ) + bn_output = relay.nn.batch_norm(layer, bn_gamma, bn_beta, bn_mmean, bn_mvar) + out = relay.nn.relu(bn_output[0]) + return relay.Function(relay.analysis.free_vars(out), out) + + def get_partitoned_mod(mod, params, pattern_table, bind_constants): + mod["main"] = bind_params_by_name(mod["main"], params) + remove_bn_pass = tvm.transform.Sequential( + [ + transform.InferType(), + transform.SimplifyInference(), + transform.FoldConstant(), + transform.FoldScaleAxis(), + ] + ) + composite_partition = tvm.transform.Sequential( + [ + remove_bn_pass, + transform.MergeComposite(pattern_table), + transform.AnnotateTarget("dnnl"), + transform.PartitionGraph(bind_constants=bind_constants), + ] + ) + + with tvm.transform.PassContext(opt_level=3, disabled_pass=["AlterOpLayout"]): + return composite_partition(mod) + + data = relay.var("data", relay.TensorType((1, 3, 224, 224), "float32")) + net = get_net("block_", data, 8) + mod, params = tvm.relay.testing.create_workload(net) + + mod = get_partitoned_mod(mod, params, get_pattern_table("dnnl"), bind_constants=True) + len(mod["main"].body.args) == 1 + + mod = get_partitoned_mod(mod, params, get_pattern_table("dnnl"), bind_constants=False) + len(mod["main"].body.args) == 3 + + if __name__ == "__main__": test_multi_node_compiler() test_extern_ccompiler_single_op() @@ -1492,4 +1540,4 @@ def run(dtype, shape): test_flatten_tuple_output() test_tuple_output_exec() test_extern_opt() - test_static_tensor_array_gather_partition() + test_not_bind_constant()