From 864a113a497add70426b2d9b1bbb824d81fc06da Mon Sep 17 00:00:00 2001 From: Mark Shields <87091372+mbs-octoml@users.noreply.github.com> Date: Tue, 14 Dec 2021 10:16:26 -0800 Subject: [PATCH] [Relay] Re-run PlanDevices after LowerTE to flow new memory scope constraints. (#9613) * [Relay] Re-run PlanDevices after LowerTE to flow new memory scope constraints. This PR: 1) Makes PlanDevices consider lowered calls when solving device domain constraints. 2) Connects the storage scopes on PrimFunc parameters (encoded in their Buffer data Var type annotation PointerTypes storage_scope fields) to the memory_scope fields of the SEScopes which PlanDevices unifies over. 3) Allows new device_copies to be inserted on the arguments and results of lowered calls so as to acount for any memory scope mismatches which are now apparent. [device_planner.cc has main changes, rest is secondary.] In the short term we'd like to use this machinery to flow memory scope choices made during lowering back out into the overall Relay program. In the longer term we'd also like to be able to use memory scopes to influence the lowering of yet-to-be-lowered functions (or lowered functions which have yet to been scheduled, a distinction now possible with TensorIR). - Memory scope constraints can flow both out of and in to PrimFuncs introduced by LowerTE. In TIR memory scopes are represented by 'storage scopes' on the PointerType type annotations on TIR Buffer data variables. - It is straightforward to extract memory scopes from PrimFuncs by looking at the PrimFunc's buffer_map. We do this is 'phase 1' of PlanDevices, which collects all the device constraints implied by - However, pushing memory constraints in to PrimFuncs is more challenging due to buffer aliasing. This aspect is still experimental. - Allow device_copies to be inserted for both arguments and results of PrimFunc calls, on the assumption PlanDevices has already established a consistent device assignment prior to lowering and any new mismatch is required to match up memory scopes. We use the new 'free' on_device annotations to implement this. Coming along for the ride: - To make unit tests of mixed Relay/TIR functions possible needed to be able to supply a checked_type to GlobalVar since that's currently the only way to give a Relay type to PrimFuncs. - Use GenSym to get unique var names in ANF & partial eval so easier to diff debug output between passes and connect program fragments back into the overall program. Relying on pretty-printing to automagically unique-ify var names is certainly cute but until we have better span support is very hard to work with. - Realized both dead_code.cc and fold_constant.cc would happily move values into a different lexical virtual device context since device_planner.cc was being 'clever' and eliding on_devices for let-bound values when there's no change. Fixed so that every let-bound value has an on_device. Will be much better after https://github.com/apache/tvm-rfcs/pull/45 is implemented. - Make build -Werror clean for clang-12 (mostly move fixups). - Address post-submit comments from #9693. * [checkpoint] thread safe GenSym --- include/tvm/ir/expr.h | 2 +- include/tvm/relay/attrs/on_device.h | 2 +- include/tvm/relay/expr.h | 9 ++ include/tvm/target/se_scope.h | 8 +- python/tvm/ir/expr.py | 4 +- src/ir/expr.cc | 7 +- .../contrib/cmsisnn/extract_constants.cc | 4 +- .../contrib/cmsisnn/generate_constants.cc | 2 +- .../backend/contrib/cmsisnn/tir_to_runtime.cc | 4 +- src/relay/backend/graph_executor_codegen.cc | 3 +- src/relay/backend/vm/compiler.cc | 4 + src/relay/ir/adt.cc | 4 +- src/relay/ir/expr.cc | 49 ++++--- src/relay/ir/function.cc | 2 +- src/relay/op/memory/on_device.cc | 4 +- src/relay/op/memory/on_device.h | 9 +- src/relay/transforms/device_domains.cc | 12 +- src/relay/transforms/device_planner.cc | 123 +++++++++++++++++- src/relay/transforms/let_list.h | 2 +- src/relay/transforms/partial_eval.cc | 4 +- src/relay/transforms/to_a_normal_form.cc | 2 +- .../contrib/verilator/verilator_runtime.h | 2 +- .../lower_cross_thread_reduction.cc | 4 +- tests/python/relay/test_pass_plan_devices.py | 105 +++++++++++++++ 24 files changed, 316 insertions(+), 55 deletions(-) diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index d33606676944..a6e5c8de73a7 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -245,7 +245,7 @@ class GlobalVarNode : public RelayExprNode { */ class GlobalVar : public RelayExpr { public: - TVM_DLL explicit GlobalVar(String name_hint); + TVM_DLL explicit GlobalVar(String name_hint, Type type = {}); TVM_DEFINE_OBJECT_REF_METHODS(GlobalVar, RelayExpr, GlobalVarNode); }; diff --git a/include/tvm/relay/attrs/on_device.h b/include/tvm/relay/attrs/on_device.h index b1f1e6a6dc45..0931865fa88e 100644 --- a/include/tvm/relay/attrs/on_device.h +++ b/include/tvm/relay/attrs/on_device.h @@ -65,7 +65,7 @@ struct OnDeviceAttrs : public tvm::AttrsNode { SEScope se_scope = SEScope::FullyUnconstrained(); /*! - * \brief If fales (the default), the result of the "on_device" call is not constrained to be + * \brief If false (the default), the result of the "on_device" call is not constrained to be * \p se_scope. */ bool constrain_result = false; diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index 03200d3a3dfb..8bec72490ab1 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -234,6 +234,15 @@ class Var : public Expr { */ TVM_DLL Var(Id vid, Type type_annotation, Span span = Span()); + /*! + * \brief Return a globally fresh name. Helps with debugging to follow the same + * variable between passes and sub-expressions. + * + * TODO(mbs): Replace with name creation w.r.t. scopes once available as part of + * name gen overhaul. + */ + static Var GenSym(Type type_annotation = {}, Span span = {}); + TVM_DEFINE_OBJECT_REF_METHODS(Var, RelayExpr, VarNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(VarNode); }; diff --git a/include/tvm/target/se_scope.h b/include/tvm/target/se_scope.h index ec5da3a80cae..314bf054d7ea 100644 --- a/include/tvm/target/se_scope.h +++ b/include/tvm/target/se_scope.h @@ -159,19 +159,21 @@ using MemoryScope = String; * */ class SEScopeNode : public AttrsNode { - public: + private: /*! - * \brief The \p DLDeviceType (represtented as an int) of the virtual device. If \p target is + * \brief The \p DLDeviceType (represented as an int) of the virtual device. If \p target is * known then this will be equal to \p target->kind->device_type. If \p target is null then the * target is to be determined later. * * This is needed to support the legacy "on_device" and "device_copy" calls which only allow * a \p DLDeviceTypes (as an integer) to be given. * - * kInvalidDeviceType denotes unconstrained. + * kInvalidDeviceType denotes unconstrained. An int since the DLDeviceType enum representation + * is not fixed. Private to discourage further int vs DLDeviceType confusion. */ int /* actually DLDeviceType */ device_type_int; + public: DLDeviceType device_type() const { return static_cast(device_type_int); } /*! diff --git a/python/tvm/ir/expr.py b/python/tvm/ir/expr.py index 139e2b8d97fa..43cba1a83530 100644 --- a/python/tvm/ir/expr.py +++ b/python/tvm/ir/expr.py @@ -64,8 +64,8 @@ class GlobalVar(RelayExpr): The name of the variable. """ - def __init__(self, name_hint): - self.__init_handle_by_constructor__(_ffi_api.GlobalVar, name_hint) + def __init__(self, name_hint, type_annot=None): + self.__init_handle_by_constructor__(_ffi_api.GlobalVar, name_hint, type_annot) def __call__(self, *args): """Call the global variable. diff --git a/src/ir/expr.cc b/src/ir/expr.cc index caddf0efcc77..399873492f04 100644 --- a/src/ir/expr.cc +++ b/src/ir/expr.cc @@ -141,15 +141,18 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "range(min=" << op->min << ", ext=" << op->extent << ')'; }); -GlobalVar::GlobalVar(String name_hint) { +GlobalVar::GlobalVar(String name_hint, Type type) { ObjectPtr n = make_object(); n->name_hint = std::move(name_hint); + n->checked_type_ = std::move(type); data_ = std::move(n); } TVM_REGISTER_NODE_TYPE(GlobalVarNode); -TVM_REGISTER_GLOBAL("ir.GlobalVar").set_body_typed([](String name) { return GlobalVar(name); }); +TVM_REGISTER_GLOBAL("ir.GlobalVar").set_body_typed([](String name, Type type) { + return GlobalVar(name, type); +}); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { diff --git a/src/relay/backend/contrib/cmsisnn/extract_constants.cc b/src/relay/backend/contrib/cmsisnn/extract_constants.cc index 5ed23ad1ad6a..ca003d80c1d9 100644 --- a/src/relay/backend/contrib/cmsisnn/extract_constants.cc +++ b/src/relay/backend/contrib/cmsisnn/extract_constants.cc @@ -46,6 +46,8 @@ class ExtractConstantsMutator : public MixedModeMutator { private: String gen_var_name() { return "tvm_var_extract_const_" + std::to_string(var_count_++); } + using MixedModeMutator::VisitExpr_; + Expr VisitExpr_(const FunctionNode* function) final { Function func = GetRef(function); function_to_constants_.Set(func, Array{}); @@ -56,7 +58,7 @@ class ExtractConstantsMutator : public MixedModeMutator { func = Function(FreeVars(new_body), new_body, func->ret_type, FreeTypeVars(new_body, mod_), func->attrs); } - return func; + return std::move(func); } Expr Rewrite_(const CallNode* call, const Expr& post) final { diff --git a/src/relay/backend/contrib/cmsisnn/generate_constants.cc b/src/relay/backend/contrib/cmsisnn/generate_constants.cc index 056784b6675d..472f93a0a1f0 100644 --- a/src/relay/backend/contrib/cmsisnn/generate_constants.cc +++ b/src/relay/backend/contrib/cmsisnn/generate_constants.cc @@ -179,7 +179,7 @@ class GenerateConstantsMutator : public MixedModeMutator { if (clip_call) { ret_call = Call(clip_call->op, {ret_call}, clip_call->attrs, {}); } - return ret_call; + return std::move(ret_call); } Expr Rewrite_(const CallNode* call, const Expr& post) final { diff --git a/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc b/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc index 2a7d0ae21769..e0e5aa962239 100644 --- a/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc +++ b/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc @@ -101,8 +101,10 @@ class CodeGenCMSISNN : public codegen::CodeGenCHost { int clip_max; }; + using codegen::CodeGenCHost::VisitStmt_; + /*! * \brief Emits CMSIS-NN APIs for every call_extern */ - void VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) + void VisitExpr_(const CallNode* op, std::ostream& os) final { if (!op->op.same_as(builtin::call_extern())) { CodeGenCHost::VisitExpr_(op, os); return; diff --git a/src/relay/backend/graph_executor_codegen.cc b/src/relay/backend/graph_executor_codegen.cc index 3d889cdf6561..16b1ddb3c82f 100644 --- a/src/relay/backend/graph_executor_codegen.cc +++ b/src/relay/backend/graph_executor_codegen.cc @@ -694,8 +694,7 @@ class GraphExecutorCodegenModule : public runtime::ModuleNode { *rv = this->output_.external_mods; }); } else if (name == "get_devices") { - return PackedFunc( - [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = Array(); }); + return PackedFunc([sptr_to_self](TVMArgs args, TVMRetValue* rv) { *rv = Array(); }); } else if (name == "get_metadata") { return PackedFunc( [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->output_.metadata; }); diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 93b2bcb8d7ef..23aee452ba09 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -1088,6 +1088,10 @@ IRModule VMCompiler::OptimizeModuleImpl(IRModule mod) { // let-bound functions. pass_seqs.push_back(DeadCodeElimination(/*inline_once=*/false)); + // Now that we have PrimFuncs, flow and solve SEScope constraints again to account for + // any memory scopes which lowering has settled on. + pass_seqs.push_back(transform::PlanDevices(config_)); + // Inline the functions that are lifted to the module scope. We perform this // pass after all other optimization passes but before the memory allocation // pass. This is because memory allocation pass will insert `invoke_tvm_op` diff --git a/src/relay/ir/adt.cc b/src/relay/ir/adt.cc index c2b8fd641d03..0389547a78f9 100644 --- a/src/relay/ir/adt.cc +++ b/src/relay/ir/adt.cc @@ -115,7 +115,7 @@ Clause WithFields(Clause clause, Optional opt_lhs, Optional opt_r cow_clause_node->lhs = lhs; cow_clause_node->rhs = rhs; } - return std::move(clause); + return clause; } TVM_REGISTER_NODE_TYPE(ClauseNode); @@ -168,7 +168,7 @@ Match WithFields(Match match, Optional opt_data, Optional> o cow_match_node->complete = complete; cow_match_node->span = span; } - return std::move(match); + return match; } TVM_REGISTER_NODE_TYPE(MatchNode); diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index 6b4b2f16ce1e..18e83f998e24 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -107,7 +107,7 @@ Tuple WithFields(Tuple tuple, Optional> opt_fields, cow_tuple_node->virtual_device_ = virtual_device; cow_tuple_node->span = span; } - return std::move(tuple); + return tuple; } TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) @@ -124,6 +124,13 @@ Var::Var(Id vid, Type type_annotation, Span span) { data_ = std::move(n); } +/* static */ Var Var::GenSym(Type type_annotation, Span span) { + static size_t next_id = std::atomic(0); + std::ostringstream os; + os << "x_" << next_id++; + return Var(os.str(), std::move(type_annotation), std::move(span)); +} + Var WithFields(Var var, Optional opt_vid, Optional opt_type_annotation, Optional opt_virtual_device, Optional opt_span) { Id vid = opt_vid.value_or(var->vid); @@ -141,7 +148,7 @@ Var WithFields(Var var, Optional opt_vid, Optional opt_type_annotation cow_var_node->virtual_device_ = virtual_device; cow_var_node->span = span; } - return std::move(var); + return var; } TVM_REGISTER_NODE_TYPE(VarNode); @@ -219,7 +226,7 @@ Call WithFields(Call call, Optional opt_op, Optional> opt_args cow_call_node->virtual_device_ = virtual_device; cow_call_node->span = span; } - return std::move(call); + return call; } TVM_REGISTER_NODE_TYPE(CallNode); @@ -264,7 +271,7 @@ Let WithFields(Let let, Optional opt_var, Optional opt_value, Optiona cow_let_node->virtual_device_ = virtual_device; cow_let_node->span = span; } - return std::move(let); + return let; } TVM_REGISTER_NODE_TYPE(LetNode); @@ -308,7 +315,7 @@ If WithFields(If if_expr, Optional opt_cond, Optional opt_true_branc cow_if_node->virtual_device_ = virtual_device; cow_if_node->span = span; } - return std::move(if_expr); + return if_expr; } TVM_REGISTER_NODE_TYPE(IfNode); @@ -350,7 +357,7 @@ TupleGetItem WithFields(TupleGetItem tuple_get_item, Optional opt_tuple, cow_tuple_get_item_node->span = span; cow_tuple_get_item_node->virtual_device_ = virtual_device; } - return std::move(tuple_get_item); + return tuple_get_item; } TVM_REGISTER_NODE_TYPE(TupleGetItemNode); @@ -385,7 +392,7 @@ RefCreate WithFields(RefCreate ref_create, Optional opt_value, cow_ref_create_node->virtual_device_ = virtual_device; cow_ref_create_node->span = span; } - return std::move(ref_create); + return ref_create; } TVM_REGISTER_NODE_TYPE(RefCreateNode); @@ -420,7 +427,7 @@ RefRead WithFields(RefRead ref_read, Optional opt_ref, Optional o cow_ref_read_node->virtual_device_ = virtual_device; cow_ref_read_node->span = span; } - return std::move(ref_read); + return ref_read; } TVM_REGISTER_NODE_TYPE(RefReadNode); @@ -457,7 +464,7 @@ RefWrite WithFields(RefWrite ref_write, Optional opt_ref, Optional o cow_ref_write_node->virtual_device_ = virtual_device; cow_ref_write_node->span = span; } - return std::move(ref_write); + return ref_write; } TVM_REGISTER_NODE_TYPE(RefWriteNode); @@ -510,29 +517,29 @@ inline void Dismantle(const Expr& expr) { stack.top().second = true; // special handling - if (const CallNode* op = node.as()) { + if (const auto* call_node = node.as()) { // do not process args if used elsewhere - if (op->args.use_count() < 2) { - for (auto it = op->args.rbegin(); it != op->args.rend(); ++it) { + if (call_node->args.use_count() < 2) { + for (auto it = call_node->args.rbegin(); it != call_node->args.rend(); ++it) { fpush_to_stack(*it); } } - } else if (const TupleNode* op = node.as()) { + } else if (const auto* tuple_node = node.as()) { // do not process fields if used elsewhere - if (op->fields.use_count() < 2) { - for (auto it = op->fields.rbegin(); it != op->fields.rend(); ++it) { + if (tuple_node->fields.use_count() < 2) { + for (auto it = tuple_node->fields.rbegin(); it != tuple_node->fields.rend(); ++it) { fpush_to_stack(*it); } } - } else if (const TupleGetItemNode* op = node.as()) { + } else if (const auto* tuple_get_item_node = node.as()) { // do not process tuple if used elsewhere - if (op->tuple.use_count() < 2) { - fpush_to_stack(op->tuple); + if (tuple_get_item_node->tuple.use_count() < 2) { + fpush_to_stack(tuple_get_item_node->tuple); } - } else if (const LetNode* op = node.as()) { + } else if (const auto* let_node = node.as()) { // do not process let if used elsewhere - if (op->body.use_count() < 2) { - fpush_to_stack(op->body); + if (let_node->body.use_count() < 2) { + fpush_to_stack(let_node->body); } } } diff --git a/src/relay/ir/function.cc b/src/relay/ir/function.cc index f2cb02194009..4c5b867e49da 100644 --- a/src/relay/ir/function.cc +++ b/src/relay/ir/function.cc @@ -91,7 +91,7 @@ Function WithFields(Function function, Optional> opt_params, Optional cow_function_node->virtual_device_ = virtual_device; cow_function_node->span = span; } - return std::move(function); + return function; } FuncType FunctionNode::func_type_annotation() const { diff --git a/src/relay/op/memory/on_device.cc b/src/relay/op/memory/on_device.cc index ae5ef33da6d0..0fd86d3de67c 100644 --- a/src/relay/op/memory/on_device.cc +++ b/src/relay/op/memory/on_device.cc @@ -99,7 +99,9 @@ Expr MaybeOnDevice(Expr body, SEScope se_scope, bool constrain_result, bool cons ICHECK(inner == outer) << "Cannot constrain intermediate result of nested on_device calls to different SEScopes"; } - // We can now ignore the intermediate constraints, if any. + // We can now ignore the middle constraint. + // If the outer on_device has any constraint then use se_scope given for it. + // Otherwise we can use the existing inner se_scope. return OnDevice(props.body, (constrain_inner || constrain_outer) ? outer : inner, constrain_outer, constrain_inner); } else { diff --git a/src/relay/op/memory/on_device.h b/src/relay/op/memory/on_device.h index bac6695ac35b..2ebaf034c760 100644 --- a/src/relay/op/memory/on_device.h +++ b/src/relay/op/memory/on_device.h @@ -66,15 +66,18 @@ struct OnDeviceProps { }; /*! - * \brief As for OnDevice, but taking all fields other than \p body from \p props. + * \brief Wraps \p body in an "on_device" CallNode, taking all fields other than \p body from \p + * props. */ inline Call OnDeviceWithProps(Expr body, const OnDeviceProps& props) { return OnDevice(std::move(body), props.se_scope, props.constrain_result, props.constrain_body); } /*! - * \brief As for OnDevice, but don't constrain the body or result to any particular virtual device. - * This allows a "device_copy" when required. + * \brief Wraps \p body in an "on_device" CallNode, but don't constrain the body or result to + * any particular virtual device. This allows a "device_copy" to be inserted by PlanDevices + * where required, while at the same time not introducing unnecessary freedom in the device + * choices. */ inline Call OnDeviceCopyOk(Expr body) { return OnDevice(std::move(body), SEScope::FullyUnconstrained(), diff --git a/src/relay/transforms/device_domains.cc b/src/relay/transforms/device_domains.cc index 76697d8437f4..fd46a6dc0563 100644 --- a/src/relay/transforms/device_domains.cc +++ b/src/relay/transforms/device_domains.cc @@ -199,9 +199,15 @@ DeviceDomainPtr DeviceDomains::DomainForCallee(const Call& call) { DeviceCopyProps device_copy_props = GetDeviceCopyProps(call.get()); CallLoweredProps call_lowered_props = GetCallLoweredProps(call.get()); - // TODO(mbs): Support call_lowered to PrimFuncs. - ICHECK(!call_lowered_props.lowered_func.defined()); - if (on_device_props.body.defined()) { + if (call_lowered_props.lowered_func.defined()) { + // Presumably we've already seen the call to the "primitive" Function from which this lowered + // function was derived in an earlier PlanDevices pass. Thus we've already established that + // all the argument and result devices domains must be equal, ignoring memory scopes. + // So at this point we'll let all the arguments and result be free so that memory scopes can + // differ. + // TODO(mbs): As per header comments, need to revisit when can setup sub-SEScope constraints. + return DomainFor(call_lowered_props.lowered_func); + } else if (on_device_props.body.defined()) { // By default: // on_device(expr, se_scope=) // on_device : fn():?x? diff --git a/src/relay/transforms/device_planner.cc b/src/relay/transforms/device_planner.cc index a85233de17e5..bad8363f4783 100644 --- a/src/relay/transforms/device_planner.cc +++ b/src/relay/transforms/device_planner.cc @@ -57,6 +57,8 @@ * idempotent. * - Some special operators require their arguments or results to be on the 'host' (typcially * a CPU) \p SEScope, see below. + * - Any \p PrimFuncs in the \p IRModule (if \p LowerTEPass has already run) may constrain their + * argument buffers to have a specific memory scope, which is part of \p SEScope. * - Annotations left over from a previous run of this pass, such as 'param_se_scopes' and * 'result_se_scope' function attributes we introduce below. This is so the pass is idempotent * and can be re-run to flow additional memory scope constraints. @@ -71,10 +73,14 @@ * - We wish to treat \code on_device(expr, device_type=d).0 \endcode as if it were written * \code on_device(expr.0, device_type_d) \endcode. I.e. we prefer to copy the projection from * the tuple rather than project from a copy of the tuple. We'll do this by rewriting. + * - We are prepared to insert device_copies on the arguments and result of calls to PrimFuncs, + * on the assumption a) we already ran PlanDevices before lowering so we are not allowing + * any new cross-device copies, but b) after lowering we may have new memory scope constraits + * to deal with. * * Phase 1 * ------- - * We flow constraints from the "on_device" and "device_copy" calls, + * We flow constraints from the "on_device" and "device_copy" calls, PrimFunc buffer memory scopes, * and some special ops, to all other Relay sub-expressions. * * For a primitive such as \code add(e1, e2) \endcode all arguments and results must be on the @@ -102,6 +108,10 @@ * different from each other. Every call to the function must use the same choice of parameter * and result devices -- there is no 'device polymorphism' for Relay functions. * + * Currently \p PrimFuncs and external functions do not carry over their parameter and result + * devices from their original Relay Function representations. However we know all calls to those + * functions are device-consistent, thus no information is lost. + * * Phase 2 * ------- * After flowing constraints we apply some defaulting heuristics (using a global default \p SEScope) @@ -138,6 +148,7 @@ * around a var or global var. These uses of "on_device" imply both the argument and result are * on the same device. We signal this by setting the 'is_fixed' OnDeviceAttrs field to true, * which helps make this pass idempotent. + * - The buffer maps for called PrimFuncs are updated to capture memory scopes. * * Helper visitors (in device_aware_visitors.h) can be used by downstream transforms to recover * the device for any expression for their own use, e.g. during memory planning. All downstream @@ -268,6 +279,7 @@ #include +#include "../../tir/analysis/device_constraint_utils.h" #include "../op/annotation/annotation.h" #include "../op/memory/device_copy.h" #include "../op/memory/on_device.h" @@ -301,6 +313,17 @@ namespace { * on_device(e).0 * ==> on_device(e.0) * \endcode + * + * - Be prepared to copy arguments and results on primitive call boundaries in case memory + * scopes don't line up. We'll use the 'fully unconstrained' version of on_device so that + * we can allow for a device_copy without knowing the specific device for the arguments. + * \code + * call_lowered(@prim, (a, b)) + * ==> copy_ok(call_lowered(@prim, (copy_ok(a), copy_ok(b)))) + * where + * copy_ok(x) = on_device(x, se_scope=SEScope::FullyUnconstrained, + * constrain_body=False, constrain_result=False) + * \endcode */ class RewriteOnDevices : public ExprMutator { public: @@ -358,6 +381,26 @@ class RewriteOnDevices : public ExprMutator { return WithFields(GetRef(function_node), function_node->params, std::move(body)); } + Expr VisitExpr_(const CallNode* call_node) final { + CallLoweredProps props = GetCallLoweredProps(call_node); + if (props.lowered_func.defined()) { + BaseFunc base_func = mod_->Lookup(props.lowered_func); + if (base_func.as()) { + VLOG(2) << "allowing device_copy on PrimFunc arguments and result"; + Array new_args; + new_args.reserve(props.arguments.size()); + for (const auto& arg : props.arguments) { + Expr new_arg = VisitExpr(arg); + new_args.push_back(OnDeviceCopyOk(std::move(new_arg))); + } + Call new_call = CallLowered(std::move(props.lowered_func), std::move(new_args), props.attrs, + call_node->span); + return OnDeviceCopyOk(std::move(new_call)); + } + } + return ExprMutator::VisitExpr_(call_node); + } + /*! \brief Module we are rewriting, so we can lookup global definitions. */ IRModule mod_; }; @@ -398,6 +441,10 @@ class DeviceAnalyzer : public ExprVisitor { VLOG(2) << "collecting constraints from Relay Function '" << kv.first->name_hint << "'"; domains_->UnifyExprExact(kv.first, kv.second); VisitExpr(GetRef(function_node)); + } else if (const auto* prim_func_node = kv.second.as()) { + VLOG(2) << "collecting constraints from TIR PrimFunc '" << kv.first->name_hint << "'"; + domains_->UnifyExprExact( + kv.first, DomainForPrimFunc(kv.first, GetRef(prim_func_node))); } else { VLOG(2) << "skipping '" << kv.first->name_hint << "'"; } @@ -406,6 +453,40 @@ class DeviceAnalyzer : public ExprVisitor { } private: + /*! + * \brief Return the domain representing \p prim_func which, before lowering, had + * the Relay \p type. + */ + DeviceDomainPtr DomainForPrimFunc(const GlobalVar& global_var, const tir::PrimFunc& prim_func) { + // CAUTION: The prim_func->checked_type() is currently w.r.t. the flattened and DPS form + // of the prim func, however here we wish to remain within the Relay view of all functions. + // Thus we'll use the global var who's checked_type is in Relay form. + auto func_domain = domains_->DomainFor(global_var); // higher-order + + // TODO(mbs): We don't visit the body of the function -- there's currently nothing to be done. + const auto* func_type_node = global_var->checked_type().as(); + ICHECK(func_type_node); + ICHECK_EQ(func_domain->function_arity(), func_type_node->arg_types.size()); + + Array se_scopes = + tir::GetPrimFuncArgAndResultConstraints(prim_func, GetRef(func_type_node)); + + // Build the implied domain (in terms of the function's Relay type) implied by any memory scope + // constrains in the function's buffers, for both arguments and results. + std::vector args_and_result_domains; + args_and_result_domains.reserve(se_scopes.size()); + for (size_t i = 0; i < func_type_node->arg_types.size(); ++i) { + const SEScope& param_se_scope = se_scopes[i]; + VLOG(2) << "param_se_scope[" << i << "] = " << param_se_scope; + args_and_result_domains.push_back(domains_->MakeFirstOrderDomain(param_se_scope)); + } + const SEScope& ret_se_scope = se_scopes.back(); + VLOG(2) << "ret_se_scope = " << ret_se_scope; + args_and_result_domains.push_back(domains_->MakeFirstOrderDomain(ret_se_scope)); + + return domains_->MakeHigherOrderDomain(std::move(args_and_result_domains)); + } + void VisitExpr_(const CallNode* call_node) final { auto call = GetRef(call_node); @@ -849,6 +930,15 @@ class DeviceCapturer : public ExprMutator { if (const auto* function_node = AsOptimizableFunctionNode(kv.second)) { VLOG(2) << "capturing devices for Relay Function '" << kv.first->name_hint << "'"; result->Add(kv.first, Downcast(Mutate(GetRef(function_node)))); + } else if (const auto* prim_func_node = kv.second.as()) { + VLOG(2) << "capturing devices for TIR PrimFunc '" << kv.first->name_hint << "'"; + auto prim_func = GetRef(prim_func_node); + tir::PrimFunc new_prim_func = UpdatePrimFunc(kv.first, prim_func); + VLOG(2) << "Rewritten prim func:" << std::endl + << PrettyPrint(prim_func) << std::endl + << "to:" << std::endl + << PrettyPrint(new_prim_func); + result->Add(kv.first, std::move(new_prim_func)); } else { VLOG(2) << "skipping '" << kv.first->name_hint << "'"; result->Add(kv.first, kv.second); @@ -858,6 +948,34 @@ class DeviceCapturer : public ExprMutator { } private: + /*! + * \brief Returns \p prim_func updated to capture any memory scope's implied by its device + * domain. + */ + tir::PrimFunc UpdatePrimFunc(const GlobalVar& global_var, const tir::PrimFunc& prim_func) { + // CAUTION: Same caution as for DeviceAnalyzer::DomainForPrimFunc. + auto func_domain = domains_->DomainFor(global_var); + ICHECK(func_domain->is_higher_order()); + + const auto* func_type_node = global_var->checked_type().as(); + ICHECK(func_type_node); + ICHECK_EQ(func_domain->function_arity(), func_type_node->arg_types.size()); + + std::vector arg_and_result_se_scopes; + arg_and_result_se_scopes.reserve(func_type_node->arg_types.size() + 1); + for (size_t i = 0; i < func_type_node->arg_types.size(); ++i) { + SEScope param_se_scope = domains_->ResultSEScope(func_domain->function_param(i)); + VLOG(2) << "param_se_scope[" << i << "] = " << param_se_scope; + arg_and_result_se_scopes.push_back(param_se_scope); + } + SEScope ret_se_scope = domains_->ResultSEScope(func_domain->function_result()); + VLOG(2) << "ret_se_scope = " << ret_se_scope; + arg_and_result_se_scopes.push_back(ret_se_scope); + + return tir::ApplyPrimFuncArgAndResultConstraints(prim_func, GetRef(func_type_node), + arg_and_result_se_scopes); + } + // Nothing interesting for VarNode, ConstantNode, GlobalVarNode, OpNode and ConstructorNode Expr VisitExpr_(const TupleNode* tuple_node) final { @@ -932,8 +1050,7 @@ class DeviceCapturer : public ExprMutator { // match. return VisitExpr(device_copy_props.body); } else { - return VisitChild(/*lexical_se_scope=*/ - dst_se_scope, + return VisitChild(/*lexical_se_scope=*/dst_se_scope, /*expected_se_scope=*/dst_se_scope, /*child_se_scope=*/src_se_scope, device_copy_props.body); } diff --git a/src/relay/transforms/let_list.h b/src/relay/transforms/let_list.h index 56875f6c16a1..f449d6c3b011 100644 --- a/src/relay/transforms/let_list.h +++ b/src/relay/transforms/let_list.h @@ -79,7 +79,7 @@ class LetList { * * \return a Var that hold the inserted expr. */ - Var Push(Expr expr, Type ty) { return Push(Var("x", ty), expr); } + Var Push(Expr expr, Type ty) { return Push(Var::GenSym(ty), expr); } /*! * \brief insert a binding. diff --git a/src/relay/transforms/partial_eval.cc b/src/relay/transforms/partial_eval.cc index 8f5e9e146d54..28d1aa5532bf 100644 --- a/src/relay/transforms/partial_eval.cc +++ b/src/relay/transforms/partial_eval.cc @@ -842,7 +842,7 @@ class PartialEvaluator : public ExprFunctor }); } - PStatic VisitFunc(const Function& func, LetList* ll, const Var& name = Var("x", Type())) { + PStatic VisitFunc(const Function& func, LetList* ll, const Var& name) { Func f = VisitFuncStatic(func, name); Function u_func = AsFunc(RegisterFuncId(DeDup(AnnotateFuncId(func)))); // TODO(@M.K.): we seems to reduce landin knot into letrec. @@ -851,7 +851,7 @@ class PartialEvaluator : public ExprFunctor } PStatic VisitExpr_(const FunctionNode* op, LetList* ll) final { - return VisitFunc(GetRef(op), ll); + return VisitFunc(GetRef(op), ll, Var::GenSym()); } struct ReflectError : Error { diff --git a/src/relay/transforms/to_a_normal_form.cc b/src/relay/transforms/to_a_normal_form.cc index c955269e3412..741de6d7ea9b 100644 --- a/src/relay/transforms/to_a_normal_form.cc +++ b/src/relay/transforms/to_a_normal_form.cc @@ -219,7 +219,7 @@ class Fill : ExprFunctor, private transform::Lexi // v is already defined (e.g. coming from a Let expression). Otherwise return `now` directly Expr Compound(const Expr& orig, const Expr& now, const Var& v) { Expr annotated_expr = MaybeOnDeviceFixed(now, GetSEScope(orig)); - Var var = v.defined() ? v : Var(String("x"), Type()); + Var var = v.defined() ? v : Var::GenSym(); bool not_included = include_set_ && include_set_->find(orig) == include_set_->end(); if (!v.defined() && not_included) { return annotated_expr; diff --git a/src/runtime/contrib/verilator/verilator_runtime.h b/src/runtime/contrib/verilator/verilator_runtime.h index 588b3f172a3e..9ef17d7481ab 100644 --- a/src/runtime/contrib/verilator/verilator_runtime.h +++ b/src/runtime/contrib/verilator/verilator_runtime.h @@ -94,7 +94,7 @@ class VerilatorRuntime : public JSONRuntimeBase { ~VerilatorRuntime(); - const char* type_key() const { return "verilator"; } + const char* type_key() const final { return "verilator"; } /*! \brief set verilator library */ void SetLibrary(const std::string& lib_name); diff --git a/src/tir/transforms/lower_cross_thread_reduction.cc b/src/tir/transforms/lower_cross_thread_reduction.cc index fb5255664af3..2eea869af516 100644 --- a/src/tir/transforms/lower_cross_thread_reduction.cc +++ b/src/tir/transforms/lower_cross_thread_reduction.cc @@ -207,7 +207,7 @@ class InThreadReducerMaker : private StmtMutator { if (res->thread_binding.defined()) { return res->body; } else { - return res; + return std::move(res); } } else { return Stmt{nullptr}; @@ -564,7 +564,7 @@ class CrossThreadReductionTransformer : public StmtMutator { } } } - return new_block; + return std::move(new_block); } Stmt VisitStmt_(const BlockRealizeNode* realize) final { diff --git a/tests/python/relay/test_pass_plan_devices.py b/tests/python/relay/test_pass_plan_devices.py index 6bf103ea0c2a..ee9cfc909585 100644 --- a/tests/python/relay/test_pass_plan_devices.py +++ b/tests/python/relay/test_pass_plan_devices.py @@ -23,6 +23,7 @@ import tvm from tvm import relay +from tvm.script import tir as T import tvm.testing import numpy as np @@ -1580,6 +1581,110 @@ def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], %c: Tensor[( exercise(input(), expected(), None, None) +def test_lowered(): + """ + Tests propagation of memory scopes from PrimFuncs and insertion + of device_copies to mediate any scope changes. + """ + + @T.prim_func + def input_gem(a: T.handle, b: T.handle, c: T.handle, d: T.handle) -> None: + A = T.match_buffer(a, [128, 128], scope="scopeA") # will flow out + B = T.match_buffer(b, [128, 128], scope="") # will flow in + C = T.match_buffer(c, [128, 128], scope="scopeB") # will flow out + D = T.match_buffer(d, [128, 128], scope="scopeA") # will flow out + + for i, j, k in T.grid(128, 128, 128): + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + D[vi, vj] = C[vi, vj] + D[vi, vj] = D[vi, vj] + A[vi, vk] * B[vj, vk] + + @T.prim_func + def expected_gem(a: T.handle, b: T.handle, c: T.handle, d: T.handle) -> None: + A = T.match_buffer(a, [128, 128], scope="scopeA") + B = T.match_buffer(b, [128, 128], scope="scopeB") # flowed in + C = T.match_buffer(c, [128, 128], scope="scopeB") + D = T.match_buffer(d, [128, 128], scope="scopeA") + + for i, j, k in T.grid(128, 128, 128): + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + D[vi, vj] = C[vi, vj] + D[vi, vj] = D[vi, vj] + A[vi, vk] * B[vj, vk] + + metatable = { + "SEScope": [ + CPU, # meta[SEScope][0], no memory scope + CPU_SCOPE_A, # meta[SEScope][1], "scopeA" + CPU_SCOPE_B, + ] + } # meta[SEScope][2], "scopeB" + gem_ty = relay.FuncType( + [ + relay.TensorType((128, 128), "float32"), + relay.TensorType((128, 128), "float32"), + relay.TensorType((128, 128), "float32"), + ], + relay.TensorType((128, 128), "float32"), + ) + gem_gv = relay.GlobalVar("gem", type_annot=gem_ty) + + def input(): + mod = tvm.ir.IRModule() + mod[gem_gv] = input_gem + # - %x on CPU, no memory scope constraint, so will be constrained by first param of gem to "scopeA". + # - %y on CPU "scopeB", so will flow in to second param of gem. + # - %z on CPU "scopeA", so will clash with third param of gem and will need device_copy. + # - result on CPU "scopeB", but result of gem on "scopeA" so will need device_copy + return tvm.parser.parse( + """ + #[version = "0.0.5"] + def @main(%x : Tensor[(128, 128), float32], + %y : Tensor[(128, 128), float32], + %z : Tensor[(128, 128), float32], + param_se_scopes=[meta[SEScope][0], meta[SEScope][2], meta[SEScope][1]], + result_se_scope=meta[SEScope][2]) { + call_lowered(@gem, (%x, %y, %z)) + } + """, + "from_string", + mod, + metatable, + ) + + def expected(): + mod = tvm.ir.IRModule() + mod[gem_gv] = expected_gem + # - %x now on CPU "scopeA", no device_copy needed. + # - %y still on CPU "scopeB", no device_copy needed. + # - %z still on CPU "scopeA", needs device_copy to "scopeB". + # - result still on CPU "scopeB", needs device_copy from "scopeA". + return tvm.parser.parse( + """ + #[version = "0.0.5"] + def @main(%x : Tensor[(128, 128), float32], + %y : Tensor[(128, 128), float32], + %z : Tensor[(128, 128), float32], + param_se_scopes=[meta[SEScope][1], meta[SEScope][2], meta[SEScope][1]], + result_se_scope=meta[SEScope][2]) { + %0 = device_copy(%z, src_se_scope=meta[SEScope][1], dst_se_scope=meta[SEScope][2]); + %1 = on_device(%0, se_scope=meta[SEScope][2], constrain_result=True); + %2 = call_lowered(@gem, (%x, %y, %1)); + %3 = on_device(%2, se_scope=meta[SEScope][1], constrain_result=True); + device_copy(%3, src_se_scope=meta[SEScope][1], dst_se_scope=meta[SEScope][2]) + } + """, + "from_string", + mod, + metatable, + ) + + exercise(input(), expected(), None, None) + + if __name__ == "__main__": import sys import pytest