diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index cd9c7d68366d1..e074ddc6aa61b 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -323,6 +323,12 @@ class RelayBuildModule : public runtime::ModuleNode { /*is_homogenous=*/config_->optional_homogeneous_target.defined(), /*is_vm=*/false); transform::PassContext pass_ctx = PassContext::Current(); + if (config_->optional_homogeneous_target.defined()) { + // This pass currently only supports the homogeneous case. + pass_seqs.push_back(transform::SplitArgs( + config_->optional_homogeneous_target->GetAttr("max_function_args", -1).value())); + } + // Always plan devices so the remaining passes don't need to distinguish homogeneous vs // hetrogenous execution. pass_seqs.push_back(transform::PlanDevices(config_)); diff --git a/tests/python/unittest/test_micro_model_library_format.py b/tests/python/unittest/test_micro_model_library_format.py index e2ce442e0d885..502e75f938b38 100644 --- a/tests/python/unittest/test_micro_model_library_format.py +++ b/tests/python/unittest/test_micro_model_library_format.py @@ -412,5 +412,4 @@ def test_export_byoc_c_module(): if __name__ == "__main__": import sys - # sys.exit(pytest.main([__file__] + sys.argv[1:])) - test_export_operator_model_library_format() + sys.exit(pytest.main([__file__] + sys.argv[1:]))