Skip to content

Commit

Permalink
add test
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Dec 13, 2021
1 parent ab01b3a commit 77c9385
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 3 deletions.
10 changes: 8 additions & 2 deletions python/tvm/relay/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,17 +695,23 @@ def LambdaLift():
return _ffi_api.LambdaLift()


def PartitionGraph(mod_name="default"):
def PartitionGraph(mod_name="default", bind_constants=True):
"""Partition a Relay program into regions that can be executed on different
backends.
Parameters
----------
bind_constants: bool
Whether or not to bind constants in partitioned subgraphs. For C-source based codegen,
it is recommended to set this to False to avoid embedding large constants in
a C source file.
Returns
-------
ret: tvm.transform.Pass
The registered pass that partitions the Relay program.
"""
mod_name = mangle_module_name(mod_name)
return _ffi_api.PartitionGraph(mod_name)
return _ffi_api.PartitionGraph(mod_name, bind_constants)


def AnnotateTarget(targets, include_non_call_ops=True):
Expand Down
50 changes: 49 additions & 1 deletion tests/python/relay/test_pass_partition_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1471,6 +1471,54 @@ def run(dtype, shape):
run("float32", [2, 3])


def test_not_bind_constant():
def get_net(prefix, data, out_channel):
weight = relay.var(prefix + "weight")
bn_gamma = relay.var(prefix + "bn_gamma")
bn_beta = relay.var(prefix + "bn_beta")
bn_mmean = relay.var(prefix + "bn_mean")
bn_mvar = relay.var(prefix + "bn_var")

layer = relay.nn.conv2d(
data=data, weight=weight, kernel_size=(3, 3), channels=out_channel, padding=(1, 1)
)
bn_output = relay.nn.batch_norm(layer, bn_gamma, bn_beta, bn_mmean, bn_mvar)
out = relay.nn.relu(bn_output[0])
return relay.Function(relay.analysis.free_vars(out), out)

def get_partitoned_mod(mod, params, pattern_table, bind_constants):
mod["main"] = bind_params_by_name(mod["main"], params)
remove_bn_pass = tvm.transform.Sequential(
[
transform.InferType(),
transform.SimplifyInference(),
transform.FoldConstant(),
transform.FoldScaleAxis(),
]
)
composite_partition = tvm.transform.Sequential(
[
remove_bn_pass,
transform.MergeComposite(pattern_table),
transform.AnnotateTarget("dnnl"),
transform.PartitionGraph(bind_constants=bind_constants),
]
)

with tvm.transform.PassContext(opt_level=3, disabled_pass=["AlterOpLayout"]):
return composite_partition(mod)

data = relay.var("data", relay.TensorType((1, 3, 224, 224), "float32"))
net = get_net("block_", data, 8)
mod, params = tvm.relay.testing.create_workload(net)

mod = get_partitoned_mod(mod, params, get_pattern_table("dnnl"), bind_constants=True)
len(mod["main"].body.args) == 1

mod = get_partitoned_mod(mod, params, get_pattern_table("dnnl"), bind_constants=False)
len(mod["main"].body.args) == 3


if __name__ == "__main__":
test_multi_node_compiler()
test_extern_ccompiler_single_op()
Expand All @@ -1492,4 +1540,4 @@ def run(dtype, shape):
test_flatten_tuple_output()
test_tuple_output_exec()
test_extern_opt()
test_static_tensor_array_gather_partition()
test_not_bind_constant()

0 comments on commit 77c9385

Please sign in to comment.