Skip to content

Commit

Permalink
[Unity][MSC] Avoid depending on trivial bindings in Relax intermediate (
Browse files Browse the repository at this point in the history
#16349)

* [Unity][MSC] Avoid depending on trivial bindings in Relax intermediate

The conversion from tensorflow to MSC is done by first converting from
tensorflow to relay, then converting from relay to executable python
code, executing that python code to generate relax, and finally
converting from relax to MSC.  During the relax phase of this
conversion, some relax `IRModule` are applied, including
`FuseOpsByPattern`.

The test cases in `test_msc/test_translate_tensorflow.py` rely on
`FuseOpsByPattern` preserving trivial bindings (e.g. `var_1 = var_2`)
in the relax IRModule.  If these trivial bindings are removed by
`CanonicalizeBindings`, then the test cases in this file fail.  The
presence or absence of trivial bindings `FuseOpsByPattern` should be
considered an implementation detail, and relax passes should not be
required to preserve trivial bindings.

This PR updates the relay to executable python step of the tensorflow
to MSC conversion, to remove trivial bindings and output a variable
name that matches the expected value in the test case.  While not an
ideal resolution, as other variable name changes could still
reintroduce the same test failures, it ensures that `FuseOpsByPattern`
may canonicalize bindings as an internal pre- or post-processing step
without breaking these unit tests.

* Update implementation to remove dataflow block in MSC codegen

The potential for duplicate variable names was introduced by having
the `block_builder.emit_output` call, which is only required to export
values from a dataflow block.  The dataflow block is not used in any
later MSC conversion, and its removal avoids this re-export of
variables.

If the dataflow block is required in the future, it can be generated
using `tvm.relax.transform.ConvertToDataflowBlock`.

* Make failing test cases be close to the same structural form

* Updated tests to validate output after compilation

* Lint fixes
  • Loading branch information
Lunderberg authored Jan 12, 2024
1 parent d1b890a commit b69d720
Show file tree
Hide file tree
Showing 6 changed files with 182 additions and 128 deletions.
16 changes: 15 additions & 1 deletion python/tvm/contrib/msc/core/codegen/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,4 +172,18 @@ def relay_to_relax(
def _bind_weights(mod: tvm.IRModule, folder: msc_utils.MSCDirectory) -> tvm.IRModule:
return BindParams("main", weights)(mod)

return codegen.load(inputs, post_load=_bind_weights)
mod = codegen.load(inputs, post_load=_bind_weights)

mod = tvm.ir.transform.Sequential(
[
# The canonicalization of relax variable bindings is not required
# for correctness. It does, however, remove trivial `x = y`
# bindings, preventing test cases from depending on their
# presence.
tvm.relax.transform.CanonicalizeBindings(),
tvm.relax.transform.ConvertToDataflow(min_size=1),
],
name="tvm.contrib.msc.core.codegen.relay_to_relax_postproc",
)(mod)

return mod
16 changes: 15 additions & 1 deletion python/tvm/contrib/msc/framework/tvm/codegen/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,18 @@ def _bind_weights(mod: tvm.IRModule, folder: msc_utils.MSCDirectory) -> tvm.IRMo
return mod

codegen = CodeGen(graph, _ffi_api.GetRelaxSources, codegen_config, print_config, build_folder)
return codegen.load(inputs, pre_load=_save_weights, post_load=_bind_weights)
mod = codegen.load(inputs, pre_load=_save_weights, post_load=_bind_weights)

mod = tvm.ir.transform.Sequential(
[
# The canonicalization of relax variable bindings is not required
# for correctness. It does, however, remove trivial `x = y`
# bindings, preventing test cases from depending on their
# presence.
tvm.relax.transform.CanonicalizeBindings(),
tvm.relax.transform.ConvertToDataflow(min_size=1),
],
name="tvm.contrib.msc.framework.tvm.codegen.to_relax_postproc",
)(mod)

return mod
11 changes: 3 additions & 8 deletions src/contrib/msc/framework/tvm/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,7 @@ void RelaxCodeGen::CodeGenGraph() {
continue;
}
int scope_level = CompareScope(node);
if (scope_level == 1) {
stack_.scope_start("block_builder.dataflow()");
} else if (scope_level == -1) {
if (scope_level == -1) {
stack_.scope_end();
}
CodeGenNode(node, config()->use_tools);
Expand All @@ -83,13 +81,11 @@ void RelaxCodeGen::CodeGenGraph() {
for (size_t i = 0; i < scopes().size() - 1; i++) {
stack_.scope_end();
}
} else if (scopes().size() == 0) {
// start dataflow scope for non-scope graph
stack_.scope_start("block_builder.dataflow()");
}
// mark outputs
stack_.comment("Emit the outputs");
Array<String> idx_exits;

for (const auto& e : graph()->GetExits()) {
const auto& idx_exit = IdxNodeBase(e) + (config()->use_tools ? "_exit" : "");
if (config()->use_tools) {
Expand All @@ -104,10 +100,9 @@ void RelaxCodeGen::CodeGenGraph() {
stack_.call_arg(DocUtils::ToStr(e->name + "_exit"), "name_hint");
}
}
stack_.func_call("emit_output", idx_exit, "block_builder").call_arg(idx_exit);
idx_exits.push_back(idx_exit);
}
stack_.scope_end();

if (config()->use_tools) {
stack_.func_call("msc_tools.execute_step", "output").call_arg(DocUtils::ToStr("after_build"));
if (idx_exits.size() == 1) {
Expand Down
3 changes: 0 additions & 3 deletions src/contrib/msc/framework/tvm/relax_opcode.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,6 @@ const Array<Doc> RelaxOpCode::GetDocs() {
if (node()->optype == "input" || node()->optype == "constant" || node()->optype == "shape") {
emit_var = false;
}
if (node()->optype == "tuple" && node()->children.size() == 0) {
emit_var = false;
}
if (emit_var) {
const auto& name = config()->explicit_name ? node()->name : "";
BuilderEmit(IdxNode(), name);
Expand Down
Loading

0 comments on commit b69d720

Please sign in to comment.