From 78d92c8cad87814aecb83a5cbc5634c39b3f7120 Mon Sep 17 00:00:00 2001 From: mbs-octoml Date: Tue, 7 Jun 2022 16:12:34 -0700 Subject: [PATCH] [BYOC] RelayToTIR custom codegen passes can still depend on dynamic shape functions In #11474 I got ready to switch CUTLASS from function-at-a-time to IRModule-at-a-time compilation. However my approach didn't handle dynamic shape functions, so I adjust it here. The idea is still that such passes will leave behind calls to 'extern' functions. However, converting those calls to 'call_lowered' form in MarkCompilerFunctionsAsExtern is too soon since only the TECompiler knows how to capture all the attributes necessary to support dynamic shape functions. So stop doing that in MarkCompilerFunctionsAsExtern and instead support this case properly in the TECompiler. While there try to chip away at the chronic lack of structure in te_compiler.cc. Every little bit helps. Add a basic unit test. --- src/relay/backend/aot_executor_codegen.cc | 8 +- src/relay/backend/graph_executor_codegen.cc | 27 +- src/relay/backend/interpreter.cc | 3 +- src/relay/backend/te_compiler.cc | 329 ++++++++++++------ src/relay/backend/te_compiler.h | 32 +- src/relay/backend/vm/compiler.cc | 24 +- .../transforms/compiler_function_utils.cc | 51 --- .../transforms/compiler_function_utils.h | 11 +- .../relay/backend/test_pass_lower_te.py | 241 +++++++++++++ .../transform/test_compiler_function_utils.py | 5 +- 10 files changed, 503 insertions(+), 228 deletions(-) create mode 100644 tests/python/relay/backend/test_pass_lower_te.py diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index 167afd2c5f78..381cfa0c9d1c 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -1064,9 +1064,8 @@ class AOTExecutorCodegen : public MixedModeVisitor { mod = transform::ToANormalForm()(mod); - IRModule lowered_mod = tec::LowerTEPass( - mod_name, - [this, workspace_byte_alignment](BaseFunc func) { + IRModule lowered_mod = + tec::LowerTE(mod_name, config_, [this, workspace_byte_alignment](BaseFunc func) { // We need to maintain the constant map for external // functions so we pass this processing function which // allows us to process each function as we lower it. @@ -1078,8 +1077,7 @@ class AOTExecutorCodegen : public MixedModeVisitor { // execute as a further pass, instead writing data to the // lowering process directly. tec::UpdateFunctionMetadata(func, this->function_metadata_, workspace_byte_alignment); - }, - config_)(mod); + })(mod); auto lowered_main = lowered_mod->Lookup("main"); auto lowered_main_func = GetRef(lowered_main.as()); diff --git a/src/relay/backend/graph_executor_codegen.cc b/src/relay/backend/graph_executor_codegen.cc index 7dba23803f8c..af426e5c71cb 100644 --- a/src/relay/backend/graph_executor_codegen.cc +++ b/src/relay/backend/graph_executor_codegen.cc @@ -217,22 +217,19 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslatorGetAttr(attr::kCompiler).defined()) { - UpdateConstants(func, ¶ms_); - } + IRModule lowered_mod = tec::LowerTE(mod_name_, config_, [this](BaseFunc func) { + // We need to maintain the constant map for external + // functions so we pass this processing function which + // allows us to process each function as we lower it. + if (func->GetAttr(attr::kCompiler).defined()) { + UpdateConstants(func, ¶ms_); + } - // TODO(@areusch, @jroesch): We should refactor this to - // execute as a further pass, instead writing data to the - // lowering process directly. - tec::UpdateFunctionMetadata(func, this->function_metadata_); - }, - config_)(mod); + // TODO(@areusch, @jroesch): We should refactor this to + // execute as a further pass, instead writing data to the + // lowering process directly. + tec::UpdateFunctionMetadata(func, this->function_metadata_); + })(mod); Optional main_func_info = lowered_mod->GetAttr("main_func_info"); diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index 9661040eab30..65a0fdc94824 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -960,8 +960,7 @@ IRModule Prepare(IRModule mod, const CompilationConfig& config) { // eta expand to support constructors in argument position. transform::EtaExpand( /*expand_constructor=*/true, /*expand_global_var=*/false), - transform::InferType(), - tec::LowerTEPass(/*module_name=*/"intrp", [](BaseFunc func) { /* no-op */ }, config)}); + transform::InferType(), tec::LowerTE(/*module_name=*/"intrp", config)}); transform::PassContext pass_ctx = transform::PassContext::Current(); With ctx(pass_ctx); diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index c78f3abd6ecc..e9491b0a8901 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -17,6 +17,76 @@ * under the License. */ +/*! + * \file relay/backend/te_compiler.cc + * \brief Manages the transition from Relay "Primitive" \p Functions to TIR \p PrimFuncs. Also + * handles invocation of external codegen. + * + * \p LowerTEPass handles the following (as a monolithic blob of code): + * + * - Most importantly, any function with the "Primitive" attribute is first converted to TE by + * \p LowerToTECompute (see te_compiler_cache.cc) using each operator's 'compute' function. + * The TE is then 'scheduled' to TIR using the 'anchor' operator's 'schedule' function. Both + * of those functions come from the \p OpStrategy returned by the Python + * 'relay.backend.lower_call' function (see te_compiler.py). + * The TIR is packed as a \p PrimFunc and introduced as a new global function. Calls to the + * original "Primitive" function are then rewritten to the form: + * \code + * call_lowered(@new_global, (... original args...), attributes) + * \endcode + * + * - The above "Primitive" function can appear: + * - As a global function + * - As a let-bound function + * - As an inline function, ie the 'op' of calls. + * In all three cases it is possible for the same "Primitive" function to be called multiple + * times, and that sharing must be respected. + * + * - "Primitive" functions must have a "global_symbol" attribute matching their desired or + * existing global name. Care is taken to ensure GlobalVars with the same name are shared. + * + * - It is possible for multiple structurally equal "Primitive" functions to appear in the same + * \p IRModule. Only one implementation should be generated, and all calls should share that + * implementation. + * + * - When later converting to DPS (see memory_alloc.cc) we must handle functions who's result + * tensor shapes depend at runtime on the input tensor shapes and/or data. + * - That dependency is first described in TE form (see \p MakeShapeFunc in + * te_compiler_cache.cc), then scheduled to yield a 'dynamic shape function' \p PrimFunc. + * This relies on each operator's "FShapeFunc" and "TShapeDataDependent" attributes. + * Since shapes are rank-1 tensors everything can be reflected back down into the regular + * TE/TIR forms. + * - Then the call_lowered attributes must record everything about the dynamic shape function + * later needed by memory_alloc.cc. We call this 'cross linking' the call with the shape + * function. + * + * - Two external codegen mechanisms are supported, both triggered by "Primitive" functions which + * also have a "Compiler" attribute bound to $compiler: + * - Function-at-a-time (old style): The primitive function is passed to the function + * registered as 'relay.ext.$compiler'. The function returns a runtime::Module which + * should return true for \p ImplementsFunction for the function's global name. That + * module is added to the IRModule's "external_mods" attributes. + * - IRModule-at-a-item (new style): The \p RelayToTIRTargetHook sub-pass looks for + * $compiler names which correspond to TargetKind names with a \p RelayToTIR attribute. + * The \p Pass bound to that attribute is run, and each such 'custom' pass can do what + * it likes, including replacing Functions with PrimFuncs, or adding new runtime::Modules + * to the IRModule's "external_mods" attribute. + * + * - Calls to functions added by external codegen are also rewritten to call_lowered form, and + * may also require cross-linking to dynamic shape functions. However, since the functions + * are/will be implemented by a runtime::Module all the Relay type information is no longer + * available. So the Relay definitions for these "Primitive" "Compiler" functions are retained + * in the \p IRModule, but marked with the "Extern" attribute to signal the function is now + * just for carrying metadata. + * + * - Some operators are handled specially: + * - 'reshape', since it's a no-op on the underlying tensor buffer, and this is handled by + * condition tests in many passes. + * - 'debug', since it's intercepted differently depending on runtimes. + * + * TODO(mbs): This desperately deserves a refactor to separate all these concerns. See Relax. + */ + #include "./te_compiler.h" #include @@ -222,7 +292,7 @@ class TECompilerImpl : public TECompilerNode { } else { // It is valid for the external codegen function to return null: // - Unit tests can use it. - // - The true compilation may have already been handled by a RelayToTIR custom hook pass + // - The true compilation may have already been handled by a RelayToTIR custom pass // on the Target's kind. The original Relay functions will be left in place so // that we can capture that their function names are now externally defined. VLOG(1) << "Note that no external runtime module was generated by external codegen '" @@ -566,100 +636,128 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { return itr->second; } } else if (const auto* function_node = expr.as()) { - if (!function_node->HasNonzeroAttr(attr::kPrimitive)) { - // Not marked as primitive by FuseOps. - return {}; - } - if (const auto* call_node = function_node->body.as()) { - if (call_node->op == debug_op_) { - // Debug 'primitives' are not lowered. - return {}; + if (function_node->HasNonzeroAttr(attr::kExtern)) { + // We have a regular call to an 'extern' function. The call itself needs to be rewritten + // to call_lowered form, and any required dynamic shape functions generated and + // cross-linked. + return GetRef(function_node); + } else if (function_node->HasNonzeroAttr(attr::kPrimitive)) { + if (const auto* call_node = function_node->body.as()) { + if (call_node->op == debug_op_) { + // Debug 'primitives' are not lowered. + return {}; + } } + // We have a regular call to a 'primitive' function (possibly with a 'Compiler' attribute). + // We need to lower and rewrite the call. + return GetRef(function_node); + } else { + // Not marked as primitive during partitioning or TVM fusion. + return {}; } - return GetRef(function_node); } else { return {}; } } /*! - * \brief Lowers the primitive function \p func to TIR for ultimate execution - * on a device with configuration \p target. Returns the global var bound - * to the TIR implementation, and attributes to attach to the call to identify it as - * a TIR call. + * \brief Returns a 'call_lowered' call to \p prim_fn_var with \p args and \p span with all the + * required attributes filled in. Generally \p prim_fn_var will correspond to the lowered or + * externally codegen-ed form of \p original_function, where \p lowered_functions binds all + * the required lowered functions. + * + * The call's attributes will capture: + * - Any attributes on the original_function. + * - All the lowered functions. + * TODO(mbs): Pretty sure that's no longer needed. + * - Details needed to cross-link the call to it's dynamic shape function, if any. */ - Expr MakeLoweredCall(Function func, Array visited_args, Span span, Target target) { - CCacheKey key = CCacheKey(func, target); - CachedFunc cfunc = compiler_->Lower(key, module_name_); - ICHECK(cfunc.defined()); - - auto opt_compiler = func->GetAttr(attr::kCompiler); + Expr MakeLoweredCall(const BaseFunc& original_function, const GlobalVar& prim_fn_var, + Array args, Span span, const Target& target, + const Map& lowered_functions) { + auto opt_compiler = original_function->GetAttr(attr::kCompiler); // Add some metadata on top of the *original function* and invoke the callback so it can // be captured. // TODO(@areusch, @jroesch): this metadata is for AOT, this should be our interface for AOT Map prim_fns; Array all_prim_fn_vars; - for (const auto& kv : cfunc->funcs->functions) { + for (const auto& kv : lowered_functions) { if (opt_compiler) { - // We expect just the original func but with just the ExternalSymbol attribute signaling - // the function (will be) compiled externally. + // We expect the original function to have just the "Extern" attribute signaling the + // function (will be) compiled externally. ICHECK(kv.second.as()) << PrettyPrint(kv.first) << " must be bound to an (external) Function"; } else { - // We expect one or more PrimFuncs, one of which corresponds to 'the' lowered primitive - // (and the rest in support of that via tir::Calls). + // We expect one or more PrimFuncs, one of which corresponds to 'the' lowered primitive, + // and the rest are in support of that via tir::Calls. ICHECK(kv.second.as()) << PrettyPrint(kv.first) << " must be bound to a PrimFunc"; prim_fns.Set(kv.first, Downcast(kv.second)); all_prim_fn_vars.push_back(kv.first); } } - Function func_with_metadata = func; - func_with_metadata = WithAttr(func_with_metadata, "prim_fn_var", cfunc->prim_fn_var); - func_with_metadata = WithAttr(func_with_metadata, "prim_funcs", prim_fns); - func_with_metadata = WithAttr(func_with_metadata, tvm::attr::kTarget, cfunc->target); - this->process_fn_(func_with_metadata); + // Alas, WithAttr cannot work with base classes. + if (const auto* prim_func_node = original_function.as()) { + auto func_with_metadata = GetRef(prim_func_node); + func_with_metadata = WithAttr(func_with_metadata, "prim_fn_var", prim_fn_var); + func_with_metadata = WithAttr(func_with_metadata, "prim_funcs", prim_fns); + func_with_metadata = WithAttr(func_with_metadata, tvm::attr::kTarget, target); + this->process_fn_(func_with_metadata); + } else { + const auto* function_node = original_function.as(); + ICHECK(function_node); + auto func_with_metadata = GetRef(function_node); + func_with_metadata = WithAttr(func_with_metadata, "prim_fn_var", prim_fn_var); + func_with_metadata = WithAttr(func_with_metadata, "prim_funcs", prim_fns); + func_with_metadata = WithAttr(func_with_metadata, tvm::attr::kTarget, target); + this->process_fn_(func_with_metadata); + } + + // Now prepare the attributes of the call_lowered. CallLoweredAttrs call_lowered_attrs; - // Non-External Relay Function // TODO(mbs): "reshape" cleanup. - if (!opt_compiler && func->HasNonzeroAttr(attr::kReshapeOnly)) { + if (!opt_compiler && original_function->HasNonzeroAttr(attr::kReshapeOnly)) { call_lowered_attrs.metadata.Set(attr::kReshapeOnly, tvm::Integer(1)); } - call_lowered_attrs.metadata.Set("relay_attrs", func->attrs); + call_lowered_attrs.metadata.Set("relay_attrs", original_function->attrs); call_lowered_attrs.metadata.Set("all_prim_fn_vars", all_prim_fn_vars); - if (IsDynamic(func->ret_type)) { - // Also lower the companion dynamic shape function. - // Shape function keys use the underlying primitive function as their 'function', - // but the generic 'cpu' target as the target since all shape functions run - // on the host cpu irrespective of where the primitive runs. - CCacheKey shape_key(func, config_->host_virtual_device->target); - CachedFunc lowered_shape_func = compiler_->LowerShapeFunc(shape_key); - - // Capture the shape function's global var and parameters 'states' in call - // annotations so calling convention can be recovered. - // TODO(mbs): Shape cleanup. - call_lowered_attrs.metadata.Set("prim_shape_fn_var", lowered_shape_func->prim_fn_var); - call_lowered_attrs.metadata.Set("prim_shape_fn_states", - lowered_shape_func->shape_func_param_states); - call_lowered_attrs.metadata.Set("prim_shape_fn_num_inputs", - Integer(static_cast(lowered_shape_func->inputs.size()))); - call_lowered_attrs.metadata.Set( - "prim_shape_fn_num_outputs", - Integer(static_cast(lowered_shape_func->outputs.size()))); - Array all_prim_shape_fn_vars; - for (const auto& kv : lowered_shape_func->funcs->functions) { - CHECK(kv.second.as()) << "must be a prim fn"; - all_prim_shape_fn_vars.push_back(kv.first); + if (const auto* function_node = original_function.as()) { + if (IsDynamic(function_node->ret_type)) { + // Create a dynamic shape function to calculate the expected shape of the results of + // the lowered function. + // Shape function keys use the original function as their 'function', but the generic 'cpu' + // target as the target since all shape functions run on the host cpu irrespective of where + // the primitive runs. + CCacheKey shape_key(GetRef(function_node), config_->host_virtual_device->target); + CachedFunc lowered_shape_func = compiler_->LowerShapeFunc(shape_key); + + // Capture the shape function's global var and parameters 'states' in call + // annotations so calling convention can be recovered. + // TODO(mbs): Shape cleanup. + call_lowered_attrs.metadata.Set("prim_shape_fn_var", lowered_shape_func->prim_fn_var); + call_lowered_attrs.metadata.Set("prim_shape_fn_states", + lowered_shape_func->shape_func_param_states); + call_lowered_attrs.metadata.Set( + "prim_shape_fn_num_inputs", + Integer(static_cast(lowered_shape_func->inputs.size()))); + call_lowered_attrs.metadata.Set( + "prim_shape_fn_num_outputs", + Integer(static_cast(lowered_shape_func->outputs.size()))); + Array all_prim_shape_fn_vars; + for (const auto& kv : lowered_shape_func->funcs->functions) { + CHECK(kv.second.as()) << "must be a prim fn"; + all_prim_shape_fn_vars.push_back(kv.first); + } + call_lowered_attrs.metadata.Set("all_prim_shape_fn_vars", all_prim_shape_fn_vars); } - call_lowered_attrs.metadata.Set("all_prim_shape_fn_vars", all_prim_shape_fn_vars); } - return CallLowered(cfunc->prim_fn_var, std::move(visited_args), std::move(call_lowered_attrs), + return CallLowered(prim_fn_var, std::move(args), std::move(call_lowered_attrs), std::move(span)); } @@ -697,43 +795,51 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { } Expr DeviceAwareVisitExpr_(const CallNode* call_node) override { - // We can see five forms of calls: - // 1. A 'normal' Relay call to a Function with the "primitive" attribute. We will need - // to lower that to a global PrimFunc and rewrite the call to: + // We can see six forms of calls: + // 1. A 'normal' Relay call to a Function with the "Primitive" attribute and not "Compiler" + // attribute. We will need to lower that to a global PrimFunc and rewrite the call to: // call_lowered(@new_global, (arg1, ..., argn), ) - // However there are a few special forms which are excluded from this treatment, see - // below. - // 2. A 'normal' Relay call to a Function with the "compiler" attribute. We will need - // to invoke the appropriate BYOC toolchain function to yield a runtime module and - // rewrite the call to the same form as above. - // 3. A 'normal' Relay call to a PrimFunc which has already been supplied via a global - // definition. We rewrite to use the call_lowered form, but otherwise nothing else + // If needed, the call needs to be cross-linked with any dynamic shape functions. + // (However, some primitives are special and handled separately.) + // 2. A 'normal' Relay call to a Function with the "Primitive" and "Compiler" attributes. We + // will need to invoke the "relay.ext." function to yield a runtime module, and + // rewrite the call to the same form as above. Dynamic shape function cross-linking may + // also be needed. + // 3. A 'normal' Relay call to a Function with the "Extern" attribute. This function has + // already been compiled by an external codegen and a definition for it exists in some + // runtime module. Again, we rewrite to call_lowered form, and cross-link with a dynamic + // shape function if needed. + // 4. A 'normal' Relay call to a PrimFunc which has already been supplied via a global + // definition. We rewrite those to use the call_lowered form, but otherwise nothing else // needs to be done. - // 4. A 'normal' Relay call to a Relay Function without any special attribute. These + // 5. A 'call_lowered' call from an earlier invocation of this pass or otherwise deliberately + // inserted. It has all the required attributes, and any associated dynamic shape function + // has been generated and cross-linked. These calls are not changed. + // 6. A 'normal' Relay call to a Relay Function without any special attribute. These // calls are not changed. - // 5. A call_lowered call from an earlier invocation of this pass. - // Note that ResolveToPrimitive will yield non-null only for cases 1-3. + // + // Note that ResolveToPrimitive will yield non-null only for cases 1-4. + + // Prepare the arguments and op. + Array new_args; + for (const auto& arg : call_node->args) { + new_args.push_back(VisitExpr(arg)); + } + Expr new_op = VisitExpr(call_node->op); // Look for (possibly indirect) calls to primitives. BaseFunc primitive_func = ResolveToPrimitive(call_node->op); if (!primitive_func.defined()) { - // Not a call to a primitive function we need to rewrite. + // Cases 5 and 6: Leave as ordinary call. if (const auto* function_node = call_node->op.as()) { process_fn_(GetRef(function_node)); } - return DeviceAwareExprMutator::DeviceAwareVisitExpr_(call_node); - } - - // Prepare the arguments. - Array new_args; - for (const auto& arg : call_node->args) { - new_args.push_back(VisitExpr(arg)); + return WithFields(GetRef(call_node), std::move(new_op), std::move(new_args)); } - // Special case: device_copies are left as calls to primitive operators - // (thus undoing FuseOps) so that each backend can handle them directly. - // TODO(mbs): device_copy cleanup. Would be better for FuseOps to just leave device_copy - // alone. + // Special case for case 1: device_copies are left as calls to primitive operators + // so that each backend can handle them directly. + // TODO(mbs): device_copy cleanup. Would be better for FuseOps to just leave device_copy alone. if (const auto* function_node = primitive_func.as()) { DeviceCopyProps device_copy_props = GetDeviceCopyProps(function_node->body); if (device_copy_props.body.defined()) { @@ -743,33 +849,23 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { } } - // Special case: If already lowered by other means then so we don't need to mutate - // the call but we do need to mutate the arguments + ICHECK(call_node->type_args.empty()) << "lowered functions cannot be polymorphic"; + + // Case 4: If the function has already been lowered we just need to update the call. if (const auto* prim_func_node = primitive_func.as()) { // Function should already be Target annotated by this point // but the TE Compiler metadata is still needed for the callback // TODO(Mousius) - Robustify this to not assume we're in the GlobalVar for Target Hooks - GlobalVar prim_func_var = Downcast(call_node->op); + Optional opt_target = primitive_func->GetAttr(tvm::attr::kTarget); + ICHECK(opt_target.defined()); + auto prim_fn_var = Downcast(call_node->op); tir::PrimFunc prim_func = GetRef(prim_func_node); - - Map prim_fns = {{prim_func_var, prim_func}}; - tir::PrimFunc func_with_metadata = WithAttrs(prim_func, { - {"prim_fn_var", prim_func_var}, - {"prim_funcs", prim_fns}, - }); - - ICHECK(!IsDynamic(call_node->checked_type())); - CallLoweredAttrs call_lowered_attrs; - call_lowered_attrs.metadata.Set("relay_attrs", primitive_func->attrs); - - process_fn_(func_with_metadata); - ICHECK(call_node->type_args.empty()) << "lowered functions cannot be polymorphic"; - return CallLowered(prim_func_var, std::move(new_args), std::move(call_lowered_attrs), - call_node->span); + Map prim_fns = {{prim_fn_var, prim_func}}; + return MakeLoweredCall(primitive_func, prim_fn_var, std::move(new_args), call_node->span, + opt_target.value(), prim_fns); } - // Typical case: call to fused primitive Relay Function. - // Find the desired target device. + // Determine the target for lowering or external codegen. Target target; Optional opt_compiler = primitive_func->GetAttr(attr::kCompiler); if (opt_compiler.defined()) { @@ -791,10 +887,20 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { ICHECK(target.defined()); } - // Lower the primitive function for that target. - Function function = Downcast(primitive_func); - ICHECK(call_node->type_args.empty()) << "lowered functions cannot be polymorphic"; - return MakeLoweredCall(function, std::move(new_args), call_node->span, target); + if (primitive_func->HasNonzeroAttr(attr::kExtern)) { + // Case 3: Function has already been compiled. + GlobalVar prim_fn_var = Downcast(call_node->op); + return MakeLoweredCall(primitive_func, prim_fn_var, std::move(new_args), call_node->span, + target, /*lowered_functions=*/{}); + } else { + // Cases 1 and 2: lower the primitive function for the desired target, possibly using external + // codegen. + CCacheKey key(Downcast(primitive_func), target); + CachedFunc cfunc = compiler_->Lower(key, module_name_); + ICHECK(cfunc.defined()); + return MakeLoweredCall(primitive_func, cfunc->prim_fn_var, std::move(new_args), + call_node->span, target, cfunc->funcs->functions); + } } IRModule module_; @@ -1046,6 +1152,7 @@ void UpdateFunctionMetadata(BaseFunc func, function_metadata.Set(prim_fn_var.value()->name_hint, fi); } +/*! \brief Main lowering driving. */ IRModule LowerTE(const IRModule& module, const String& module_name, ProcessFn process_fn, CompilationConfig config) { TECompiler compiler(module); @@ -1163,7 +1270,7 @@ Map GetPerTargetModules(IRModule mod) { return per_target_modules; } -Pass LowerTEPass(String module_name, ProcessFn process_fn, CompilationConfig complilation_config) { +Pass LowerTE(String module_name, CompilationConfig complilation_config, ProcessFn process_fn) { runtime::TypedPackedFunc pass_func = [=](IRModule module, PassContext ctx) { return LowerTE(module, module_name, process_fn, complilation_config); @@ -1174,6 +1281,12 @@ Pass LowerTEPass(String module_name, ProcessFn process_fn, CompilationConfig com tvm::transform::CreateModulePass(pass_func, 0, "LowerTE", {"InferType"}), InferType(), tvm::tir::transform::ExtractPrimFuncConstants()}); } + +TVM_REGISTER_GLOBAL("relay.tec.LowerTE") + .set_body_typed([](String module_name, CompilationConfig compilation_config) { + return LowerTE(std::move(module_name), std::move(compilation_config)); + }); + } // namespace tec } // namespace relay } // namespace tvm diff --git a/src/relay/backend/te_compiler.h b/src/relay/backend/te_compiler.h index 8312a20cb862..5d16da4b8bb2 100644 --- a/src/relay/backend/te_compiler.h +++ b/src/relay/backend/te_compiler.h @@ -18,8 +18,8 @@ */ /*! - * \file relay/backend/tir_compiler.h - * * \brief Internal compilation layer which lowers Relay "primitive functions" to TIR PrimFns. + * \file relay/backend/te_compiler.h + * \brief Internal compilation layer which lowers Relay "primitive functions" to TIR PrimFns. * * * This represents the new design of the Relay compilation flow and will replace the interface @@ -173,36 +173,22 @@ backend::FunctionInfo UpdateMainWorkspaceSize(const IRModule& mod, const Compila */ Map GetPerTargetModules(IRModule mod); -/*! \brief Lower an IRModule's primitive functions to TIR. - * - * This is the "back half" of the Relay compiler which lowers "primitive functions" - * to TE expressions, schedules them, and then to TIR. - * - * \param module The IRModule. - * \param memory_plan The memory plan used during lowering - * \param module_name The name of this module - * \param process_fn Callback allowing one-level up code generators to process - * each function that we lower - * \return The lowered module, see above. - */ -IRModule LowerTE( - const IRModule& module, backend::StaticMemoryPlan memory_plan, const String& module_name, - ProcessFn process_fn = [](BaseFunc f) {}); +inline void DefaultProcessFn(BaseFunc) {} /*! * \brief Pass to lower an IRModule's primitive functions to TIR. * * This is the "back half" of the Relay compiler which lowers "primitive functions" - * to TE expressions, schedules them, and then to TIR. It annotates all functions - * with their target. + * to TE expressions, schedules them, and emits PrimFuncs. * - * \param module_name The name of this module - * \param process_fn Callback allowing one-level up code generators to process - * each function that we lower + * \param module_name The name of this module, used as a prefix for generated globals. * \param config All available targets. + * \param process_fn Callback allowing one-level up code generators to process + * each function that we lower (default is no-op). * \returns The pass which lowers primitive functions to TIR */ -transform::Pass LowerTEPass(String module_name, ProcessFn process_fn, CompilationConfig config); +transform::Pass LowerTE(String module_name, CompilationConfig config, + ProcessFn process_fn = DefaultProcessFn); } // namespace tec } // namespace relay diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index d9730b1b5a4c..9c0780e6e43f 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -1039,13 +1039,11 @@ transform::Sequential VMCompiler::FuseAndLowerOperators(const CompilationConfig& // Give each "primitive" Function a hash. pass_seqs.push_back(LabelOps()); // Lower "primitive" Functions to PrimFuncs and rewrite calls. - pass_seqs.push_back(tec::LowerTEPass(/*module_name=*/"vm_mod", - [this](const BaseFunc& func) { - if (func->GetAttr(attr::kCompiler).defined()) { - backend::UpdateConstants(func, ¶ms_); - } - }, - config)); + pass_seqs.push_back(tec::LowerTE(/*module_name=*/"vm_mod", config, [this](const BaseFunc& func) { + if (func->GetAttr(attr::kCompiler).defined()) { + backend::UpdateConstants(func, ¶ms_); + } + })); // Since lowered functions are bound in the IRModule, we can now eliminate any unused // let-bound functions. pass_seqs.push_back(DeadCodeElimination(/*inline_once=*/false)); @@ -1090,13 +1088,11 @@ IRModule VMCompiler::OptimizeModuleImpl(IRModule mod) { pass_seqs.push_back(transform::LabelOps()); // Lower all functions annotated as "primitive" by FuseOps. - pass_seqs.push_back(tec::LowerTEPass(/*module_name=*/"vm_mod", - [this](const BaseFunc& func) { - if (func->GetAttr(attr::kCompiler).defined()) { - backend::UpdateConstants(func, ¶ms_); - } - }, - config_)); + pass_seqs.push_back(tec::LowerTE(/*module_name=*/"vm_mod", config_, [this](const BaseFunc& func) { + if (func->GetAttr(attr::kCompiler).defined()) { + backend::UpdateConstants(func, ¶ms_); + } + })); // Since lowered functions are bound in the IRModule, we can now eliminate any unused // let-bound functions. diff --git a/src/relay/transforms/compiler_function_utils.cc b/src/relay/transforms/compiler_function_utils.cc index f22e9bd80dd0..3df07e4c57f5 100644 --- a/src/relay/transforms/compiler_function_utils.cc +++ b/src/relay/transforms/compiler_function_utils.cc @@ -81,42 +81,6 @@ class Outliner : public MixedModeMutator { IRModule mod_; }; -/*! - * \brief Rewrite calls to global "Compiler" functions to use the 'call_lowered' convention. - */ -class CallRewriter : public MixedModeMutator { - public: - CallRewriter(std::string compiler_filter, IRModule mod) - : compiler_filter_(std::move(compiler_filter)), mod_(std::move(mod)) {} - - Expr Rewrite_(const CallNode* pre, const Expr& post) final { - Call new_call = Downcast(post); - if (const auto* global_var_node = new_call->op.as()) { - if (const auto* function_node = - mod_->Lookup(GetRef(global_var_node)).as()) { - Optional opt_compiler = function_node->GetAttr(attr::kCompiler); - if (opt_compiler.defined() && - (compiler_filter_.empty() || opt_compiler.value() == compiler_filter_)) { - Optional opt_global_symbol = - function_node->GetAttr(tvm::attr::kGlobalSymbol); - ICHECK(opt_global_symbol.defined()); - GlobalVar global_symbol = mod_->GetGlobalVar(opt_global_symbol.value()); - CallLoweredAttrs attrs; - attrs.metadata.Set("relay_attrs", new_call->attrs); - return CallLowered(global_symbol, new_call->args, attrs, new_call->span); - } - } - } - return post; - } - - private: - /*! \brief If non-empty, the "Compiler" attribute value to require on functions to outline. */ - std::string compiler_filter_; - /*! \brief Module being rewritten. */ - IRModule mod_; -}; - } // namespace GlobalSymbolCache::~GlobalSymbolCache() = default; @@ -169,20 +133,6 @@ transform::Pass MarkCompilerFunctionsAsExtern(std::string compiler_filter) { runtime::TypedPackedFunc pass_func = [compiler_filter = std::move(compiler_filter)](IRModule mod, transform::PassContext ctx) { IRModule output_mod = mod->ShallowCopy(); - - // First pass, rewrite the calls. - // We have to do this before marking functions as 'extern' to know which calls to rewrite! - for (const auto& kv : mod->functions) { - if (const auto* function_node = AsOptimizableFunctionNode(kv.second)) { - Expr new_body = - CallRewriter(compiler_filter, output_mod).VisitExpr(function_node->body); - Function new_function = - WithFields(GetRef(function_node), /*opt_params=*/{}, new_body); - output_mod->Update(kv.first, new_function); - } - } - - // Second pass, mark functions as 'extern'. for (const auto& kv : mod->functions) { if (const auto* function_node = kv.second.as()) { Optional opt_compiler = function_node->GetAttr(attr::kCompiler); @@ -197,7 +147,6 @@ transform::Pass MarkCompilerFunctionsAsExtern(std::string compiler_filter) { } } } - return output_mod; }; diff --git a/src/relay/transforms/compiler_function_utils.h b/src/relay/transforms/compiler_function_utils.h index e4b1f05211fe..9d1dcd9f21a2 100644 --- a/src/relay/transforms/compiler_function_utils.h +++ b/src/relay/transforms/compiler_function_utils.h @@ -43,11 +43,8 @@ * * - \p MarkCompilerFunctionsAsExtern will replace global functions with a matching "Compiler" * attribute with the same function with just an "Extern" attribute, signalling the function - * has been dealt with. Calls to such functions will be rewritten to use the 'call_lowered' - * calling convention. Can be used after lowering to cleanup the IRModule. - * - * Note that the above behaviour is hard coded within the TECompiler, but is only available to - * external codegen using the Function-at-a-time "relay.ext.toolchain" extension point. + * has been dealt with. However calls to such functions will be left unchanged. Can be used + * after lowering to cleanup the IRModule. */ #ifndef TVM_RELAY_TRANSFORMS_COMPILER_FUNCTION_UTILS_H_ @@ -118,8 +115,8 @@ transform::Pass OutlineCompilerFunctionsWithExistingGlobalSymbols(std::string co /*! * \brief A pass to mark all global functions which have a "Compiler" attribute matching - * compiler_filter as 'extern' by replacing all attributes with a single "Extern" attribute, and - * rewrite all calls to such functions to use the 'call_lowered' calling convention. + * compiler_filter as 'extern' by replacing all attributes with a single "Extern" attribute. + * Calls to such functions are not changed. * * If \p compiler_filter is non-empty only functions with that as their attribute value are * outlined. diff --git a/tests/python/relay/backend/test_pass_lower_te.py b/tests/python/relay/backend/test_pass_lower_te.py new file mode 100644 index 000000000000..310a16e269e0 --- /dev/null +++ b/tests/python/relay/backend/test_pass_lower_te.py @@ -0,0 +1,241 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Exercises the LowerTE pass. + +import tvm +import tvm.testing +import logging + +logging.basicConfig() +logger = logging.getLogger("test_pass_lower_te") +logger.setLevel(logging.INFO) + +# Since the TE compiler needs a good refactor it has not been exposed as a 'standard' pass +# in relay.transform. For testing grab it directly. +LowerTE = tvm._ffi.get_global_func("relay.tec.LowerTE") + + +def transform(mod): + logger.info("Starting module:\n%s", mod) + host_target = tvm.target.Target("llvm") + prim_target = tvm.target.Target("llvm", host=host_target) + ctxt = tvm.transform.PassContext() + config = tvm.target.make_compilation_config(ctxt, prim_target) + mod = tvm.relay.transform.PlanDevices(config)(mod) + mod = tvm.relay.transform.InferType()(mod) + mod = LowerTE("test", config)(mod) + mod = tvm.relay.transform.InferType()(mod) + logger.info("After LowerTE:\n%s", mod) + return mod + + +# All attempts to use structural equalty tests against an expected IRModule parsed from +# Relay text were thwarted by the difficulty of setting up the expected call_lower attributes +# with the right GlobalVar instances. So the following assert structural correctness the hard way. + + +def test_lower_primitive(): + input_mod = tvm.parser.parse( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32]) -> Tensor[(5, 7), float32] { + %0 = fn(%x : Tensor[(5, 7), float32], %y : Tensor[(5, 7), float32], Primitive=1) -> Tensor[(5, 7), float32] { + add(%x, %y) + }; + %0(%a, %a) + } + """, + "from_string", + None, + None, + ) + + actual_mod = transform(input_mod) + + # Expected: + # def @main(%a: Tensor[(5, 7), float32]) -> Tensor[(5, 7), float32] { + # %0 = (%a, %a); + # call_lowered(@test_fused_add, %0, metadata={relay_attrs={Primitive=1},all_prim_fn_vars=[@test_fused_add]}) + # } + # def @test_fused_add = + + main = actual_mod["main"] + call = main.body + assert call.op.name == "call_lowered" + assert len(call.args) == 2 + assert call.args[0].name_hint == "test_fused_add" + assert len(call.args[1].fields) == 2 + assert call.args[1].fields[0].name_hint == "a" + assert call.args[1].fields[1].name_hint == "a" + assert call.attrs.metadata["relay_attrs"].Primitive == 1 + assert len(call.attrs.metadata["all_prim_fn_vars"]) == 1 + assert call.attrs.metadata["all_prim_fn_vars"][0].name_hint == "test_fused_add" + + test_fused_add = actual_mod["test_fused_add"] + assert isinstance(test_fused_add, tvm.tir.PrimFunc) + + +def test_lower_compiler(): + @tvm._ffi.register_func("relay.ext.test_pass_lower_te") + def relay_ext_test_pass_lower_te(func): + return None + + input_mod = tvm.parser.parse( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32]) -> Tensor[(5, 7), float32] { + %0 = fn(%x : Tensor[(5, 7), float32], %y : Tensor[(5, 7), float32], Primitive=1, Compiler="test_pass_lower_te", global_symbol="test_add") -> Tensor[(5, 7), float32] { + add(%x, %y) + }; + %0(%a, %a) + } + """, + "from_string", + None, + None, + ) + + actual_mod = transform(input_mod) + + # Expected: + # def @main(%a : Tensor[(5, 7), float32]) -> Tensor[(5, 7), float32] { + # %0 = (%a, %a) + # call_lowered(@test_add , %0, metadata={relay_attrs={Primitive=1, Compiler="test_pass_lower_te", global_symbol="test_add"}}, all_prim_fn_vars=[]}) + # } + # def @test_add(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], Extern=1) -> Tensor[(5, 7), float32] { + # add(%x, %y) + # } + + main = actual_mod["main"] + call = main.body + assert call.op.name == "call_lowered" + assert len(call.args) == 2 + assert call.args[0].name_hint == "test_add" + assert len(call.args[1].fields) == 2 + assert call.args[1].fields[0].name_hint == "a" + assert call.args[1].fields[1].name_hint == "a" + assert call.attrs.metadata["relay_attrs"].Primitive == 1 + assert call.attrs.metadata["relay_attrs"].Compiler == "test_pass_lower_te" + assert call.attrs.metadata["relay_attrs"].global_symbol == "test_add" + assert len(call.attrs.metadata["all_prim_fn_vars"]) == 0 + + test_add = actual_mod["test_add"] + assert isinstance(test_add, tvm.relay.Function) + assert test_add.attrs["Extern"] == 1 + + +def test_lower_extern(): + input_mod = tvm.parser.parse( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32]) -> Tensor[(5, 7), float32] { + @my_add(%a, %a) + } + def @my_add(%x : Tensor[(5, 7), float32], %y : Tensor[(5, 7), float32], Extern=1) -> Tensor[(5, 7), float32] { + add(%x, %y) + } + """, + "from_string", + None, + None, + ) + + actual_mod = transform(input_mod) + + # Expected: + # def @main(%a: Tensor[(5, 7), float32]) -> Tensor[(5, 7), float32] { + # %0 = (%a, %a); + # call_lowered(@my_add, %0, metadata={relay_attrs={Extern=1}}, all_prim_fn_vars=[]}) + # } + # def @my_add(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], Extern=1) -> Tensor[(5, 7), float32] { + # add(%x, %y) + # } + + main = actual_mod["main"] + call = main.body + assert call.op.name == "call_lowered" + assert len(call.args) == 2 + assert call.args[0].name_hint == "my_add" + assert len(call.args[1].fields) == 2 + assert call.args[1].fields[0].name_hint == "a" + assert call.args[1].fields[1].name_hint == "a" + assert call.attrs.metadata["relay_attrs"].Extern == 1 + assert len(call.attrs.metadata["all_prim_fn_vars"]) == 0 + + test_add = actual_mod["my_add"] + assert isinstance(test_add, tvm.relay.Function) + assert test_add.attrs["Extern"] == 1 + + +def test_lower_extern_with_dynamic_shape(): + input_mod = tvm.parser.parse( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32]) -> Tensor[(?, ?), float32] { + @my_dyn(%a, %a) + } + def @my_dyn(%x : Tensor[(5, 7), float32], %y : Tensor[(5, 7), float32], Extern=1) -> Tensor[(?, ?), float32] { + add(%x, %y) + } + """, + "from_string", + None, + None, + ) + + actual_mod = transform(input_mod) + + # Expected: + # def @main(%a: Tensor[(5, 7), float32]) -> Tensor[(?, ?), float32] { + # %0 = (%a, %a); + # call_lowered(@my_dyn, %0, metadata={prim_shape_fn_var='shape_func_add', relay_attrs={Extern=1}, prim_shape_fn_states=[2, 2], prim_shape_fn_num_inputs=2, all_prim_shape_fn_vars=['shape_func_add'], prim_shape_fn_num_outputs=1, all_prim_fn_vars=[]}) + # } + # def @my_dyn(%x: Tensor[(5, 7), float32] , %y: Tensor[(5, 7), float32] , Extern=1) -> Tensor[(?, ?), float32] { + # add(%x, %y) + # } + # def @shape_func_add = + + main = actual_mod["main"] + call = main.body + assert call.op.name == "call_lowered" + assert len(call.args) == 2 + assert call.args[0].name_hint == "my_dyn" + assert len(call.args[1].fields) == 2 + assert call.args[1].fields[0].name_hint == "a" + assert call.args[1].fields[1].name_hint == "a" + assert call.attrs.metadata["prim_shape_fn_var"].name_hint == "shape_func_add" + assert call.attrs.metadata["relay_attrs"].Extern == 1 + assert len(call.attrs.metadata["prim_shape_fn_states"]) == 2 + assert call.attrs.metadata["prim_shape_fn_states"][0] == 2 + assert call.attrs.metadata["prim_shape_fn_states"][1] == 2 + assert call.attrs.metadata["prim_shape_fn_num_inputs"] == 2 + assert len(call.attrs.metadata["all_prim_shape_fn_vars"]) == 1 + assert call.attrs.metadata["all_prim_shape_fn_vars"][0].name_hint == "shape_func_add" + assert call.attrs.metadata["prim_shape_fn_num_outputs"] == 1 + assert len(call.attrs.metadata["all_prim_fn_vars"]) == 0 + + my_dyn = actual_mod["my_dyn"] + assert isinstance(my_dyn, tvm.relay.Function) + assert my_dyn.attrs["Extern"] == 1 + + shape_func_add = actual_mod["shape_func_add"] + assert isinstance(shape_func_add, tvm.tir.PrimFunc) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relay/transform/test_compiler_function_utils.py b/tests/python/relay/transform/test_compiler_function_utils.py index 13e0f98e79f1..b9eb11547595 100644 --- a/tests/python/relay/transform/test_compiler_function_utils.py +++ b/tests/python/relay/transform/test_compiler_function_utils.py @@ -38,8 +38,7 @@ def make_consts(dtype, shapes): (2304,), # 1 (600, 32, 64), # 2 ], - ), - "attributes": [{"relay_attrs": None}], + ) } @@ -115,7 +114,7 @@ def expected_extern_mod(): """ #[version = "0.0.5"] def @main(%x0 : Tensor[(1600, 768), float16], %x3 : Tensor[(600, 32, 64), float16]) -> (Tensor[(1600, 2304), float16], Tensor[(600, 32, 32), float16]) { - %1 = call_lowered(@tvmgen_default_cutlass_main_0, (%x0, meta[relay.Constant][0], meta[relay.Constant][1]), metadata=meta[attributes][0]); + %1 = @tvmgen_default_cutlass_main_0(%x0, meta[relay.Constant][0], meta[relay.Constant][1]); %2 = fn(%y_3_i0: Tensor[(600, 32, 64), float16], %y_3_i1: Tensor[(600, 32, 64), float16], Inline=1, Compiler="cublas", global_symbol="tvmgen_default_cublas_main_3", Primitive=1) -> Tensor[(600, 32, 32), float16] { %6 = fn (%FunctionVar_0_01: Tensor[(600, 32, 64), float16], %FunctionVar_0_11: Tensor[(600, 32, 64), float16],