diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index 167afd2c5f78..381cfa0c9d1c 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -1064,9 +1064,8 @@ class AOTExecutorCodegen : public MixedModeVisitor { mod = transform::ToANormalForm()(mod); - IRModule lowered_mod = tec::LowerTEPass( - mod_name, - [this, workspace_byte_alignment](BaseFunc func) { + IRModule lowered_mod = + tec::LowerTE(mod_name, config_, [this, workspace_byte_alignment](BaseFunc func) { // We need to maintain the constant map for external // functions so we pass this processing function which // allows us to process each function as we lower it. @@ -1078,8 +1077,7 @@ class AOTExecutorCodegen : public MixedModeVisitor { // execute as a further pass, instead writing data to the // lowering process directly. tec::UpdateFunctionMetadata(func, this->function_metadata_, workspace_byte_alignment); - }, - config_)(mod); + })(mod); auto lowered_main = lowered_mod->Lookup("main"); auto lowered_main_func = GetRef(lowered_main.as()); diff --git a/src/relay/backend/graph_executor_codegen.cc b/src/relay/backend/graph_executor_codegen.cc index 7dba23803f8c..af426e5c71cb 100644 --- a/src/relay/backend/graph_executor_codegen.cc +++ b/src/relay/backend/graph_executor_codegen.cc @@ -217,22 +217,19 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslatorGetAttr(attr::kCompiler).defined()) { - UpdateConstants(func, ¶ms_); - } + IRModule lowered_mod = tec::LowerTE(mod_name_, config_, [this](BaseFunc func) { + // We need to maintain the constant map for external + // functions so we pass this processing function which + // allows us to process each function as we lower it. + if (func->GetAttr(attr::kCompiler).defined()) { + UpdateConstants(func, ¶ms_); + } - // TODO(@areusch, @jroesch): We should refactor this to - // execute as a further pass, instead writing data to the - // lowering process directly. - tec::UpdateFunctionMetadata(func, this->function_metadata_); - }, - config_)(mod); + // TODO(@areusch, @jroesch): We should refactor this to + // execute as a further pass, instead writing data to the + // lowering process directly. + tec::UpdateFunctionMetadata(func, this->function_metadata_); + })(mod); Optional main_func_info = lowered_mod->GetAttr("main_func_info"); diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index 9661040eab30..65a0fdc94824 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -960,8 +960,7 @@ IRModule Prepare(IRModule mod, const CompilationConfig& config) { // eta expand to support constructors in argument position. transform::EtaExpand( /*expand_constructor=*/true, /*expand_global_var=*/false), - transform::InferType(), - tec::LowerTEPass(/*module_name=*/"intrp", [](BaseFunc func) { /* no-op */ }, config)}); + transform::InferType(), tec::LowerTE(/*module_name=*/"intrp", config)}); transform::PassContext pass_ctx = transform::PassContext::Current(); With ctx(pass_ctx); diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index c78f3abd6ecc..e9491b0a8901 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -17,6 +17,76 @@ * under the License. */ +/*! + * \file relay/backend/te_compiler.cc + * \brief Manages the transition from Relay "Primitive" \p Functions to TIR \p PrimFuncs. Also + * handles invocation of external codegen. + * + * \p LowerTEPass handles the following (as a monolithic blob of code): + * + * - Most importantly, any function with the "Primitive" attribute is first converted to TE by + * \p LowerToTECompute (see te_compiler_cache.cc) using each operator's 'compute' function. + * The TE is then 'scheduled' to TIR using the 'anchor' operator's 'schedule' function. Both + * of those functions come from the \p OpStrategy returned by the Python + * 'relay.backend.lower_call' function (see te_compiler.py). + * The TIR is packed as a \p PrimFunc and introduced as a new global function. Calls to the + * original "Primitive" function are then rewritten to the form: + * \code + * call_lowered(@new_global, (... original args...), attributes) + * \endcode + * + * - The above "Primitive" function can appear: + * - As a global function + * - As a let-bound function + * - As an inline function, ie the 'op' of calls. + * In all three cases it is possible for the same "Primitive" function to be called multiple + * times, and that sharing must be respected. + * + * - "Primitive" functions must have a "global_symbol" attribute matching their desired or + * existing global name. Care is taken to ensure GlobalVars with the same name are shared. + * + * - It is possible for multiple structurally equal "Primitive" functions to appear in the same + * \p IRModule. Only one implementation should be generated, and all calls should share that + * implementation. + * + * - When later converting to DPS (see memory_alloc.cc) we must handle functions who's result + * tensor shapes depend at runtime on the input tensor shapes and/or data. + * - That dependency is first described in TE form (see \p MakeShapeFunc in + * te_compiler_cache.cc), then scheduled to yield a 'dynamic shape function' \p PrimFunc. + * This relies on each operator's "FShapeFunc" and "TShapeDataDependent" attributes. + * Since shapes are rank-1 tensors everything can be reflected back down into the regular + * TE/TIR forms. + * - Then the call_lowered attributes must record everything about the dynamic shape function + * later needed by memory_alloc.cc. We call this 'cross linking' the call with the shape + * function. + * + * - Two external codegen mechanisms are supported, both triggered by "Primitive" functions which + * also have a "Compiler" attribute bound to $compiler: + * - Function-at-a-time (old style): The primitive function is passed to the function + * registered as 'relay.ext.$compiler'. The function returns a runtime::Module which + * should return true for \p ImplementsFunction for the function's global name. That + * module is added to the IRModule's "external_mods" attributes. + * - IRModule-at-a-item (new style): The \p RelayToTIRTargetHook sub-pass looks for + * $compiler names which correspond to TargetKind names with a \p RelayToTIR attribute. + * The \p Pass bound to that attribute is run, and each such 'custom' pass can do what + * it likes, including replacing Functions with PrimFuncs, or adding new runtime::Modules + * to the IRModule's "external_mods" attribute. + * + * - Calls to functions added by external codegen are also rewritten to call_lowered form, and + * may also require cross-linking to dynamic shape functions. However, since the functions + * are/will be implemented by a runtime::Module all the Relay type information is no longer + * available. So the Relay definitions for these "Primitive" "Compiler" functions are retained + * in the \p IRModule, but marked with the "Extern" attribute to signal the function is now + * just for carrying metadata. + * + * - Some operators are handled specially: + * - 'reshape', since it's a no-op on the underlying tensor buffer, and this is handled by + * condition tests in many passes. + * - 'debug', since it's intercepted differently depending on runtimes. + * + * TODO(mbs): This desperately deserves a refactor to separate all these concerns. See Relax. + */ + #include "./te_compiler.h" #include @@ -222,7 +292,7 @@ class TECompilerImpl : public TECompilerNode { } else { // It is valid for the external codegen function to return null: // - Unit tests can use it. - // - The true compilation may have already been handled by a RelayToTIR custom hook pass + // - The true compilation may have already been handled by a RelayToTIR custom pass // on the Target's kind. The original Relay functions will be left in place so // that we can capture that their function names are now externally defined. VLOG(1) << "Note that no external runtime module was generated by external codegen '" @@ -566,100 +636,128 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { return itr->second; } } else if (const auto* function_node = expr.as()) { - if (!function_node->HasNonzeroAttr(attr::kPrimitive)) { - // Not marked as primitive by FuseOps. - return {}; - } - if (const auto* call_node = function_node->body.as()) { - if (call_node->op == debug_op_) { - // Debug 'primitives' are not lowered. - return {}; + if (function_node->HasNonzeroAttr(attr::kExtern)) { + // We have a regular call to an 'extern' function. The call itself needs to be rewritten + // to call_lowered form, and any required dynamic shape functions generated and + // cross-linked. + return GetRef(function_node); + } else if (function_node->HasNonzeroAttr(attr::kPrimitive)) { + if (const auto* call_node = function_node->body.as()) { + if (call_node->op == debug_op_) { + // Debug 'primitives' are not lowered. + return {}; + } } + // We have a regular call to a 'primitive' function (possibly with a 'Compiler' attribute). + // We need to lower and rewrite the call. + return GetRef(function_node); + } else { + // Not marked as primitive during partitioning or TVM fusion. + return {}; } - return GetRef(function_node); } else { return {}; } } /*! - * \brief Lowers the primitive function \p func to TIR for ultimate execution - * on a device with configuration \p target. Returns the global var bound - * to the TIR implementation, and attributes to attach to the call to identify it as - * a TIR call. + * \brief Returns a 'call_lowered' call to \p prim_fn_var with \p args and \p span with all the + * required attributes filled in. Generally \p prim_fn_var will correspond to the lowered or + * externally codegen-ed form of \p original_function, where \p lowered_functions binds all + * the required lowered functions. + * + * The call's attributes will capture: + * - Any attributes on the original_function. + * - All the lowered functions. + * TODO(mbs): Pretty sure that's no longer needed. + * - Details needed to cross-link the call to it's dynamic shape function, if any. */ - Expr MakeLoweredCall(Function func, Array visited_args, Span span, Target target) { - CCacheKey key = CCacheKey(func, target); - CachedFunc cfunc = compiler_->Lower(key, module_name_); - ICHECK(cfunc.defined()); - - auto opt_compiler = func->GetAttr(attr::kCompiler); + Expr MakeLoweredCall(const BaseFunc& original_function, const GlobalVar& prim_fn_var, + Array args, Span span, const Target& target, + const Map& lowered_functions) { + auto opt_compiler = original_function->GetAttr(attr::kCompiler); // Add some metadata on top of the *original function* and invoke the callback so it can // be captured. // TODO(@areusch, @jroesch): this metadata is for AOT, this should be our interface for AOT Map prim_fns; Array all_prim_fn_vars; - for (const auto& kv : cfunc->funcs->functions) { + for (const auto& kv : lowered_functions) { if (opt_compiler) { - // We expect just the original func but with just the ExternalSymbol attribute signaling - // the function (will be) compiled externally. + // We expect the original function to have just the "Extern" attribute signaling the + // function (will be) compiled externally. ICHECK(kv.second.as()) << PrettyPrint(kv.first) << " must be bound to an (external) Function"; } else { - // We expect one or more PrimFuncs, one of which corresponds to 'the' lowered primitive - // (and the rest in support of that via tir::Calls). + // We expect one or more PrimFuncs, one of which corresponds to 'the' lowered primitive, + // and the rest are in support of that via tir::Calls. ICHECK(kv.second.as()) << PrettyPrint(kv.first) << " must be bound to a PrimFunc"; prim_fns.Set(kv.first, Downcast(kv.second)); all_prim_fn_vars.push_back(kv.first); } } - Function func_with_metadata = func; - func_with_metadata = WithAttr(func_with_metadata, "prim_fn_var", cfunc->prim_fn_var); - func_with_metadata = WithAttr(func_with_metadata, "prim_funcs", prim_fns); - func_with_metadata = WithAttr(func_with_metadata, tvm::attr::kTarget, cfunc->target); - this->process_fn_(func_with_metadata); + // Alas, WithAttr cannot work with base classes. + if (const auto* prim_func_node = original_function.as()) { + auto func_with_metadata = GetRef(prim_func_node); + func_with_metadata = WithAttr(func_with_metadata, "prim_fn_var", prim_fn_var); + func_with_metadata = WithAttr(func_with_metadata, "prim_funcs", prim_fns); + func_with_metadata = WithAttr(func_with_metadata, tvm::attr::kTarget, target); + this->process_fn_(func_with_metadata); + } else { + const auto* function_node = original_function.as(); + ICHECK(function_node); + auto func_with_metadata = GetRef(function_node); + func_with_metadata = WithAttr(func_with_metadata, "prim_fn_var", prim_fn_var); + func_with_metadata = WithAttr(func_with_metadata, "prim_funcs", prim_fns); + func_with_metadata = WithAttr(func_with_metadata, tvm::attr::kTarget, target); + this->process_fn_(func_with_metadata); + } + + // Now prepare the attributes of the call_lowered. CallLoweredAttrs call_lowered_attrs; - // Non-External Relay Function // TODO(mbs): "reshape" cleanup. - if (!opt_compiler && func->HasNonzeroAttr(attr::kReshapeOnly)) { + if (!opt_compiler && original_function->HasNonzeroAttr(attr::kReshapeOnly)) { call_lowered_attrs.metadata.Set(attr::kReshapeOnly, tvm::Integer(1)); } - call_lowered_attrs.metadata.Set("relay_attrs", func->attrs); + call_lowered_attrs.metadata.Set("relay_attrs", original_function->attrs); call_lowered_attrs.metadata.Set("all_prim_fn_vars", all_prim_fn_vars); - if (IsDynamic(func->ret_type)) { - // Also lower the companion dynamic shape function. - // Shape function keys use the underlying primitive function as their 'function', - // but the generic 'cpu' target as the target since all shape functions run - // on the host cpu irrespective of where the primitive runs. - CCacheKey shape_key(func, config_->host_virtual_device->target); - CachedFunc lowered_shape_func = compiler_->LowerShapeFunc(shape_key); - - // Capture the shape function's global var and parameters 'states' in call - // annotations so calling convention can be recovered. - // TODO(mbs): Shape cleanup. - call_lowered_attrs.metadata.Set("prim_shape_fn_var", lowered_shape_func->prim_fn_var); - call_lowered_attrs.metadata.Set("prim_shape_fn_states", - lowered_shape_func->shape_func_param_states); - call_lowered_attrs.metadata.Set("prim_shape_fn_num_inputs", - Integer(static_cast(lowered_shape_func->inputs.size()))); - call_lowered_attrs.metadata.Set( - "prim_shape_fn_num_outputs", - Integer(static_cast(lowered_shape_func->outputs.size()))); - Array all_prim_shape_fn_vars; - for (const auto& kv : lowered_shape_func->funcs->functions) { - CHECK(kv.second.as()) << "must be a prim fn"; - all_prim_shape_fn_vars.push_back(kv.first); + if (const auto* function_node = original_function.as()) { + if (IsDynamic(function_node->ret_type)) { + // Create a dynamic shape function to calculate the expected shape of the results of + // the lowered function. + // Shape function keys use the original function as their 'function', but the generic 'cpu' + // target as the target since all shape functions run on the host cpu irrespective of where + // the primitive runs. + CCacheKey shape_key(GetRef(function_node), config_->host_virtual_device->target); + CachedFunc lowered_shape_func = compiler_->LowerShapeFunc(shape_key); + + // Capture the shape function's global var and parameters 'states' in call + // annotations so calling convention can be recovered. + // TODO(mbs): Shape cleanup. + call_lowered_attrs.metadata.Set("prim_shape_fn_var", lowered_shape_func->prim_fn_var); + call_lowered_attrs.metadata.Set("prim_shape_fn_states", + lowered_shape_func->shape_func_param_states); + call_lowered_attrs.metadata.Set( + "prim_shape_fn_num_inputs", + Integer(static_cast(lowered_shape_func->inputs.size()))); + call_lowered_attrs.metadata.Set( + "prim_shape_fn_num_outputs", + Integer(static_cast(lowered_shape_func->outputs.size()))); + Array all_prim_shape_fn_vars; + for (const auto& kv : lowered_shape_func->funcs->functions) { + CHECK(kv.second.as()) << "must be a prim fn"; + all_prim_shape_fn_vars.push_back(kv.first); + } + call_lowered_attrs.metadata.Set("all_prim_shape_fn_vars", all_prim_shape_fn_vars); } - call_lowered_attrs.metadata.Set("all_prim_shape_fn_vars", all_prim_shape_fn_vars); } - return CallLowered(cfunc->prim_fn_var, std::move(visited_args), std::move(call_lowered_attrs), + return CallLowered(prim_fn_var, std::move(args), std::move(call_lowered_attrs), std::move(span)); } @@ -697,43 +795,51 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { } Expr DeviceAwareVisitExpr_(const CallNode* call_node) override { - // We can see five forms of calls: - // 1. A 'normal' Relay call to a Function with the "primitive" attribute. We will need - // to lower that to a global PrimFunc and rewrite the call to: + // We can see six forms of calls: + // 1. A 'normal' Relay call to a Function with the "Primitive" attribute and not "Compiler" + // attribute. We will need to lower that to a global PrimFunc and rewrite the call to: // call_lowered(@new_global, (arg1, ..., argn), ) - // However there are a few special forms which are excluded from this treatment, see - // below. - // 2. A 'normal' Relay call to a Function with the "compiler" attribute. We will need - // to invoke the appropriate BYOC toolchain function to yield a runtime module and - // rewrite the call to the same form as above. - // 3. A 'normal' Relay call to a PrimFunc which has already been supplied via a global - // definition. We rewrite to use the call_lowered form, but otherwise nothing else + // If needed, the call needs to be cross-linked with any dynamic shape functions. + // (However, some primitives are special and handled separately.) + // 2. A 'normal' Relay call to a Function with the "Primitive" and "Compiler" attributes. We + // will need to invoke the "relay.ext." function to yield a runtime module, and + // rewrite the call to the same form as above. Dynamic shape function cross-linking may + // also be needed. + // 3. A 'normal' Relay call to a Function with the "Extern" attribute. This function has + // already been compiled by an external codegen and a definition for it exists in some + // runtime module. Again, we rewrite to call_lowered form, and cross-link with a dynamic + // shape function if needed. + // 4. A 'normal' Relay call to a PrimFunc which has already been supplied via a global + // definition. We rewrite those to use the call_lowered form, but otherwise nothing else // needs to be done. - // 4. A 'normal' Relay call to a Relay Function without any special attribute. These + // 5. A 'call_lowered' call from an earlier invocation of this pass or otherwise deliberately + // inserted. It has all the required attributes, and any associated dynamic shape function + // has been generated and cross-linked. These calls are not changed. + // 6. A 'normal' Relay call to a Relay Function without any special attribute. These // calls are not changed. - // 5. A call_lowered call from an earlier invocation of this pass. - // Note that ResolveToPrimitive will yield non-null only for cases 1-3. + // + // Note that ResolveToPrimitive will yield non-null only for cases 1-4. + + // Prepare the arguments and op. + Array new_args; + for (const auto& arg : call_node->args) { + new_args.push_back(VisitExpr(arg)); + } + Expr new_op = VisitExpr(call_node->op); // Look for (possibly indirect) calls to primitives. BaseFunc primitive_func = ResolveToPrimitive(call_node->op); if (!primitive_func.defined()) { - // Not a call to a primitive function we need to rewrite. + // Cases 5 and 6: Leave as ordinary call. if (const auto* function_node = call_node->op.as()) { process_fn_(GetRef(function_node)); } - return DeviceAwareExprMutator::DeviceAwareVisitExpr_(call_node); - } - - // Prepare the arguments. - Array new_args; - for (const auto& arg : call_node->args) { - new_args.push_back(VisitExpr(arg)); + return WithFields(GetRef(call_node), std::move(new_op), std::move(new_args)); } - // Special case: device_copies are left as calls to primitive operators - // (thus undoing FuseOps) so that each backend can handle them directly. - // TODO(mbs): device_copy cleanup. Would be better for FuseOps to just leave device_copy - // alone. + // Special case for case 1: device_copies are left as calls to primitive operators + // so that each backend can handle them directly. + // TODO(mbs): device_copy cleanup. Would be better for FuseOps to just leave device_copy alone. if (const auto* function_node = primitive_func.as()) { DeviceCopyProps device_copy_props = GetDeviceCopyProps(function_node->body); if (device_copy_props.body.defined()) { @@ -743,33 +849,23 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { } } - // Special case: If already lowered by other means then so we don't need to mutate - // the call but we do need to mutate the arguments + ICHECK(call_node->type_args.empty()) << "lowered functions cannot be polymorphic"; + + // Case 4: If the function has already been lowered we just need to update the call. if (const auto* prim_func_node = primitive_func.as()) { // Function should already be Target annotated by this point // but the TE Compiler metadata is still needed for the callback // TODO(Mousius) - Robustify this to not assume we're in the GlobalVar for Target Hooks - GlobalVar prim_func_var = Downcast(call_node->op); + Optional opt_target = primitive_func->GetAttr(tvm::attr::kTarget); + ICHECK(opt_target.defined()); + auto prim_fn_var = Downcast(call_node->op); tir::PrimFunc prim_func = GetRef(prim_func_node); - - Map prim_fns = {{prim_func_var, prim_func}}; - tir::PrimFunc func_with_metadata = WithAttrs(prim_func, { - {"prim_fn_var", prim_func_var}, - {"prim_funcs", prim_fns}, - }); - - ICHECK(!IsDynamic(call_node->checked_type())); - CallLoweredAttrs call_lowered_attrs; - call_lowered_attrs.metadata.Set("relay_attrs", primitive_func->attrs); - - process_fn_(func_with_metadata); - ICHECK(call_node->type_args.empty()) << "lowered functions cannot be polymorphic"; - return CallLowered(prim_func_var, std::move(new_args), std::move(call_lowered_attrs), - call_node->span); + Map prim_fns = {{prim_fn_var, prim_func}}; + return MakeLoweredCall(primitive_func, prim_fn_var, std::move(new_args), call_node->span, + opt_target.value(), prim_fns); } - // Typical case: call to fused primitive Relay Function. - // Find the desired target device. + // Determine the target for lowering or external codegen. Target target; Optional opt_compiler = primitive_func->GetAttr(attr::kCompiler); if (opt_compiler.defined()) { @@ -791,10 +887,20 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { ICHECK(target.defined()); } - // Lower the primitive function for that target. - Function function = Downcast(primitive_func); - ICHECK(call_node->type_args.empty()) << "lowered functions cannot be polymorphic"; - return MakeLoweredCall(function, std::move(new_args), call_node->span, target); + if (primitive_func->HasNonzeroAttr(attr::kExtern)) { + // Case 3: Function has already been compiled. + GlobalVar prim_fn_var = Downcast(call_node->op); + return MakeLoweredCall(primitive_func, prim_fn_var, std::move(new_args), call_node->span, + target, /*lowered_functions=*/{}); + } else { + // Cases 1 and 2: lower the primitive function for the desired target, possibly using external + // codegen. + CCacheKey key(Downcast(primitive_func), target); + CachedFunc cfunc = compiler_->Lower(key, module_name_); + ICHECK(cfunc.defined()); + return MakeLoweredCall(primitive_func, cfunc->prim_fn_var, std::move(new_args), + call_node->span, target, cfunc->funcs->functions); + } } IRModule module_; @@ -1046,6 +1152,7 @@ void UpdateFunctionMetadata(BaseFunc func, function_metadata.Set(prim_fn_var.value()->name_hint, fi); } +/*! \brief Main lowering driving. */ IRModule LowerTE(const IRModule& module, const String& module_name, ProcessFn process_fn, CompilationConfig config) { TECompiler compiler(module); @@ -1163,7 +1270,7 @@ Map GetPerTargetModules(IRModule mod) { return per_target_modules; } -Pass LowerTEPass(String module_name, ProcessFn process_fn, CompilationConfig complilation_config) { +Pass LowerTE(String module_name, CompilationConfig complilation_config, ProcessFn process_fn) { runtime::TypedPackedFunc pass_func = [=](IRModule module, PassContext ctx) { return LowerTE(module, module_name, process_fn, complilation_config); @@ -1174,6 +1281,12 @@ Pass LowerTEPass(String module_name, ProcessFn process_fn, CompilationConfig com tvm::transform::CreateModulePass(pass_func, 0, "LowerTE", {"InferType"}), InferType(), tvm::tir::transform::ExtractPrimFuncConstants()}); } + +TVM_REGISTER_GLOBAL("relay.tec.LowerTE") + .set_body_typed([](String module_name, CompilationConfig compilation_config) { + return LowerTE(std::move(module_name), std::move(compilation_config)); + }); + } // namespace tec } // namespace relay } // namespace tvm diff --git a/src/relay/backend/te_compiler.h b/src/relay/backend/te_compiler.h index 8312a20cb862..5d16da4b8bb2 100644 --- a/src/relay/backend/te_compiler.h +++ b/src/relay/backend/te_compiler.h @@ -18,8 +18,8 @@ */ /*! - * \file relay/backend/tir_compiler.h - * * \brief Internal compilation layer which lowers Relay "primitive functions" to TIR PrimFns. + * \file relay/backend/te_compiler.h + * \brief Internal compilation layer which lowers Relay "primitive functions" to TIR PrimFns. * * * This represents the new design of the Relay compilation flow and will replace the interface @@ -173,36 +173,22 @@ backend::FunctionInfo UpdateMainWorkspaceSize(const IRModule& mod, const Compila */ Map GetPerTargetModules(IRModule mod); -/*! \brief Lower an IRModule's primitive functions to TIR. - * - * This is the "back half" of the Relay compiler which lowers "primitive functions" - * to TE expressions, schedules them, and then to TIR. - * - * \param module The IRModule. - * \param memory_plan The memory plan used during lowering - * \param module_name The name of this module - * \param process_fn Callback allowing one-level up code generators to process - * each function that we lower - * \return The lowered module, see above. - */ -IRModule LowerTE( - const IRModule& module, backend::StaticMemoryPlan memory_plan, const String& module_name, - ProcessFn process_fn = [](BaseFunc f) {}); +inline void DefaultProcessFn(BaseFunc) {} /*! * \brief Pass to lower an IRModule's primitive functions to TIR. * * This is the "back half" of the Relay compiler which lowers "primitive functions" - * to TE expressions, schedules them, and then to TIR. It annotates all functions - * with their target. + * to TE expressions, schedules them, and emits PrimFuncs. * - * \param module_name The name of this module - * \param process_fn Callback allowing one-level up code generators to process - * each function that we lower + * \param module_name The name of this module, used as a prefix for generated globals. * \param config All available targets. + * \param process_fn Callback allowing one-level up code generators to process + * each function that we lower (default is no-op). * \returns The pass which lowers primitive functions to TIR */ -transform::Pass LowerTEPass(String module_name, ProcessFn process_fn, CompilationConfig config); +transform::Pass LowerTE(String module_name, CompilationConfig config, + ProcessFn process_fn = DefaultProcessFn); } // namespace tec } // namespace relay diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index d9730b1b5a4c..9c0780e6e43f 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -1039,13 +1039,11 @@ transform::Sequential VMCompiler::FuseAndLowerOperators(const CompilationConfig& // Give each "primitive" Function a hash. pass_seqs.push_back(LabelOps()); // Lower "primitive" Functions to PrimFuncs and rewrite calls. - pass_seqs.push_back(tec::LowerTEPass(/*module_name=*/"vm_mod", - [this](const BaseFunc& func) { - if (func->GetAttr(attr::kCompiler).defined()) { - backend::UpdateConstants(func, ¶ms_); - } - }, - config)); + pass_seqs.push_back(tec::LowerTE(/*module_name=*/"vm_mod", config, [this](const BaseFunc& func) { + if (func->GetAttr(attr::kCompiler).defined()) { + backend::UpdateConstants(func, ¶ms_); + } + })); // Since lowered functions are bound in the IRModule, we can now eliminate any unused // let-bound functions. pass_seqs.push_back(DeadCodeElimination(/*inline_once=*/false)); @@ -1090,13 +1088,11 @@ IRModule VMCompiler::OptimizeModuleImpl(IRModule mod) { pass_seqs.push_back(transform::LabelOps()); // Lower all functions annotated as "primitive" by FuseOps. - pass_seqs.push_back(tec::LowerTEPass(/*module_name=*/"vm_mod", - [this](const BaseFunc& func) { - if (func->GetAttr(attr::kCompiler).defined()) { - backend::UpdateConstants(func, ¶ms_); - } - }, - config_)); + pass_seqs.push_back(tec::LowerTE(/*module_name=*/"vm_mod", config_, [this](const BaseFunc& func) { + if (func->GetAttr(attr::kCompiler).defined()) { + backend::UpdateConstants(func, ¶ms_); + } + })); // Since lowered functions are bound in the IRModule, we can now eliminate any unused // let-bound functions. diff --git a/src/relay/transforms/compiler_function_utils.cc b/src/relay/transforms/compiler_function_utils.cc index f22e9bd80dd0..3df07e4c57f5 100644 --- a/src/relay/transforms/compiler_function_utils.cc +++ b/src/relay/transforms/compiler_function_utils.cc @@ -81,42 +81,6 @@ class Outliner : public MixedModeMutator { IRModule mod_; }; -/*! - * \brief Rewrite calls to global "Compiler" functions to use the 'call_lowered' convention. - */ -class CallRewriter : public MixedModeMutator { - public: - CallRewriter(std::string compiler_filter, IRModule mod) - : compiler_filter_(std::move(compiler_filter)), mod_(std::move(mod)) {} - - Expr Rewrite_(const CallNode* pre, const Expr& post) final { - Call new_call = Downcast(post); - if (const auto* global_var_node = new_call->op.as()) { - if (const auto* function_node = - mod_->Lookup(GetRef(global_var_node)).as()) { - Optional opt_compiler = function_node->GetAttr(attr::kCompiler); - if (opt_compiler.defined() && - (compiler_filter_.empty() || opt_compiler.value() == compiler_filter_)) { - Optional opt_global_symbol = - function_node->GetAttr(tvm::attr::kGlobalSymbol); - ICHECK(opt_global_symbol.defined()); - GlobalVar global_symbol = mod_->GetGlobalVar(opt_global_symbol.value()); - CallLoweredAttrs attrs; - attrs.metadata.Set("relay_attrs", new_call->attrs); - return CallLowered(global_symbol, new_call->args, attrs, new_call->span); - } - } - } - return post; - } - - private: - /*! \brief If non-empty, the "Compiler" attribute value to require on functions to outline. */ - std::string compiler_filter_; - /*! \brief Module being rewritten. */ - IRModule mod_; -}; - } // namespace GlobalSymbolCache::~GlobalSymbolCache() = default; @@ -169,20 +133,6 @@ transform::Pass MarkCompilerFunctionsAsExtern(std::string compiler_filter) { runtime::TypedPackedFunc pass_func = [compiler_filter = std::move(compiler_filter)](IRModule mod, transform::PassContext ctx) { IRModule output_mod = mod->ShallowCopy(); - - // First pass, rewrite the calls. - // We have to do this before marking functions as 'extern' to know which calls to rewrite! - for (const auto& kv : mod->functions) { - if (const auto* function_node = AsOptimizableFunctionNode(kv.second)) { - Expr new_body = - CallRewriter(compiler_filter, output_mod).VisitExpr(function_node->body); - Function new_function = - WithFields(GetRef(function_node), /*opt_params=*/{}, new_body); - output_mod->Update(kv.first, new_function); - } - } - - // Second pass, mark functions as 'extern'. for (const auto& kv : mod->functions) { if (const auto* function_node = kv.second.as()) { Optional opt_compiler = function_node->GetAttr(attr::kCompiler); @@ -197,7 +147,6 @@ transform::Pass MarkCompilerFunctionsAsExtern(std::string compiler_filter) { } } } - return output_mod; }; diff --git a/src/relay/transforms/compiler_function_utils.h b/src/relay/transforms/compiler_function_utils.h index e4b1f05211fe..9d1dcd9f21a2 100644 --- a/src/relay/transforms/compiler_function_utils.h +++ b/src/relay/transforms/compiler_function_utils.h @@ -43,11 +43,8 @@ * * - \p MarkCompilerFunctionsAsExtern will replace global functions with a matching "Compiler" * attribute with the same function with just an "Extern" attribute, signalling the function - * has been dealt with. Calls to such functions will be rewritten to use the 'call_lowered' - * calling convention. Can be used after lowering to cleanup the IRModule. - * - * Note that the above behaviour is hard coded within the TECompiler, but is only available to - * external codegen using the Function-at-a-time "relay.ext.toolchain" extension point. + * has been dealt with. However calls to such functions will be left unchanged. Can be used + * after lowering to cleanup the IRModule. */ #ifndef TVM_RELAY_TRANSFORMS_COMPILER_FUNCTION_UTILS_H_ @@ -118,8 +115,8 @@ transform::Pass OutlineCompilerFunctionsWithExistingGlobalSymbols(std::string co /*! * \brief A pass to mark all global functions which have a "Compiler" attribute matching - * compiler_filter as 'extern' by replacing all attributes with a single "Extern" attribute, and - * rewrite all calls to such functions to use the 'call_lowered' calling convention. + * compiler_filter as 'extern' by replacing all attributes with a single "Extern" attribute. + * Calls to such functions are not changed. * * If \p compiler_filter is non-empty only functions with that as their attribute value are * outlined. diff --git a/tests/python/relay/backend/test_pass_lower_te.py b/tests/python/relay/backend/test_pass_lower_te.py new file mode 100644 index 000000000000..310a16e269e0 --- /dev/null +++ b/tests/python/relay/backend/test_pass_lower_te.py @@ -0,0 +1,241 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Exercises the LowerTE pass. + +import tvm +import tvm.testing +import logging + +logging.basicConfig() +logger = logging.getLogger("test_pass_lower_te") +logger.setLevel(logging.INFO) + +# Since the TE compiler needs a good refactor it has not been exposed as a 'standard' pass +# in relay.transform. For testing grab it directly. +LowerTE = tvm._ffi.get_global_func("relay.tec.LowerTE") + + +def transform(mod): + logger.info("Starting module:\n%s", mod) + host_target = tvm.target.Target("llvm") + prim_target = tvm.target.Target("llvm", host=host_target) + ctxt = tvm.transform.PassContext() + config = tvm.target.make_compilation_config(ctxt, prim_target) + mod = tvm.relay.transform.PlanDevices(config)(mod) + mod = tvm.relay.transform.InferType()(mod) + mod = LowerTE("test", config)(mod) + mod = tvm.relay.transform.InferType()(mod) + logger.info("After LowerTE:\n%s", mod) + return mod + + +# All attempts to use structural equalty tests against an expected IRModule parsed from +# Relay text were thwarted by the difficulty of setting up the expected call_lower attributes +# with the right GlobalVar instances. So the following assert structural correctness the hard way. + + +def test_lower_primitive(): + input_mod = tvm.parser.parse( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32]) -> Tensor[(5, 7), float32] { + %0 = fn(%x : Tensor[(5, 7), float32], %y : Tensor[(5, 7), float32], Primitive=1) -> Tensor[(5, 7), float32] { + add(%x, %y) + }; + %0(%a, %a) + } + """, + "from_string", + None, + None, + ) + + actual_mod = transform(input_mod) + + # Expected: + # def @main(%a: Tensor[(5, 7), float32]) -> Tensor[(5, 7), float32] { + # %0 = (%a, %a); + # call_lowered(@test_fused_add, %0, metadata={relay_attrs={Primitive=1},all_prim_fn_vars=[@test_fused_add]}) + # } + # def @test_fused_add = + + main = actual_mod["main"] + call = main.body + assert call.op.name == "call_lowered" + assert len(call.args) == 2 + assert call.args[0].name_hint == "test_fused_add" + assert len(call.args[1].fields) == 2 + assert call.args[1].fields[0].name_hint == "a" + assert call.args[1].fields[1].name_hint == "a" + assert call.attrs.metadata["relay_attrs"].Primitive == 1 + assert len(call.attrs.metadata["all_prim_fn_vars"]) == 1 + assert call.attrs.metadata["all_prim_fn_vars"][0].name_hint == "test_fused_add" + + test_fused_add = actual_mod["test_fused_add"] + assert isinstance(test_fused_add, tvm.tir.PrimFunc) + + +def test_lower_compiler(): + @tvm._ffi.register_func("relay.ext.test_pass_lower_te") + def relay_ext_test_pass_lower_te(func): + return None + + input_mod = tvm.parser.parse( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32]) -> Tensor[(5, 7), float32] { + %0 = fn(%x : Tensor[(5, 7), float32], %y : Tensor[(5, 7), float32], Primitive=1, Compiler="test_pass_lower_te", global_symbol="test_add") -> Tensor[(5, 7), float32] { + add(%x, %y) + }; + %0(%a, %a) + } + """, + "from_string", + None, + None, + ) + + actual_mod = transform(input_mod) + + # Expected: + # def @main(%a : Tensor[(5, 7), float32]) -> Tensor[(5, 7), float32] { + # %0 = (%a, %a) + # call_lowered(@test_add , %0, metadata={relay_attrs={Primitive=1, Compiler="test_pass_lower_te", global_symbol="test_add"}}, all_prim_fn_vars=[]}) + # } + # def @test_add(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], Extern=1) -> Tensor[(5, 7), float32] { + # add(%x, %y) + # } + + main = actual_mod["main"] + call = main.body + assert call.op.name == "call_lowered" + assert len(call.args) == 2 + assert call.args[0].name_hint == "test_add" + assert len(call.args[1].fields) == 2 + assert call.args[1].fields[0].name_hint == "a" + assert call.args[1].fields[1].name_hint == "a" + assert call.attrs.metadata["relay_attrs"].Primitive == 1 + assert call.attrs.metadata["relay_attrs"].Compiler == "test_pass_lower_te" + assert call.attrs.metadata["relay_attrs"].global_symbol == "test_add" + assert len(call.attrs.metadata["all_prim_fn_vars"]) == 0 + + test_add = actual_mod["test_add"] + assert isinstance(test_add, tvm.relay.Function) + assert test_add.attrs["Extern"] == 1 + + +def test_lower_extern(): + input_mod = tvm.parser.parse( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32]) -> Tensor[(5, 7), float32] { + @my_add(%a, %a) + } + def @my_add(%x : Tensor[(5, 7), float32], %y : Tensor[(5, 7), float32], Extern=1) -> Tensor[(5, 7), float32] { + add(%x, %y) + } + """, + "from_string", + None, + None, + ) + + actual_mod = transform(input_mod) + + # Expected: + # def @main(%a: Tensor[(5, 7), float32]) -> Tensor[(5, 7), float32] { + # %0 = (%a, %a); + # call_lowered(@my_add, %0, metadata={relay_attrs={Extern=1}}, all_prim_fn_vars=[]}) + # } + # def @my_add(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], Extern=1) -> Tensor[(5, 7), float32] { + # add(%x, %y) + # } + + main = actual_mod["main"] + call = main.body + assert call.op.name == "call_lowered" + assert len(call.args) == 2 + assert call.args[0].name_hint == "my_add" + assert len(call.args[1].fields) == 2 + assert call.args[1].fields[0].name_hint == "a" + assert call.args[1].fields[1].name_hint == "a" + assert call.attrs.metadata["relay_attrs"].Extern == 1 + assert len(call.attrs.metadata["all_prim_fn_vars"]) == 0 + + test_add = actual_mod["my_add"] + assert isinstance(test_add, tvm.relay.Function) + assert test_add.attrs["Extern"] == 1 + + +def test_lower_extern_with_dynamic_shape(): + input_mod = tvm.parser.parse( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32]) -> Tensor[(?, ?), float32] { + @my_dyn(%a, %a) + } + def @my_dyn(%x : Tensor[(5, 7), float32], %y : Tensor[(5, 7), float32], Extern=1) -> Tensor[(?, ?), float32] { + add(%x, %y) + } + """, + "from_string", + None, + None, + ) + + actual_mod = transform(input_mod) + + # Expected: + # def @main(%a: Tensor[(5, 7), float32]) -> Tensor[(?, ?), float32] { + # %0 = (%a, %a); + # call_lowered(@my_dyn, %0, metadata={prim_shape_fn_var='shape_func_add', relay_attrs={Extern=1}, prim_shape_fn_states=[2, 2], prim_shape_fn_num_inputs=2, all_prim_shape_fn_vars=['shape_func_add'], prim_shape_fn_num_outputs=1, all_prim_fn_vars=[]}) + # } + # def @my_dyn(%x: Tensor[(5, 7), float32] , %y: Tensor[(5, 7), float32] , Extern=1) -> Tensor[(?, ?), float32] { + # add(%x, %y) + # } + # def @shape_func_add = + + main = actual_mod["main"] + call = main.body + assert call.op.name == "call_lowered" + assert len(call.args) == 2 + assert call.args[0].name_hint == "my_dyn" + assert len(call.args[1].fields) == 2 + assert call.args[1].fields[0].name_hint == "a" + assert call.args[1].fields[1].name_hint == "a" + assert call.attrs.metadata["prim_shape_fn_var"].name_hint == "shape_func_add" + assert call.attrs.metadata["relay_attrs"].Extern == 1 + assert len(call.attrs.metadata["prim_shape_fn_states"]) == 2 + assert call.attrs.metadata["prim_shape_fn_states"][0] == 2 + assert call.attrs.metadata["prim_shape_fn_states"][1] == 2 + assert call.attrs.metadata["prim_shape_fn_num_inputs"] == 2 + assert len(call.attrs.metadata["all_prim_shape_fn_vars"]) == 1 + assert call.attrs.metadata["all_prim_shape_fn_vars"][0].name_hint == "shape_func_add" + assert call.attrs.metadata["prim_shape_fn_num_outputs"] == 1 + assert len(call.attrs.metadata["all_prim_fn_vars"]) == 0 + + my_dyn = actual_mod["my_dyn"] + assert isinstance(my_dyn, tvm.relay.Function) + assert my_dyn.attrs["Extern"] == 1 + + shape_func_add = actual_mod["shape_func_add"] + assert isinstance(shape_func_add, tvm.tir.PrimFunc) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relay/transform/test_compiler_function_utils.py b/tests/python/relay/transform/test_compiler_function_utils.py index 13e0f98e79f1..b9eb11547595 100644 --- a/tests/python/relay/transform/test_compiler_function_utils.py +++ b/tests/python/relay/transform/test_compiler_function_utils.py @@ -38,8 +38,7 @@ def make_consts(dtype, shapes): (2304,), # 1 (600, 32, 64), # 2 ], - ), - "attributes": [{"relay_attrs": None}], + ) } @@ -115,7 +114,7 @@ def expected_extern_mod(): """ #[version = "0.0.5"] def @main(%x0 : Tensor[(1600, 768), float16], %x3 : Tensor[(600, 32, 64), float16]) -> (Tensor[(1600, 2304), float16], Tensor[(600, 32, 32), float16]) { - %1 = call_lowered(@tvmgen_default_cutlass_main_0, (%x0, meta[relay.Constant][0], meta[relay.Constant][1]), metadata=meta[attributes][0]); + %1 = @tvmgen_default_cutlass_main_0(%x0, meta[relay.Constant][0], meta[relay.Constant][1]); %2 = fn(%y_3_i0: Tensor[(600, 32, 64), float16], %y_3_i1: Tensor[(600, 32, 64), float16], Inline=1, Compiler="cublas", global_symbol="tvmgen_default_cublas_main_3", Primitive=1) -> Tensor[(600, 32, 32), float16] { %6 = fn (%FunctionVar_0_01: Tensor[(600, 32, 64), float16], %FunctionVar_0_11: Tensor[(600, 32, 64), float16],