diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index cdd4c9c1dbd2f..e740776d6d4f4 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -444,6 +444,17 @@ TVM_DLL Pass RelayToTIRTargetHook(); */ TVM_DLL Pass ManifestAlloc(Target target_host, Map targets); +/*! + * \brief Uses existing "on_device" and "device_copy" CallNodes to infer the device on which + * every Relay sub-expression should run (and the result stored). Captures the result of that + * analysis using new "on_device" and "device_copy" CallNodes. See + * tvm::relay::transform::{LexicalOnDeviceMixin,DeviceAwareExprVisitor,DeviceAwareExprMutator} + * for help recovering the device for an arbitrary sub-expression in downstream transformations. + * + * \param default_device_type DLDeviceType for default device. + */ +TVM_DLL Pass PlanDevices(DLDeviceType default_device_type); + } // namespace transform /*! diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index 7c79464bdd303..bb91afc061953 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -1167,6 +1167,16 @@ def SimplifyExpr(): return _ffi_api.SimplifyExpr() +def PlanDevices(default_device): + """ + Uses existing "on_device" and "device_copy" CallNodes to infer the device on which + every Relay sub-expression should run (and the result stored). Captures the result of that + analysis using new "on_device" and "device_copy" CallNodes. Note that the device_id of + the default_device is ignored. + """ + return _ffi_api.PlanDevices(default_device) + + def FoldExplicitPadding(): """ FoldExplicitPadding finds explict padding before an op that can support diff --git a/python/tvm/target/target.py b/python/tvm/target/target.py index af2f5d8572930..4ce8881701346 100644 --- a/python/tvm/target/target.py +++ b/python/tvm/target/target.py @@ -525,7 +525,7 @@ def hexagon(cpu_ver="v66", **kwargs): # LLVM target string def create_llvm_target(cpu_ver, config): - """ Create LLVM target string. """ + """Create LLVM target string.""" target = " -mtriple=hexagon" mcpu = " -mcpu=hexagon" + cpu_ver @@ -547,7 +547,7 @@ def create_target_features(config): # Simulator options string def create_sim_options(cpu_ver, config): - """ Create simulator option string. """ + """Create simulator option string.""" def validate_hvx_length(codegen_hvx, sim_options): if sim_options and "--hvx_length" in sim_options: @@ -606,7 +606,7 @@ def validate_hvx_length(codegen_hvx, sim_options): # LLVM options string def create_llvm_options(cpu_ver, config): # pylint: disable=unused-argument - """ Create LLVM options string. """ + """Create LLVM options string.""" llvm_options = config["llvm_options"] @@ -620,7 +620,7 @@ def create_llvm_options(cpu_ver, config): # pylint: disable=unused-argument # TVM target attributes string def create_tvm_options(cpu_ver, config): # pylint: disable=unused-argument - """ Create TVM target features string. """ + """Create TVM target features string.""" features = { "link_params": "link-params", diff --git a/src/relay/transforms/device_planner.cc b/src/relay/transforms/device_planner.cc new file mode 100644 index 0000000000000..3d6c59d881258 --- /dev/null +++ b/src/relay/transforms/device_planner.cc @@ -0,0 +1,1986 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/analysis/device_planner.cc + * \brief Determines a unique device to hold the result of every Relay sub-expression. + * + * We say a Relay expression E is 'on device D' if the result of executing E is stored on D. + * Currently we only track the 'device_type' of D and not its 'device id'. We do not track the + * specific target associated with D (this is recovered independently via a TargetMap), and we + * do not track the storage scope within D (this is yet to be implemented). + * + * Note that 'stored on device D' is almost but not quite the same as 'executes on device D', + * see below. + * + * This pass assumes the module already contains some "on_device" and/or "device_copy" CallNodes: + * - "device_copy" CallNodes (with a \p DeviceCopyAttrs attribute) specify a 'src_dev_type' and + * 'dst_dev_type' device type, which constrain the argument and context of the call + * respectively. It is ok if source and destination devices are the same, such no-op copies + * will be removed after accounting for the device preference. + * - "on_device" CallNodes (with a \p OnDeviceAttrs attribute) specify a 'device_type', which + * constrains the argument of the call, but (usually, see below) leaves the context + * unconstrained. These are called 'annotations' in the rest of the code, have no operational + * significance by themselves, but may trigger the insertion of a new "device_copy". + * - In two situations the result of an "on_device" CallNode may also be constrained to the + * given device: + * - The "on_device" call occurs at the top-level of a function body, or occurs as an + * immediately let-bound expression. In this situation the extra degree of freedom in + * the function result and let-binding leads to surprising device copies, so we simply + * force the function result or let-bound variable to the given device. + * - The \p OnDeviceAttrs has an \p is_fixed field of \p true, which indicates we inserted + * it ourselves during an earlier invocation of this pass. This helps make this pass + * idempotent. + * + * We proceed in four phases: + * + * Phase 0 + * ------- + * We rewrite the programs to handle some special cases: + * - "on_device" calls at the top-level of function or immediately let-bound are rewritten + * to have \code is_fixed=true \endcode. + * - We wish to treat \code on_device(expr, device_type=d).0 \endcode as if it were written + * \code on_device(expr.0, device_type_d) \endcode. I.e. we prefer to copy the projection from + * the tuple rather than project from a copy of the tuple. We'll do this by rewriting. + * + * Phase 1 + * ------- + * We flow constraints from the "on_device" and "device_copy" calls (and some special ops, see + * below) to all other Relay sub-expressions. (For idempotence we also respect any existing + * "on_device" function attributes we introduce below.) + * + * For a primitive such as \code add(e1, e2) \endcode all arguments and results must be on the + * same device. However each call site can use a different device. In other words primitives are + * 'device polymorphic' since we compile and execute them for each required device. + * + * For most Relay expressions the device for the overall expression is the same as the device + * for it's sub-expressions. E.g. each field of a tuple must be on the same device as the tuple + * itself, the condition and arms of an \p if must all be on the same device as the overall if, + * and so on. + * + * Some special ops (or 'dialects') are handled: + * - Relay supports computing the shape of tensors and operators at runtime using "shape_of", + * "shape_func", and "reshape_tensor". Shapes must only be held on the CPU, but the tensors + * they describe may reside on any device. + * - Explicit memory allocation is done using the "alloc_storage" and "alloc_tensor". Again + * shapes reside on the CPU, but the allocated tensors may reside on any device. + * + * Two Relay expression have special handling: + * - For \code let x = e1; e2 \endcode the result of \p e2 must be on the same device as the + * overall let. However the result of \p e1 may be on a different device. + * - For a function \code fn(x, y) { body } \endcode the result of the function must be on the + * same device as \p body. However parameters \p x and \p may be on different devices, even + * different from each other. Every call to the function must use the same choice of parameter + * and result devices -- there is no 'device polymorphism' for Relay functions. + * + * Phase 2 + * ------- + * After flowing constraints we apply some defaulting heuristics (using a global default device) + * to fix the device for any as-yet unconstrained sub-expressions. + * - Unconstrained function result devices default to the global default device. + * - Unconstrained function parameters devices default to the device for the function result. + * - Unconstrained let-bound expression devices default to the device for the overall let. + * TODO(mbs): I may have over-innovated here and we simply want to bind all free domaints to + * the global default device. Worth a design doc with motivating examples I think. + * + * Phase 3 + * ------- + * Finally, the result of this analysis is reified into the result as: + * - Additional "on_device" attributes (an Attrs resolving to a \p FunctionOnDeviceAttrs) for + * every function (both top-level and local). These describe the devices for the function's + * parameters and the result. + * - Additional "device_copy" CallNodes where a copy is required in order to respect the + * intent of the original "on_device" CallNodes. + * - Additional "on_device" CallNodes where the device type of an expression does not match + * that of the lexically enclosing "on_device" CallNode or function attribute. In practice + * this means "on_device" CallNodes may appear in two places: + * - On a let-bound expression if its device differs from the overall let expression. + * - On a call argument if its device differs from the call result. In particular, the + * argument to a "device_copy" call will always be wrapped in an "on_device". (That may + * seem pedantic but simplifies downstream handling.) + * However since we make it easy to track devices for variables we never wrap an "on_device" + * around a var or global var. These uses of "on_device" imply both the argument and result are + * on the same device. We signal this by setting the 'is_fixed' OnDeviceAttrs field to true, + * which helps make this pass idempotent. + * + * A helper \p LexicalOnDeviceMixin class can be used by downstream transforms to recover the device + * for any expression for their own use, e.g. during memory planning. All downstream passes must + * preserve the lexical scoping of the "on_device" CallNodes. In particular conversion to ANF + * must respect the lexical scoping convention: + * \code + * f(on_device(g(h(a, b), c), device_type=CPU)) + * ==> + * let %x0 = on_device(h(a, b), device_type=CPU) + * let %x1 = on_device(g(%x0), device-type=CPU) + * f(on_device(%x1, device_type=CPU)) + * \endcode + * + * This pass should be run before FuseOps it can use device-specific fusion rules. + * + * 'Stored on' vs 'Executes on' + * ---------------------------- + * Obviously for a primitive call \code add(x, y) \endcode we can execute the primitive on the + * same device as will hold its result. Thus 'executes on' is the same as 'stored on' for + * primitives. + * + * But what about for arbitrary Relay expressions? Most backends (interpreter, graph, VM) are + * implicitly executed on the 'host' CPU, with only primitive evaluation handed off to specific + * devices, thus the notion of 'executes on' is mute. AOT backends on the other hand need to + * know exactly which device (possibly one of a number of available 'CPU'-like devices) is + * responsible for execution. Currently that's handled independently by the \p AnnotateTargets + * pass, but we'd like to fold that into device planning here to ensure everything is consistent. + * + * Obviously since tensors are passed-by-pointer it's quite possible to execute a Relay + * expression (eg an if expression) on one device even though the tensor data resides on + * another. But for AOT that flexibility seems excessive. So we'd like to just take 'executes on' + * to be 'stored on' exactly. In particular, for a Relay function, we'd like to be able to just + * compile the function body for the function's result device. + * + * This works after conversion to ANF provided the compilation for a let expression is prepared + * to make a cross-device call. However we leave it to a downstream transformation to heuristically + * minimize cross-device calls by moving device copies out of functions. E.g.: + * \code + * def @f() { // execute on CPU + * let x = on_device(...GPU computation..., device_type=GPU); + * device_copy(...GPU computation..., src_dev_type=GPU, dst_dev_type=CPU) + * } + * def @main() { + * ... call @f() on CPU ... + * } + * \endcode + * could be rewritten to: + * \code + * def @f() { // execute on GPU + * let x = ...GPU computation...; + * ...GPU computation... + * } + * def @main() { + * let x = device_copy(@f(), src_dev_type=GPU, dst_dev_type=CPU) + * ... use x on CPU ... + * } + * \endcode + * + * Higher-order shenanigans + * ------------------------ + * Relay is a 'mostly' higher-order language -- we can let-bind functions, pass functions + * as arguments (even anonymous functions), return functions, evaluate conditional expressions + * over functions, and so on. We handle this during constraint solving using the domain: + * \code + * D ::= -- first-order + * | fn(D,...,D):D -- higher-order + * \endcode + * In this way we can determine the device for all function parameters and results. E.g. for + * \code + * let f = fn(x, y) { ... } + * let g = fn(f, z) { f(z, z) } + * g(f, on_device(..., device_type=CPU)) + * \endcode + * the parameters \p x and \p y will be on the CPU. + * + * But now look closely at the call \code e1(e2, e3) \endcode. We know \p e1 must evaluate to a + * function. Our analysis must guarantee that the function's parameters and result devices are + * consistent for \p e2, \p e3, and the context of the call. But: + * - Which device holds the closure result of evaluating \p e1 ? + * - If \p e2 is of function type, what does that mean when we say every function parameter + * is on a device? + * - If \p e1 returns a function, what does that mean when we say every function result is + * on a device? + * + * Since higher-order aspects are later compiled away (by 'defunctionalization' + * aka 'firstification') we'd prefer not to have to answer any of those questions. In particular, + * we really don't want our domain \p D to allow for yet another device for the function closure. + * So we'll just force the 'device for a function' to be the same as the device for the function's + * result using the notion of the 'result domain' for a domain: + * \code + * result_domain() = + * result_domain(fn(D1,...,Dn):Dr) = result_domain(Dr) + * \endcode + * + * Similarly the domain does not have entries for tuples, references, or ADTs. Whenever the + * analysis encounters a function inside one of those it simply forces all argument and result + * devices for the function to match the device for the first-order expression. For example, + * if the tuple \code (fn(x, y) { ... }, 3) \endcode is on the GPU then the inner function + * parameters and result must similarly be on the GPU. + * + * ------- + * | AOR | This pass supports all of Relay. + * ------- + * ^ + * | + * `-- Mark's stamp of completeness :-) + * + * TODO(mbs): + * * Though on_device is the identity for all types we can't wrap it around functions/constructors + * taking type args (or at least not without changing type_infer.cc to see through them). + * This is not currently handled generally. + * * Proper diagnostics for unification failure using spans. + * * Make sure the pass is idempotent even after FuseOps etc. + * * Support application of constructors properly. Are they device polymorphic? + * * Replace DLDeviceType with TargetDevice, and unify 'target annotation' with 'device planning'. + * * Support running the pass post FuseOps (so need to understand primitive functions, both + * outlines and lined) and post the VM transforms (probably need to support more intrinsic + * forms?). + * * Don't hardcode the 'CPU' device for shape funcs etc, and distinguish between the default + * device for primitives vs the default device for the rest of Relay. + * * We'll probably need some support for partial 'device polymorphism' for functions once we + * incorporate targets and memory scopes into the domain. For example it's ok for the function + * body to be executed on different device ids provided they have the same target and memory + * scope. + * * Might be simpler to just let every type have a device annotation rather than work in + * a separate domain? + * * Switch to expr.CopyWith(...) form once implemented to avoid unnecessary copies. + * * The original device_annotation.cc RewriteAnnotatedOps removed all "on_device" calls + * in tuples at the top level of function bodies or main expression, irrespective of the + * "on_device" body. What's up with that? + */ + +#include "./device_planner.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "../op/annotation/annotation.h" +#include "../op/memory/device_copy.h" + +namespace tvm { +namespace relay { +namespace transform { + +namespace { + +/*! + * \brief As for GetDeviceCopyProps, but for the call to the lowered TIR primitives rather + * than the original "device_copy" operator. + * + * See te_compiler.cc for where this rewriting occurs. + */ +DeviceCopyProps GetPrimitiveDeviceCopyProps(const CallNode* call_node) { + auto tir_call_attrs = call_node->attrs.as(); + if (tir_call_attrs == nullptr) { + return {}; + } + if (tir_call_attrs->metadata.count("source_device") != 1 || + tir_call_attrs->metadata.count("dst_device") != 1) { + return {}; + } + ICHECK_EQ(call_node->args.size(), 1) << "device_copy is of arity 1"; + return { + call_node->args[0], + static_cast( + Downcast(tir_call_attrs->metadata["source_device"])->value), + static_cast(Downcast(tir_call_attrs->metadata["dst_device"])->value)}; +} + +class DeviceDomain; +using DeviceDomainPtr = std::shared_ptr; + +/****** +****** Domains +******/ + +/*! + * \brief Represents the domain over which we collect equality constraints. + * + * \code + * D ::= ?x? -- first order, free + * | -- first order, bound + * | fn(D1, ..., Dn):Dr -- higher order + * \endcode + * + * We require a function value to be on the same device as its result. To support that we need + * a notion of the 'result domain' of a domain: + * \code + * result_domain(?x?) = ?x? + * result_domain() = + * result_domain(fn(D1, ..., Dn):Dr) = result_domain(Dr) + * \endcode + */ +class DeviceDomain { + public: + /*! + * \brief Constructs a first-order domain of \p device_type, which may be + * \p kInvalidDeviceType to indicate the domain is free. + */ + explicit DeviceDomain(DLDeviceType device_type) : device_type_(device_type) {} + + /*! + * \brief Constructs a higher-order domain, where \p args_and_result contain the + * function argument and result domains in order. + */ + explicit DeviceDomain(std::vector args_and_result) + : device_type_(kInvalidDeviceType), args_and_result_(std::move(args_and_result)) {} + + /*! \brief Returns true if domain is first-order and free. */ + bool is_free() const { return device_type_ == kInvalidDeviceType && args_and_result_.empty(); } + + /*! \brief Returns true if domain is higher-order. */ + bool is_higher_order() const { return !args_and_result_.empty(); } + + DLDeviceType first_order_device_type() const { + ICHECK(args_and_result_.empty()); + return device_type_; + } + + size_t function_arity() const { + ICHECK(!args_and_result_.empty()); + return args_and_result_.size() - 1UL; + } + + DeviceDomainPtr function_param(size_t i) const { + ICHECK(!args_and_result_.empty()); + ICHECK_LT(i + 1, args_and_result_.size()); + return args_and_result_[i]; + } + + DeviceDomainPtr function_result() const { + ICHECK(!args_and_result_.empty()); + return args_and_result_.back(); + } + + private: + /*! + * \brief If this is a function domain then always kInvalidDevice. Otherwise will be + * kInvalidDevice if the domain is still free, or the specific concrete device if the domain is + * bound. + */ + const DLDeviceType device_type_; + + /*! + * \brief If this is a function domain then the sub-domains for each of the function's + * arguments, and the domain for its result. Otherwise empty. + */ + const std::vector args_and_result_; + + friend struct DeviceDomainHash; + friend struct DeviceDomainEqual; + friend class DeviceDomains; +}; + +// Ye olde boost hash mixer. +constexpr size_t mix(size_t h1, size_t h2) { + return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2)); +} + +// The following hash and equality helpers give each free first-order domain pointer its own +// distinct identity. +struct DeviceDomainHash { + size_t operator()(const DeviceDomainPtr& domain) const { + if (domain->is_free()) { + // Give each free first-order domain its own identity. + return static_cast(reinterpret_cast(domain.get())); + } else { + size_t h = domain->args_and_result_.size(); + h = mix(h, std::hash()(static_cast(domain->device_type_))); + for (const auto& sub_domain_ptr : domain->args_and_result_) { + h = mix(h, DeviceDomainHash()(sub_domain_ptr)); + } + return h; + } + } +}; + +struct DeviceDomainEqual { + public: + bool operator()(const DeviceDomainPtr& lhs, const DeviceDomainPtr& rhs) const { + if (lhs->args_and_result_.size() != rhs->args_and_result_.size()) { + // Mismatched arities are never equal. + // (Though we'll never ask to do such a comparison explicitly, the hash map + // may do so implicitly due to hash collisions.) + return false; + } + if (lhs->is_free() && rhs->is_free()) { + // Compare first-order free domains by their address. + return lhs.get() == rhs.get(); + } + if (lhs->args_and_result_.empty()) { + // Compare first-order domains by their device type -- free vs bound will compare as false. + return lhs->device_type_ == rhs->device_type_; + } else { + // Compare higher-order domains pointwise. + for (size_t i = 0; i < lhs->args_and_result_.size(); ++i) { + if (!(*this)(lhs->args_and_result_[i], rhs->args_and_result_[i])) { + return false; + } + } + return true; + } + } +}; + +/*! + * \brief Tracks the device domains for a set of expressions w.r.t. an equivalence relation + * built up by calls to \p Unify. + */ +class DeviceDomains { + public: + DeviceDomains() = default; + + /*! + * \brief Returns a domain appropriate for \p type who's result domain is bound + * to \p device_type. If \p device_type is \p kInvalidDeviceType then the entire domain + * will be free. + */ + static DeviceDomainPtr MakeDomain(const Type& type, DLDeviceType device_type) { + if (const auto* func_type_node = type.as()) { + std::vector args_and_result; + args_and_result.reserve(func_type_node->arg_types.size() + 1); + for (const auto& arg_type : func_type_node->arg_types) { + args_and_result.emplace_back(MakeDomain(arg_type, kInvalidDeviceType)); + } + args_and_result.emplace_back(MakeDomain(func_type_node->ret_type, device_type)); + return std::make_shared(std::move(args_and_result)); + } else { + return std::make_shared(device_type); + } + } + + /*! + * \brief Returns a higher-order domain with \p args_and_results. + */ + static DeviceDomainPtr MakeDomain(std::vector arg_and_results) { + return std::make_shared(std::move(arg_and_results)); + } + + /*! \brief Returns a domain with the given result device type appropriate \p device_type. */ + static DeviceDomainPtr ForDeviceType(const Type& type, DLDeviceType device_type) { + ICHECK_NE(device_type, kInvalidDeviceType); + return MakeDomain(type, device_type); + } + + /*! \brief Returns a free domain appropriate for \p type. */ + static DeviceDomainPtr Free(const Type& type) { return MakeDomain(type, kInvalidDeviceType); } + + /*! \brief Returns the domain representing the equivalence class containing \p domain. */ + DeviceDomainPtr Lookup(DeviceDomainPtr domain) { + DeviceDomainPtr root = domain; + while (true) { + auto itr = domain_to_equiv_.find(root); + if (itr == domain_to_equiv_.end()) { + break; + } + ICHECK_NE(itr->second, root); + root = itr->second; + ICHECK_NOTNULL(root); + } + // Path compression. + while (domain != root) { + auto itr = domain_to_equiv_.find(domain); + ICHECK(itr != domain_to_equiv_.end()); + domain = itr->second; + ICHECK_NOTNULL(domain); + itr->second = root; + } + return root; + } + + /*! + * \brief Returns the domain accounting for all bound devices in \p lhs and \p rhs. + * + * Throws \p Error on failure. + */ + DeviceDomainPtr Join(const DeviceDomainPtr& lhs, const DeviceDomainPtr& rhs) { + // TODO(mbs): Proper diagnostics. + ICHECK_EQ(lhs->args_and_result_.size(), rhs->args_and_result_.size()) + << "Device domains:" << std::endl + << ToString(lhs) << std::endl + << "and" << std::endl + << ToString(rhs) << std::endl + << "do not have the same kind and can't be unified."; + if (rhs->is_free()) { + return lhs; + } else if (lhs->is_free()) { + return rhs; + } else if (lhs->args_and_result_.empty()) { + // Must have consistent device types for first order domains. + if (lhs->device_type_ != rhs->device_type_) { + // TODO(mbs): Proper diagnostics. + std::ostringstream os; + os << "Inconsistent device types " << lhs->device_type_ << " and " << rhs->device_type_; + throw Error(os.str()); + } + return lhs; + } else { + // Recurse for higher-order. + std::vector args_and_result; + args_and_result.reserve(lhs->args_and_result_.size()); + for (size_t i = 0; i < lhs->args_and_result_.size(); ++i) { + args_and_result.emplace_back(Unify(lhs->args_and_result_[i], rhs->args_and_result_[i])); + } + return MakeDomain(std::move(args_and_result)); + } + } + + /*! + * \brief Unifies \p lhs and \p rhs, returning the most-bound of the two. Fails if \p lhs and \p + * rhs disagree on bound device type. + * + * Throws \p Error on failure. + */ + // TODO(mbs): I don't think we need an occurs check since the program is well-typed, but + // given we have refs to functions I'm prepared to be surprised. + DeviceDomainPtr Unify(DeviceDomainPtr lhs, DeviceDomainPtr rhs) { + lhs = Lookup(lhs); + rhs = Lookup(rhs); + auto joined_domain = Join(lhs, rhs); + if (!DeviceDomainEqual()(lhs, joined_domain)) { + domain_to_equiv_.emplace(lhs, joined_domain); + } + if (!DeviceDomainEqual()(rhs, joined_domain)) { + domain_to_equiv_.emplace(rhs, joined_domain); + } + return joined_domain; + } + + /*! + * \brief Unifies \p lhs and \p rhs. If \p lhs is first-order and \p rhs is higher-order, + * require all arguments and result of \p rhs to unify with \p lhs. Otherwise same as + * \p Unify. + * + * Throws \p Error on failure. + */ + void UnifyCollapsed(const DeviceDomainPtr& lhs, const DeviceDomainPtr& rhs) { + if (!lhs->is_higher_order() && rhs->is_higher_order()) { + Collapse(lhs, rhs); + } else { + Unify(lhs, rhs); + } + } + + /*! \brief Returns true if a domain is known for \p expr. */ + bool contains(const Expr& expr) const { return expr_to_domain_.count(expr.get()); } + + /*! \brief Returns the domain representing \p expr. */ + DeviceDomainPtr DomainFor(const Expr& expr) { + ICHECK(expr.defined()); + auto itr = expr_to_domain_.find(expr.get()); + if (itr != expr_to_domain_.end()) { + return Lookup(itr->second); + } + auto domain = Free(expr->checked_type()); + expr_to_domain_.emplace(expr.get(), domain); + return domain; + } + + /*! + * \brief Returns the domain representing the callee (ie 'op') in \p call expression. If the + * callee is a primitive or special operation we handle it specially. Otherwise defers to \p + * DomainFor(call->op). + * + * This special handling is needed: + * - To handle the "on_device" and "device_copy" ops which constrain devices to the given devices. + * - To handle some special ops which constrain devices to the CPU. + * - To allow the same primitive to be called on different devices at different call sites. + * Since each call to the op can have a different domain we index the ops by the call expression + * rather than the op itself. + */ + DeviceDomainPtr DomainForCallee(const Call& call) { + auto itr = call_to_callee_domain_.find(call.get()); + if (itr != call_to_callee_domain_.end()) { + return Lookup(itr->second); + } + std::vector args_and_result; + + auto on_device_props = GetOnDeviceProps(call.get()); + auto device_copy_props = GetDeviceCopyProps(call.get()); + if (!device_copy_props.body.defined()) { + device_copy_props = GetPrimitiveDeviceCopyProps(call.get()); + } + + if (on_device_props.body.defined()) { + // on_device(expr, device_type=, is_fixed=false) + // on_device : fn():?x? + // + // on_device(expr, device_type=, is_fixed=true) + // on_device: fn(): + args_and_result.emplace_back( + ForDeviceType(on_device_props.body->checked_type(), on_device_props.device_type)); + if (on_device_props.is_fixed) { + args_and_result.emplace_back(args_and_result.front()); + } else { + args_and_result.emplace_back(Free(on_device_props.body->checked_type())); + } + } else if (device_copy_props.body.defined()) { + // device_copy(expr, src_dev_type=, dst_dev_type=) + // device_copy: fn(): + args_and_result.emplace_back( + ForDeviceType(device_copy_props.body->checked_type(), device_copy_props.src_dev_type)); + args_and_result.emplace_back( + ForDeviceType(device_copy_props.body->checked_type(), device_copy_props.dst_dev_type)); + } else if (call->op == alloc_storage_op) { + ICHECK_EQ(call->args.size(), 2U); + // alloc_storage(size, alignment, device_type=) + // alloc_storage: fn(, ): + const auto* attrs = call->attrs.as(); + args_and_result.emplace_back(cpu_domain_); + args_and_result.emplace_back(cpu_domain_); + args_and_result.emplace_back( + ForDeviceType(call->checked_type(), static_cast(attrs->device_type))); + } else if (call->op == alloc_tensor_op) { + ICHECK_EQ(call->args.size(), 3U); + // alloc_tensor(storage, offset, shape) + // alloc_tensor: fn(?x?, , ):?x? + auto free_domain = Free(call->checked_type()); + args_and_result.emplace_back(free_domain); + args_and_result.emplace_back(cpu_domain_); + args_and_result.emplace_back(cpu_domain_); + args_and_result.emplace_back(free_domain); + } else if (call->op == shape_func_op) { + ICHECK_EQ(call->args.size(), 3U); + // shape_func(func, inputs, outputs, is_inputs=[...]) + // shape_func: fn(..., , ): + // where ... is a free domain appropriate for func's type + args_and_result.emplace_back(Free(call->args[0]->checked_type())); + // TODO(mbs): I think this should be on the cpu only when is_input = [false], but + // what do we do when we have multiple arguments with different is_input values? + args_and_result.emplace_back(cpu_domain_); + args_and_result.emplace_back(cpu_domain_); + args_and_result.emplace_back(cpu_domain_); + } else if (call->op == shape_of_op) { + ICHECK_EQ(call->args.size(), 1U); + // shape_of(tensor) + // shape_of: fn(?x?): + args_and_result.emplace_back(Free(call->args[0]->checked_type())); + args_and_result.emplace_back(cpu_domain_); + } else if (call->op == invoke_tvm_op) { + ICHECK_EQ(call->args.size(), 3U); + // invoke_tvm_op(op, inputs, outputs) + // invoke_tvm_op: fn(..., ?x?, ?x?):?x? + // where ... is a free domain appropriate for op's type + auto free_domain = Free(call->checked_type()); + args_and_result.emplace_back(Free(call->args[0]->checked_type())); + args_and_result.emplace_back(free_domain); + args_and_result.emplace_back(free_domain); + args_and_result.emplace_back(free_domain); + } else if (call->op == reshape_tensor_op) { + ICHECK_EQ(call->args.size(), 2U); + // reshape_tensor(data, shape) + // reshape_tensor: fn(?x?, ):?x? + auto free_domain = Free(call->checked_type()); + args_and_result.emplace_back(free_domain); + args_and_result.emplace_back(cpu_domain_); + args_and_result.emplace_back(free_domain); + } else if (call->op->IsInstance()) { + // (arg1, ..., argn) + // : fn(?x?, ..., ?x?):?x? + // (all args and result must be first-order). + auto free_domain = Free(arb_); + for (size_t i = 0; i < call->args.size(); ++i) { + args_and_result.emplace_back(free_domain); + } + args_and_result.emplace_back(free_domain); + } else { + // Defer to normal case where op can be an arbitrary expression. + return DomainFor(call->op); + } + auto domain = MakeDomain(std::move(args_and_result)); + call_to_callee_domain_.emplace(call.get(), domain); + return domain; + } + + /*! \brief Unifies the domains for expressions \p lhs and \p rhs. */ + void UnifyExprExact(const Expr& lhs, const Expr& rhs) { + auto lhs_domain = DomainFor(lhs); + auto rhs_domain = DomainFor(rhs); + try { + Unify(lhs_domain, rhs_domain); + } catch (const Error& e) { + // TODO(mbs): Proper diagnostics. + LOG(FATAL) << "Incompatible devices for expressions:" << std::endl + << PrettyPrint(lhs) << std::endl + << "with device:" << std::endl + << ToString(lhs_domain) << "and:" << std::endl + << PrettyPrint(rhs) << std::endl + << "with device:" << std::endl + << ToString(rhs_domain) << std::endl + << e.what(); + } + } + + /*! + * \brief Unifies the domain for \p expr with \p expected_domain. + */ + void UnifyExprExact(const Expr& expr, const DeviceDomainPtr& expected_domain) { + auto actual_domain = DomainFor(expr); + try { + Unify(actual_domain, expected_domain); + } catch (const Error& e) { + // TODO(mbs): Proper diagnostics. + LOG(FATAL) << "Incompatible devices for expression:" << std::endl + << PrettyPrint(expr) << std::endl + << "with actual device:" << std::endl + << ToString(actual_domain) << std::endl + << "and expected device:" << std::endl + << ToString(expected_domain) << std::endl + << e.what(); + } + } + + /*! + * \brief Unifies the domain for \p expr with \p expected_domain. + * If \p expected_domain is higher-order but \p expr is first-order, require all arguments + * and the result of \p expected_domain to have the same domain as for \p expr. + */ + void UnifyExprCollapsed(const Expr& expr, const DeviceDomainPtr& expected_domain) { + auto actual_domain = DomainFor(expr); + try { + UnifyCollapsed(actual_domain, expected_domain); + } catch (const Error& e) { + // TODO(mbs): Proper diagnostics. + LOG(FATAL) << "Incompatible devices for expression:" << std::endl + << PrettyPrint(expr) << std::endl + << "with actual device:" << std::endl + << ToString(actual_domain) << std::endl + << "and expected device:" << std::endl + << ToString(expected_domain) << std::endl + << e.what(); + } + } + + /*! \brief Returns true if \p domain contains any free sub-domains. */ + bool AnyFree(DeviceDomainPtr domain) { + domain = Lookup(domain); + if (domain->is_free()) { + return true; + } + for (const auto& sub_domain : domain->args_and_result_) { + if (AnyFree(sub_domain)) { + return true; + } + } + return false; + } + + /* + * \brief Force all domains in \p higher_order_domain to unify with \p first_order_domain. + * This can be used to handle functions within tuples, references and ADTs since we don't + * attempt to track anything beyond 'the device' for expressions of those first-order types. + * + * Throws \p Error on failure. + */ + void Collapse(const DeviceDomainPtr& first_order_domain, + const DeviceDomainPtr& higher_order_domain) { + for (size_t i = 0; i < higher_order_domain->function_arity(); ++i) { + Unify(higher_order_domain->function_param(i), first_order_domain); + } + Unify(higher_order_domain->function_result(), first_order_domain); + } + + /*! \brief Force all free domains in \p domain to default to \p default_device_type. */ + void SetDefault(DeviceDomainPtr domain, DLDeviceType default_device_type) { + ICHECK_NE(default_device_type, kInvalidDeviceType); + domain = Lookup(domain); + if (domain->is_free()) { + // Will never throw since lhs is free. + Unify(domain, std::make_shared(default_device_type)); + } else if (!domain->args_and_result_.empty()) { + for (const auto& sub_domain : domain->args_and_result_) { + SetDefault(sub_domain, default_device_type); + } + } + } + + /*! + * \brief If \p domain is higher-order and its result domain is free, force it to + * \p default_device_type. Then force any remaining free domains to the result domain + * (freshly defaulted or original). If \p domain is first-order same as \p SetDefault. + */ + void SetResultDefaultThenParams(const DeviceDomainPtr& domain, DLDeviceType default_device_type) { + if (!domain->is_higher_order()) { + SetDefault(domain, default_device_type); + return; + } + DLDeviceType result_device_type = ResultDeviceType(domain); + if (result_device_type == kInvalidDeviceType) { + // If the function result device is still free use the given default. + result_device_type = default_device_type; + } + // Default any remaining free parameters to the function result device. + SetDefault(domain, result_device_type); + } + + /*! \brief Returns one-line description of \p domain for debugging. */ + std::string ToString(DeviceDomainPtr domain) { + domain = Lookup(domain); + std::ostringstream os; + if (domain->is_free()) { + // first-order free + os << "?" << static_cast(reinterpret_cast(domain.get())) << "?"; + } else if (domain->args_and_result_.empty()) { + // first-order bound + os << "<" << domain->device_type_ << ">"; + } else { + // higher-order + os << "fn("; + for (size_t i = 0; i + 1 < domain->args_and_result_.size(); ++i) { + if (i > 0) { + os << ","; + } + os << ToString(domain->args_and_result_[i]); + } + os << "):" << ToString(domain->args_and_result_.back()); + } + return os.str(); + } + + /*! \brief Returns description of entire system of constraints for debugging */ + std::string ToString() { + std::ostringstream os; + for (const auto& pair : expr_to_domain_) { + os << "expression:" << std::endl + << PrettyPrint(GetRef(pair.first)) << std::endl + << "domain:" << std::endl + << ToString(pair.second) << std::endl + << std::endl; + } + for (const auto& pair : call_to_callee_domain_) { + os << "call:" << std::endl + << PrettyPrint(GetRef(pair.first)) << std::endl + << "callee domain:" << std::endl + << ToString(pair.second) << std::endl + << std::endl; + } + return os.str(); + } + + /*! + * \brief Returns the result domain for \p domain (see defn in DeviceDomain comment). + */ + DeviceDomainPtr ResultDomain(DeviceDomainPtr domain) { + domain = Lookup(domain); + while (!domain->args_and_result_.empty()) { + domain = Lookup(domain->args_and_result_.back()); + } + return domain; + } + + /*! + * \brief Returns the result (possibly free) device type for \p domain (see defn in DeviceDomain + * comment). + */ + DLDeviceType ResultDeviceType(const DeviceDomainPtr& domain) { + return ResultDomain(domain)->first_order_device_type(); + } + + private: + /*! \brief Intrinsics we need to handle specially. */ + const Op& alloc_storage_op = Op::Get("memory.alloc_storage"); + const Op& alloc_tensor_op = Op::Get("memory.alloc_tensor"); + const Op& shape_of_op = Op::Get("vm.shape_of"); + const Op& invoke_tvm_op = Op::Get("vm.invoke_tvm_op"); + const Op& shape_func_op = Op::Get("vm.shape_func"); + const Op& reshape_tensor_op = Op::Get("vm.reshape_tensor"); + /*! \brief The CPU device type for special operators such as dynamic shape functions. */ + const DLDeviceType cpu_device_type_ = kDLCPU; + /*! \brief Placeholder for any first-order type. */ + Type arb_ = TupleType(); + /*! \brief The domain for first-order expressions on the CPU. */ + DeviceDomainPtr cpu_domain_ = ForDeviceType(arb_, cpu_device_type_); + + /*! \brief Maps expressions to their domains as determined during analysis. */ + std::unordered_map expr_to_domain_; + + /*! + * \brief Maps call expressions to the domains for their callee where the callee is a primitive. + */ + std::unordered_map call_to_callee_domain_; + + /*! \brief Maps device domains to their equivalent domains as determined during unification. */ + std::unordered_map + domain_to_equiv_; +}; + +/****** +****** Phase 0 +******/ + +/*! + * \brief Rewrites "on_device" calls to handle some special cases. + */ +class RewriteOnDevices : public ExprMutator { + public: + RewriteOnDevices() = default; + + private: + Expr VisitExpr_(const TupleGetItemNode* tuple_get_item_node) final { + Expr tuple = VisitExpr(tuple_get_item_node->tuple); + // TODO(mbs): Avoid copy. + Expr tuple_get_item = + TupleGetItem(tuple, tuple_get_item_node->index, tuple_get_item_node->span); + auto props = GetOnDeviceProps(tuple); + if (props.body.defined() && !props.is_fixed) { + VLOG(1) << "wrapping tuple get item:" << std::endl + << PrettyPrint(GetRef(tuple_get_item_node)) << std::endl + << "with \"on_device\" for device " << props.device_type; + return OnDevice(tuple_get_item, props.device_type, /*is_fixed=*/false); + } else { + return tuple_get_item; + } + } + + Expr VisitExpr_(const LetNode* let_node) final { + auto expr = GetRef(let_node); + std::vector> bindings; + while (const auto* inner_let_node = expr.as()) { + Expr inner_let = GetRef(inner_let_node); + Expr value = VisitExpr(inner_let_node->value); + auto props = GetOnDeviceProps(value); + if (props.body.defined() && !props.is_fixed) { + VLOG(1) << "revising let-bound expression of let:" << std::endl + << PrettyPrint(expr) << std::endl + << "to be fixed to device " << props.device_type; + value = OnDevice(props.body, props.device_type, /*is_fixed=*/true); + } + bindings.emplace_back(inner_let_node->var, value, inner_let_node->span); + expr = inner_let_node->body; + } + expr = VisitExpr(expr); + // TODO(mbs): Avoid copy. + for (auto itr = bindings.rbegin(); itr != bindings.rend(); ++itr) { + expr = Let(/*var=*/std::get<0>(*itr), /*value=*/std::get<1>(*itr), expr, + /*span=*/std::get<2>(*itr)); + } + return expr; + } + + Expr VisitExpr_(const FunctionNode* function_node) final { + Expr body = VisitExpr(function_node->body); + auto props = GetOnDeviceProps(body); + if (props.body.defined() && !props.is_fixed) { + VLOG(1) << "revising body of function:" << std::endl + << PrettyPrint(GetRef(function_node)) << std::endl + << "to be fixed to device " << props.device_type; + body = OnDevice(props.body, props.device_type, /*is_fixed=*/true); + } + // TODO(mbs): Avoid copy + return Function(function_node->params, body, function_node->ret_type, + function_node->type_params, function_node->attrs, function_node->span); + } +}; + +/****** +****** Phase 1 +******/ + +/* + * \brief Collects the system of device constraints for all sub-expressions in a module. + * It is possible some devices remain free and will need to be defaulted by \p DeviceDefaulter. + */ +class DeviceAnalyzer : public ExprVisitor { + public: + explicit DeviceAnalyzer(IRModule mod) + : mod_(std::move(mod)), domains_(std::make_unique()) {} + + /*! + * \brief Returns the expression-to-device-domain map for all expressions in all the global + * function definitions in the module. Expressions may have free domains, these will be resolved + * by \p DeviceDefaulter below. + */ + std::unique_ptr Analyze() { + VLOG_CONTEXT << "DeviceAnalyzer"; + for (const auto& pair : mod_->functions) { + VLOG(1) << "collecting constraints for '" << PrettyPrint(pair.first) << "'"; + domains_->UnifyExprExact(pair.first, pair.second); + VisitExpr(pair.second); + } + return std::move(domains_); + } + + private: + void VisitExpr_(const CallNode* call_node) final { + auto call = GetRef(call_node); + + // Find the higher-order domain for the callee. See DomainForCallee for the special rules + // for primitives. + VisitExpr(call_node->op); + auto func_domain = domains_->DomainForCallee(call); // higher-order + + // Build the domain for the function implied by its arguments and call context. + ICHECK_EQ(func_domain->function_arity(), call_node->args.size()); + std::vector args_and_result_domains; + args_and_result_domains.reserve(call_node->args.size() + 1); + for (const auto& arg : call_node->args) { + args_and_result_domains.emplace_back(domains_->DomainFor(arg)); + VisitExpr(arg); + } + args_and_result_domains.emplace_back(domains_->DomainFor(call)); + auto implied_domain = + DeviceDomains::MakeDomain(std::move(args_and_result_domains)); // higher-order + + VLOG(1) << "initial call function domain:" << std::endl + << domains_->ToString(func_domain) << std::endl + << "and implied domain:" << std::endl + << domains_->ToString(implied_domain) << "for call:" << std::endl + << PrettyPrint(call); + + // The above must match. + try { + domains_->Unify(func_domain, implied_domain); // higher-order + } catch (const Error& e) { + // TODO(mbs): Proper diagnostics. + LOG(FATAL) << "Function parameters and result devices do not match those of call. Call:" + << std::endl + << PrettyPrint(call) << std::endl + << "with function devices:" << std::endl + << domains_->ToString(func_domain) << std::endl + << "and implied call devices:" << std::endl + << domains_->ToString(implied_domain) << std::endl + << e.what(); + } + + VLOG(1) << "final call function domain:" << std::endl + << domains_->ToString(func_domain) << std::endl + << "for call:" << std::endl + << PrettyPrint(call); + } + + void VisitExpr_(const LetNode* let_node) final { + Expr expr = GetRef(let_node); + // Iteratively visit let nodes to avoid stack overflow. + while (expr->IsInstance()) { + Let let = Downcast(expr); + // Let var must be same device as value it is bound to. + domains_->UnifyExprExact(let->var, let->value); // may be higher-order + // Let body must be same device as overall let. + domains_->UnifyExprExact(let, let->body); // may be higher-order + + VisitExpr(let->var); + VisitExpr(let->value); + + expr = let->body; + } + + // Visit the last body + VisitExpr(expr); + } + + void VisitExpr_(const FunctionNode* function_node) final { + // No need to step into fused primitive functions as they are lowered individually according + // to the devices of all their call sites. + if (function_node->HasNonzeroAttr(attr::kPrimitive)) { + return; + } + + auto function = GetRef(function_node); + auto func_domain = domains_->DomainFor(function); // higher-order + + // The function body domain must match the function result domain. + domains_->UnifyExprExact(function_node->body, + func_domain->function_result()); // may be higher-order + + VLOG(1) << "initial function domain:" << std::endl + << domains_->ToString(func_domain) << std::endl + << "and function body domain:" << std::endl + << domains_->ToString(domains_->DomainFor(function_node->body)) << std::endl + << "for function:" << std::endl + << PrettyPrint(function); + + ICHECK_EQ(func_domain->function_arity(), function_node->params.size()); + for (size_t i = 0; i < function_node->params.size(); ++i) { + // The parameter domains must match the function argument domains. + domains_->UnifyExprExact(function_node->params[i], + func_domain->function_param(i)); // may be higher-order + VisitExpr(function_node->params[i]); + } + + // If the function already has an "on_device" attribute then we can further + // constrain the function's domain to match it. + Optional opt_attrs = + function_node->GetAttr(FunctionOnDeviceAttrs::kFunctionAttrsKey); + if (opt_attrs) { + std::vector args_and_result; + for (size_t i = 0; i < function_node->params.size(); ++i) { + args_and_result.emplace_back( + domains_->ForDeviceType(function_node->params[i]->checked_type(), + GetFunctionParamDeviceType(function_node, i))); + } + args_and_result.emplace_back(domains_->ForDeviceType( + function_node->body->checked_type(), GetFunctionResultDeviceType(function_node))); + auto annotation_domain = domains_->MakeDomain(std::move(args_and_result)); + try { + domains_->Unify(func_domain, annotation_domain); // higher-order + } catch (const Error& e) { + // TODO(mbs): Proper diagnostics. + LOG(FATAL) + << "Function devices are incompatible with its \"on_device\" annotation. Function:" + << std::endl + << PrettyPrint(function) << std::endl + << "with function devices:" << std::endl + << domains_->ToString(func_domain) << std::endl + << "and annotation devices:" << std::endl + << domains_->ToString(annotation_domain) << std::endl + << e.what(); + } + } + + VisitExpr(function_node->body); + + VLOG(1) << "final function domain:" << std::endl + << domains_->ToString(func_domain) << std::endl + << "and function body domain:" << std::endl + << domains_->ToString(domains_->DomainFor(function_node->body)) << std::endl + << "for function:" << std::endl + << PrettyPrint(function); + } + + void VisitExpr_(const TupleNode* tuple_node) final { + Tuple tuple = GetRef(tuple_node); + for (size_t i = 0; i < tuple->fields.size(); i++) { + auto domain = domains_->DomainFor(tuple->fields[i]); // may be higher-order + domains_->UnifyExprCollapsed(tuple, domain); // collapse to first-order if needed + VisitExpr(tuple->fields[i]); + } + } + + void VisitExpr_(const TupleGetItemNode* tuple_get_item_node) final { + TupleGetItem tuple_get_item = GetRef(tuple_get_item_node); + auto domain = domains_->DomainFor(tuple_get_item); // may be higher-order + domains_->UnifyExprCollapsed(tuple_get_item_node->tuple, + domain); // collapse to first-order if needed + VisitExpr(tuple_get_item_node->tuple); + } + + class DevicePatternAnalyzer : public PatternVisitor { + public: + DevicePatternAnalyzer(DeviceDomains* domains, const ExprNode* adt_node) + : domains_(domains), adt_node_(adt_node) {} + + private: + void VisitPattern_(const PatternVarNode* pattern_var_node) final { + auto var_domain = domains_->DomainFor(pattern_var_node->var); // may be higher order + domains_->UnifyExprCollapsed(GetRef(adt_node_), + var_domain); // collapse to first-order if needed + } + + /*! \brief (Mutable borrow of) the domains for all expressions processed so far. */ + DeviceDomains* domains_; + /*! \brief The expression for the ADT we are matching over. */ + const ExprNode* adt_node_; + }; + + void VisitPattern(const Pattern& pattern) final {} + + void VisitExpr_(const MatchNode* match_node) final { + // For match node, we unify the value and the rhs of each clause + Match match = GetRef(match_node); + auto match_domain = domains_->DomainFor(match); // may be higher-order + DevicePatternAnalyzer pattern_analyzer(domains_.get(), match->data.get()); + domains_->UnifyExprCollapsed(match->data, match_domain); // collapse to first-order if needed + for (const auto& clause : match->clauses) { + pattern_analyzer.VisitPattern(clause->lhs); + domains_->UnifyExprExact(clause->rhs, match_domain); + VisitExpr(clause->rhs); + } + VisitExpr(match_node->data); + } + + void VisitExpr_(const GlobalVarNode* global_var_node) final { + domains_->DomainFor(GetRef(global_var_node)); + } + + void VisitExpr_(const VarNode* var_node) final { domains_->DomainFor(GetRef(var_node)); } + + void VisitExpr_(const ConstantNode* constant_node) final { + domains_->DomainFor(GetRef(constant_node)); + } + + void VisitExpr_(const ConstructorNode* constructor_node) final { + // Probably needs to be device polymorphic. + domains_->DomainFor(GetRef(constructor_node)); + } + + void VisitExpr_(const IfNode* if_node) final { + auto ife = GetRef(if_node); + auto domain = domains_->DomainFor(ife); // may be higher-order + domains_->UnifyExprCollapsed(if_node->cond, domain); // collapse to first-order if needed + domains_->UnifyExprExact(if_node->true_branch, domain); + domains_->UnifyExprExact(if_node->false_branch, domain); + VisitExpr(if_node->cond); + VisitExpr(if_node->true_branch); + VisitExpr(if_node->false_branch); + } + + void VisitExpr_(const OpNode* op) final { + // no-op, primitive operators are handled at their call-sites. + } + + void VisitExpr_(const RefCreateNode* ref_create_node) final { + auto ref_create = GetRef(ref_create_node); + auto domain = domains_->DomainFor(ref_create_node->value); // may be higher-order + domains_->UnifyExprCollapsed(ref_create, domain); // collapse to first-order if needed + VisitExpr(ref_create_node->value); + } + + void VisitExpr_(const RefReadNode* ref_read_node) final { + auto ref_read = GetRef(ref_read_node); + auto domain = domains_->DomainFor(ref_read); // may be higher-order + domains_->UnifyExprCollapsed(ref_read_node->ref, domain); // collapse to first-order if needed + VisitExpr(ref_read_node->ref); + } + + void VisitExpr_(const RefWriteNode* ref_write_node) final { + auto ref_write = GetRef(ref_write_node); + auto domain = domains_->DomainFor(ref_write->value); // may be higher-order + domains_->UnifyExprCollapsed(ref_write->ref, domain); // collapse to first-order if needed + domains_->UnifyExprCollapsed(ref_write, domain); // collapse to first-order if needed + VisitExpr(ref_write_node->ref); + VisitExpr(ref_write_node->value); + } + + /*! \brief The module we are analyzing. */ + IRModule mod_; + /*! \brief The domains for all expressions processed so far. */ + std::unique_ptr domains_; +}; + +/****** +****** Phase 2 +******/ + +/*! + * \brief Ensures every sub-expression in a module has a device type, using both the global + * default and some local heuristics to avoid unnecessary additional "device_copy" CallNodes. + * + * TODO(mbs): I think this is deterministic? We do however visit the top-level defs in hashmap + * order. + */ +class DeviceDefaulter : public ExprVisitor { + public: + DeviceDefaulter(IRModule mod, std::unique_ptr domains, + DLDeviceType default_device_type) + : mod_(std::move(mod)), + domains_(std::move(domains)), + default_device_type_(default_device_type) {} + + std::unique_ptr Default() { + VLOG_CONTEXT << "DeviceDefaulter"; + for (const auto& pair : mod_->functions) { + VLOG(1) << "defaulting devices for '" << PrettyPrint(pair.first) << "'"; + VisitExpr(pair.second); + } + return std::move(domains_); + } + + private: + void VisitExpr_(const FunctionNode* function_node) final { + if (function_node->HasNonzeroAttr(attr::kPrimitive)) { + return; + } + + auto function = GetRef(function_node); + auto func_domain = domains_->DomainFor(function); // higher-order + ICHECK_EQ(func_domain->function_arity(), function_node->params.size()); + if (domains_->AnyFree(func_domain)) { + VLOG(1) << "before defaulting function:" << std::endl << domains_->ToString(func_domain); + domains_->SetResultDefaultThenParams(func_domain, default_device_type_); + VLOG(1) << "after defaulting function:" << std::endl << domains_->ToString(func_domain); + } + VisitExpr(function_node->body); + } + + void VisitExpr_(const CallNode* call_node) final { + auto call = GetRef(call_node); + auto func_domain = domains_->DomainForCallee(call); // higher-order + ICHECK_EQ(func_domain->function_arity(), call_node->args.size()); + if (domains_->AnyFree(func_domain)) { + // For calls to Relay functions this step is identical to that for VisitExpr_(FunctionNode*) + // above. But for calls to primitives we may still need to force free domains to be + // defaulted. + VLOG(1) << "before defaulting callee:" << std::endl << domains_->ToString(func_domain); + domains_->SetResultDefaultThenParams(func_domain, default_device_type_); + VLOG(1) << "after defaulting callee:" << std::endl << domains_->ToString(func_domain); + } + return ExprVisitor::VisitExpr_(call_node); + } + + void VisitExpr_(const LetNode* let_node) final { + Expr expr = GetRef(let_node); + // Iteratively visit let nodes to avoid stack overflow. + while (expr->IsInstance()) { + Let let = Downcast(expr); + // If the let-var device is still free force it to match the overall let. + auto let_domain = domains_->DomainFor(let); // may be higher-order + DLDeviceType let_device_type = domains_->ResultDeviceType(let_domain); + ICHECK_NE(let_device_type, kInvalidDeviceType); + auto let_var_domain = domains_->DomainFor(let->var); // may be higher-order + if (domains_->AnyFree(let_var_domain)) { + VLOG(1) << "before defaulting let-var:" << std::endl << domains_->ToString(let_var_domain); + domains_->SetDefault(let_var_domain, let_device_type); + VLOG(1) << "after defaulting let-var:" << std::endl << domains_->ToString(let_var_domain); + } + VisitExpr(let->var); + VisitExpr(let->value); + expr = let->body; + } + VisitExpr(expr); + } + + /*! \brief The module we are processing. */ + IRModule mod_; + /*! \brief The domains for all expressions. */ + std::unique_ptr domains_; + /*! \brief The default device type. */ + DLDeviceType default_device_type_; +}; + +/****** +****** Phase 3 +******/ + +/*! + * \brief Inserts missing "device_copy" CallNodes, and ensures the device type of every + * sub-expression in a module can be easily recovered by a later transformation using simple + * lexical scoping rules (e.g. for memory planning). + * + * - Discard any existing "on_device" CallNodes since their job is done. Similarly, discard + * any existing "device_copy" CallNodes which are no-ops. + * + * - Functions are given an "on_device" attribute bound to a FunctionOnDeviceAttrs to capture + * the device type for its parameters and result. + * + * - Additional "device_copy" CallNodes are inserted wherever there's a transition between + * storage device types. Since the DeviceAnalyzer phase succeeded this can only happen + * where the original program explicitly allowed a transition using an "on_device" CallNode. + * That is, we do not not try to 'fix' a program with inconsistent devices. + * + * - Additional "on_device" CallNodes are inserted so that a later transform can discover + * the device for an arbitrary sub-expression by looking only for the lexically enclosing + * "on_device" CallNode or "on_device" function attribute. In particular, since function + * arguments and let-bound expressions can be on a device different from the function + * or let body itself we will insert "on_device" CallNodes to spell out any differences. This + * applies even to the argument to a "device_copy" CallNode, which may look pedantic but + * keeps downstream processing simple. The "on_device" calls should be removed before code gen, + * which is easily done on-the-fly. + */ +class DeviceCapturer : public ExprMutator { + public: + DeviceCapturer(IRModule mod, std::unique_ptr domains) + : mod_(std::move(mod)), domains_(std::move(domains)) {} + + IRModule Capture() { + VLOG_CONTEXT << "CaptureDevices"; + IRModule result(/*functions=*/{}, mod_->type_definitions, mod_->Imports(), mod_->source_map); + for (const auto& pair : mod_->functions) { + VLOG(1) << "capturing devices for '" << PrettyPrint(pair.first) << "'"; + result->Add(pair.first, Downcast(Mutate(pair.second))); + } + return result; + } + + private: + // Nothing interesting for VarNode, ConstantNode, GlobalVarNode and OpNode. + + Expr VisitExpr_(const TupleNode* tuple_node) final { + auto tuple = GetRef(tuple_node); + Array fields; + fields.reserve(tuple_node->fields.size()); + for (const auto& field : tuple_node->fields) { + fields.push_back(VisitChild(tuple, field)); + } + // TODO(mbs): Avoid copy + return Tuple(std::move(fields), tuple_node->span); + } + + Expr VisitExpr_(const FunctionNode* function_node) final { + if (function_node->HasNonzeroAttr(attr::kPrimitive)) { + return GetRef(function_node); + } + + auto function = GetRef(function_node); + auto func_domain = domains_->DomainFor(function); // higher-order + VLOG(1) << "capturing function:" << std::endl + << PrettyPrint(function) << std::endl + << "with domain:" << std::endl + << domains_->ToString(func_domain); + + // Gather the parameter and result device types for the "on_device" function attribute. + ICHECK_EQ(func_domain->function_arity(), function_node->params.size()); + DLDeviceType result_device_type = domains_->ResultDeviceType(func_domain); + ICHECK_NE(result_device_type, kInvalidDeviceType); + Array param_device_types; + param_device_types.reserve(function_node->params.size()); + for (size_t i = 0; i < function_node->params.size(); ++i) { + DLDeviceType param_device_type = domains_->ResultDeviceType(func_domain->function_param(i)); + ICHECK_NE(param_device_type, kInvalidDeviceType); + param_device_types.push_back(param_device_type); + } + + // Rewrite the body. Note that the body may have begun with an "on_device" so + // be prepared to insert a "device_copy". + Expr body = VisitChild( + /*lexical_device_type=*/result_device_type, + /*expected_device_type=*/result_device_type, + /*child_device_type=*/GetDeviceType(function_node->body), function_node->body); + + // TODO(mbs): Avoid copy + Function func = Function(function_node->params, body, function_node->ret_type, + function_node->type_params, function_node->attrs, function_node->span); + return FunctionOnDevice(func, param_device_types, result_device_type); + } + + Expr VisitExpr_(const CallNode* call_node) final { + auto call = GetRef(call_node); + DLDeviceType call_device_type = GetDeviceType(call); + + auto on_device_props = GetOnDeviceProps(call_node); + if (on_device_props.body.defined()) { + // We're done with the original "on_device" calls and can pinch them out. + // Note that this step has already been simulated by GetDeviceType. + return VisitExpr(on_device_props.body); + } + + auto device_copy_props = GetDeviceCopyProps(call_node); + if (device_copy_props.body.defined()) { + DLDeviceType src_device_type = device_copy_props.src_dev_type; + ICHECK_EQ(call_device_type, device_copy_props.dst_dev_type); + if (call_device_type == src_device_type) { + // We can pinch out existing "device_copy" CallNodes if their source and destinations + // match. + return VisitExpr(device_copy_props.body); + } + // else: handle as for any other call. + } + + auto func_domain = domains_->DomainForCallee(call); // higher-order + VLOG(1) << "considering call:" << std::endl + << PrettyPrint(call) << std::endl + << "on device " << call_device_type << " with function domain:" << std::endl + << domains_->ToString(func_domain); + DLDeviceType result_device_type = domains_->ResultDeviceType(func_domain); + ICHECK_NE(result_device_type, kInvalidDeviceType); + + // The callee is on the current device. + Expr op = VisitChild( + /*lexical_device_type=*/call_device_type, + /*expected_device_type=*/call_device_type, + /*child_device_type=*/result_device_type, call_node->op); + + // Each argument can be on the device for the corresponding function parameter. However if + // any of those differ from the overall call device then wrap them in an "on_device" to + // help downstream transforms track devices lexically. + Array args; + args.reserve(call_node->args.size()); + ICHECK_EQ(func_domain->function_arity(), call->args.size()); + for (size_t i = 0; i < call_node->args.size(); ++i) { + DLDeviceType param_device_type = domains_->ResultDeviceType(func_domain->function_param(i)); + ICHECK_NE(param_device_type, kInvalidDeviceType) + << "for parameter " << i << " for call:" << std::endl + << PrettyPrint(call); + args.push_back(VisitChild(/*lexical_device_type=*/call_device_type, + /*expected_device_type=*/param_device_type, + /*child_device_type=*/GetDeviceType(call_node->args[i]), + call_node->args[i])); + } + // TODO(mbs): Avoid copy + return Call(std::move(op), std::move(args), call_node->attrs, call_node->type_args, + call_node->span); + } + + Expr VisitExpr_(const LetNode* let_node) final { + Expr expr = GetRef(let_node); + // Iterate through chained lets, provided they all agree on their device type. + DLDeviceType let_device_type = GetDeviceType(expr); + std::vector> bindings; + while (const auto* inner_let_node = expr.as()) { + Expr inner_let = GetRef(inner_let_node); + if (GetDeviceType(inner_let) != let_device_type) { + // We have a device transition which needs to be handled. + break; + } + // The let-bound value can be on a different device than the overall let. However if those + // devices don't agree wrap the let-bound value in an "on_device" to help downstream + // transforms track devices lexically. + Expr value = VisitChild(/*lexical_device_type=*/let_device_type, + /*expected_device_type=*/GetDeviceType(inner_let_node->var), + /*child_device_type=*/GetDeviceType(inner_let_node->value), + inner_let_node->value); + bindings.emplace_back(inner_let_node->var, value, inner_let_node->span); + expr = inner_let_node->body; + } + Expr body = VisitChild(/*lexical_device_type=*/let_device_type, + /*expected_device_type=*/let_device_type, + /*child_device_type=*/GetDeviceType(expr), expr); + for (auto itr = bindings.rbegin(); itr != bindings.rend(); ++itr) { + body = Let(/*var=*/std::get<0>(*itr), /*value=*/std::get<1>(*itr), body, + /*span=*/std::get<2>(*itr)); + } + return body; + } + + Expr VisitExpr_(const IfNode* if_node) final { + auto ife = GetRef(if_node); + Expr cond = VisitChild(ife, if_node->cond); + Expr true_branch = VisitChild(ife, if_node->true_branch); + Expr false_branch = VisitChild(ife, if_node->false_branch); + // TODO(mbs): Avoid copy + return If(cond, true_branch, false_branch, if_node->span); + } + + Expr VisitExpr_(const TupleGetItemNode* tuple_get_item_node) final { + auto tuple_get_item = GetRef(tuple_get_item_node); + Expr tuple = VisitChild(tuple_get_item, tuple_get_item_node->tuple); + // TODO(mbs): Avoid copy + return TupleGetItem(tuple, tuple_get_item_node->index, tuple_get_item_node->span); + } + + Expr VisitExpr_(const RefCreateNode* ref_create_node) final { + auto ref_create = GetRef(ref_create_node); + Expr value = VisitChild(ref_create, ref_create_node->value); + // TODO(mbs): Avoid copy + return RefCreate(value, ref_create_node->span); + } + + Expr VisitExpr_(const RefReadNode* ref_read_node) final { + auto ref_read = GetRef(ref_read_node); + Expr ref = VisitChild(ref_read, ref_read_node->ref); + // TODO(mbs): Avoid copy + return RefRead(ref, ref_read_node->span); + } + + Expr VisitExpr_(const RefWriteNode* ref_write_node) final { + auto ref_write = GetRef(ref_write_node); + Expr ref = VisitChild(ref_write, ref_write_node->ref); + Expr value = VisitChild(ref_write, ref_write_node->value); + // TODO(mbs): Avoid copy + return RefWrite(ref, value, ref_write_node->span); + } + + Expr VisitExpr_(const ConstructorNode* constructor_node) final { + auto constructor = GetRef(constructor_node); + // check we have a device type. + (void)GetDeviceType(constructor); + return constructor; + } + + Expr VisitExpr_(const MatchNode* match_node) final { + auto match = GetRef(match_node); + Expr data = VisitChild(match, match_node->data); + Array clauses; + clauses.reserve(match_node->clauses.size()); + for (const auto& clause : match_node->clauses) { + Pattern lhs = VisitPattern(clause->lhs); // actually a no-op, so we're not checking vars + Expr rhs = VisitChild(match, clause->rhs); + clauses.push_back(Clause(lhs, rhs)); + } + // TODO(mbs): Avoid copy + return Match(data, std::move(clauses), match_node->complete, match_node->span); + } + + DLDeviceType GetDeviceType(const Expr& expr) { + // Look through any "on_device" CallNodes, to mimic how we will be pinching them out. + auto props = GetOnDeviceProps(expr); + Expr true_expr = props.body.defined() ? props.body : expr; + ICHECK(domains_->contains(true_expr)); + // If expr is higher order we'll return only the result domain's device type. + DLDeviceType device_type = domains_->ResultDeviceType(domains_->DomainFor(true_expr)); + ICHECK_NE(device_type, kInvalidDeviceType) + << "no device type was determined for expression:" << std::endl + << PrettyPrint(true_expr); + return device_type; + } + + /*! + * \brief Reconcile the \p child_device_type for \p child with both the \p expected_device_type + * (as required by the expression context the \p child is in) and the \p lexical_device_type + * (as a downstream transform would infer based only on lexically enclosing "on_device" + * CallNodes and function attributes.) Generally \p lexical_device_type and \p + * expected_device_type are the same by definition, but may differ in arguments to functions + * and let-bound expressions. + * + * If \p child_device_type differs from \p expected_device_type, wrap it as: + * \code + * device_copy(on_device(child', device_type=child_device_type), + * src_dev_type=child_device_type, dst_dev_type=expected_device_type) + * \endcode + * (where child is rewritten to child'). Note the pedantic spelling out of "on_device" on the + * child. + * + * If \p expected_device_type differs from \p lexical_device_type, then (also) wrap + * the expression as: + * \code + * on_device(..., device_type=expected_device_type) + * \endcode + * + * TODO(mbs): There's no attempt at sharing here. If usage of child's node could be wrapped + * by a "device_copy", even though those copies will generally all be to the same destination + * device. + */ + Expr VisitChild(DLDeviceType lexical_device_type, DLDeviceType expected_device_type, + DLDeviceType child_device_type, const Expr& child) { + ICHECK_NE(lexical_device_type, kInvalidDeviceType); + ICHECK_NE(expected_device_type, kInvalidDeviceType); + if (child->IsInstance()) { + // Primitive operators don't need to be rewritten and can have a different domain for + // each call site. + return child; + } + Expr result = VisitExpr(child); + if (child_device_type != expected_device_type) { + VLOG(1) << "creating " << DeviceCopyOp()->name << " from device type " << child_device_type + << " to device type " << expected_device_type << " for:" << std::endl + << PrettyPrint(result); + // Also wrap the child in an "on_device" so downstream transforms can track devices + // lexically. + result = OptOnDevice(result, child_device_type, /*is_fixed=*/true); + result = DeviceCopy(result, child_device_type, expected_device_type); + } + if (expected_device_type != lexical_device_type) { + VLOG(1) << "creating " << OnDeviceOp()->name << " for device type " << expected_device_type + << " for:" << std::endl + << PrettyPrint(result); + result = OptOnDevice(result, expected_device_type, /*is_fixed=*/true); + } + return result; + } + + /*! + * Common case of visiting a direct \p child of \p parent where by default the \p child + * is expected to be on the same device as the \p parent. + */ + Expr VisitChild(const Expr& parent, const Expr& child) { + DLDeviceType expected_device_type = GetDeviceType(parent); + DLDeviceType child_device_type = GetDeviceType(child); + return VisitChild(expected_device_type, expected_device_type, child_device_type, child); + } + + /*! \brief Module we are rewriting, so we can lookup global variables. */ + IRModule mod_; + /*! \brief Device domain for every expression from DeviceAnalyzer. */ + std::unique_ptr domains_; +}; + +/*! \brief Rewrite the "on_device" calls (and implicitly re-type-check). */ +tvm::transform::Pass Rewrite() { + auto pass_func = [](Function f, IRModule m, transform::PassContext ctxt) { + return Downcast(RewriteOnDevices().Mutate(f)); + }; + return tvm::relay::transform::CreateFunctionPass(pass_func, 0, "PlanDevicesRewrite", {}); +} + +/*! \brief Run the remaining phases. */ +tvm::transform::Pass PlanDevicesCore(DLDeviceType default_device_type) { + return tvm::transform::CreateModulePass( + [=](IRModule mod, tvm::transform::PassContext pass_cnxt) -> IRModule { + // Collect the system of constraints for every sub-expression using existing "on_device" + // and "device_copy" calls. + std::unique_ptr domains = DeviceAnalyzer(mod).Analyze(); + VLOG(1) << "Domains after analysis:" << std::endl << domains->ToString(); + + // Choose sensible default devices for every sub-expression if otherwise unconstrained + // by existing "on_device" or "device_copy" calls. + domains = DeviceDefaulter(mod, std::move(domains), default_device_type).Default(); + VLOG(1) << "Domains after defaulting: " << std::endl << domains->ToString(); + + // Insert "device_copy" and "on_device" CallNodes where needed to unambiguously capture + // the above map, and attach additional "on_device" attributes to all function + // definitions. + return DeviceCapturer(mod, std::move(domains)).Capture(); + }, + /*opt_level=*/0, "PlanDevicesCore", {}); +} + +} // namespace + +/****** +****** Pass +******/ + +TVM_DLL tvm::transform::Pass PlanDevices(DLDeviceType default_device_type) { + std::vector passes; + passes.emplace_back(Rewrite()); + passes.emplace_back(PlanDevicesCore(default_device_type)); + return tvm::transform::Sequential(std::move(passes), "PlanDevices"); +} + +TVM_REGISTER_GLOBAL("relay._transform.PlanDevices") + .set_body_typed([](const Device& default_device) { + return PlanDevices(default_device.device_type); + }); + +/****** +****** Visitor/Mutator Helpers +******/ + +// TODO(mbs): These have grown to be pretty substantial and should be hoisted out. +// TODO(mbs): Probably less code-dup if we redefine the memoizing mutator on top +// of the generic Functor. + +DLDeviceType LexicalOnDeviceMixin::GetInScopeDeviceType(const Expr& expr) const { + auto props = GetOnDeviceProps(expr); + if (props.body.defined() && props.is_fixed) { + // Look through any fixed "on_device" annotations. + return props.device_type; + } + if (expr->IsInstance()) { + // Lookup variable binding. + auto itr = var_device_types_.find(Downcast(expr)); + if (itr == var_device_types_.end()) { + return kInvalidDeviceType; + } else { + return itr->second; + } + } + // Otherwise use the currently in-scope device type. + if (expr_device_types_.empty()) { + return kInvalidDeviceType; + } else { + return expr_device_types_.back(); + } +} + +void LexicalOnDeviceMixin::EnterFunctionBody() { ++function_nesting_; } + +void LexicalOnDeviceMixin::ExitFunctionBody() { + ICHECK_GT(function_nesting_, 0); + --function_nesting_; +} + +void LexicalOnDeviceMixin::PushDeviceType(DLDeviceType device_type) { + if (device_type == kInvalidDeviceType) { + return; + } + expr_device_types_.emplace_back(device_type); +} + +void LexicalOnDeviceMixin::PopDeviceType() { + if (expr_device_types_.empty()) { + return; + } + expr_device_types_.pop_back(); +} + +void LexicalOnDeviceMixin::PushBoundVar(Var var, DLDeviceType device_type) { + if (device_type == kInvalidDeviceType) { + return; + } + ICHECK(var_device_types_.find(var) == var_device_types_.end()); + var_device_types_.emplace(std::move(var), device_type); +} + +void LexicalOnDeviceMixin::PopBoundVar(const Var& var) { + auto itr = var_device_types_.find(var); + if (itr == var_device_types_.end()) { + return; + } + var_device_types_.erase(itr); +} + +void DeviceAwareExprVisitor::VisitExpr_(const FunctionNode* function_node) { + if (function_node->HasNonzeroAttr(attr::kPrimitive)) { + // No tracking inside primitive functions. + DeviceAwareVisitExpr_(function_node); + } else { + // Function parameters come into scope. + for (size_t i = 0; i < function_node->params.size(); ++i) { + PushBoundVar(function_node->params[i], GetFunctionParamDeviceType(function_node, i)); + } + // Entering scope of function body. + PushDeviceType(GetFunctionResultDeviceType(function_node)); + EnterFunctionBody(); + + DeviceAwareVisitExpr_(function_node); + + // Leaving scope of function body. + ExitFunctionBody(); + PopDeviceType(); + // Function parameters go out of scope. + for (size_t i = 0; i < function_node->params.size(); ++i) { + PopBoundVar(function_node->params[i]); + } + } +} + +void DeviceAwareExprVisitor::VisitExpr_(const LetNode* let_node) { + PreVisitLetBlock_(let_node); + std::vector bindings; + Expr expr = GetRef(let_node); + while (const auto* inner_let_node = expr.as()) { + // Let-bound var (in pre visited version) goes into scope. + // (We'll just assume this is a letrec). + PushBoundVar(inner_let_node->var, GetInScopeDeviceType(inner_let_node->value)); + PreVisitLetBinding_(inner_let_node->var, inner_let_node->value); + bindings.emplace_back(inner_let_node); + expr = inner_let_node->body; + } + + VisitExpr(expr); + + for (auto itr = bindings.rbegin(); itr != bindings.rend(); ++itr) { + // Let-bound var goes out of scope. + PopBoundVar((*itr)->var); + PostVisitLet_(*itr); + } + PostVisitLetBlock_(let_node); +} + +void DeviceAwareExprVisitor::VisitExpr_(const CallNode* call_node) { + auto props = GetOnDeviceProps(call_node); + if (props.body.defined() && props.is_fixed) { + // Entering lexical scope of fixed "on_device" call. + PushDeviceType(props.device_type); + VisitExpr(props.body); + // Leaving lexical scope of "on_device" call. + PopDeviceType(); + } else { + DeviceAwareVisitExpr_(call_node); + } +} + +void DeviceAwareExprVisitor::DeviceAwareVisitExpr_(const FunctionNode* function_node) { + ExprVisitor::VisitExpr_(function_node); +} + +void DeviceAwareExprVisitor::DeviceAwareVisitExpr_(const CallNode* call_node) { + ExprVisitor::VisitExpr_(call_node); +} + +void DeviceAwareExprVisitor::PreVisitLetBlock_(const LetNode* let_node) { + // no-op +} + +void DeviceAwareExprVisitor::PreVisitLetBinding_(const Var& var, const Expr& value) { + VisitExpr(var); + VisitExpr(value); +} + +void DeviceAwareExprVisitor::PostVisitLet_(const LetNode* let_node) { + // no-op +} + +void DeviceAwareExprVisitor::PostVisitLetBlock_(const LetNode* let_node) { + // no-op +} + +Expr DeviceAwareExprMutator::VisitExpr_(const FunctionNode* function_node) { + if (function_node->HasNonzeroAttr(attr::kPrimitive)) { + // No tracking inside primitive functions. + return DeviceAwareVisitExpr_(function_node); + } else { + // Function parameters come into scope. + for (size_t i = 0; i < function_node->params.size(); ++i) { + PushBoundVar(function_node->params[i], GetFunctionParamDeviceType(function_node, i)); + } + // Entering scope of function body. + PushDeviceType(GetFunctionResultDeviceType(function_node)); + EnterFunctionBody(); + + Expr result = DeviceAwareVisitExpr_(function_node); + + // Leaving scope of function body. + ExitFunctionBody(); + PopDeviceType(); + // Function parameters go out of scope. + for (size_t i = 0; i < function_node->params.size(); ++i) { + PopBoundVar(function_node->params[i]); + } + + return result; + } +} + +Expr DeviceAwareExprMutator::VisitExpr_(const LetNode* let_node) { + PreVisitLetBlock_(let_node); + std::vector> bindings; + Expr expr = GetRef(let_node); + while (const auto* inner_let_node = expr.as()) { + // Let-bound var (in pre visited version) goes into scope. + // (We'll just assume this is a letrec.) + PushBoundVar(inner_let_node->var, GetInScopeDeviceType(inner_let_node->value)); + std::pair pair = PreVisitLetBinding_(inner_let_node->var, inner_let_node->value); + bindings.emplace_back(pair.first, pair.second, inner_let_node->span, inner_let_node); + expr = inner_let_node->body; + } + + expr = VisitExpr(expr); + + for (auto itr = bindings.rbegin(); itr != bindings.rend(); ++itr) { + // Let-bound var goes out of scope. + const LetNode* pre_let_node = std::get<3>(*itr); + PopBoundVar(pre_let_node->var); + Let post_let = Let(/*var=*/std::get<0>(*itr), /*value=*/std::get<1>(*itr), + /*body=*/expr, /*span=*/std::get<2>(*itr)); + expr = PostVisitLet_(pre_let_node, post_let.get()); + } + return PostVisitLetBlock_(let_node, expr.as()); +} + +Expr DeviceAwareExprMutator::VisitExpr_(const CallNode* call_node) { + auto props = GetOnDeviceProps(call_node); + if (props.body.defined() && props.is_fixed) { + // Entering lexical scope of fixed "on_device" call. + PushDeviceType(props.device_type); + Expr expr = VisitExpr(props.body); + // Leaving lexical scope of "on_device" call. + PopDeviceType(); + return OnDevice(expr, props.device_type, props.is_fixed); + } else { + return DeviceAwareVisitExpr_(call_node); + } +} + +Expr DeviceAwareExprMutator::DeviceAwareVisitExpr_(const FunctionNode* function_node) { + return ExprMutator::VisitExpr_(function_node); +} + +Expr DeviceAwareExprMutator::DeviceAwareVisitExpr_(const CallNode* call_node) { + return ExprMutator::VisitExpr_(call_node); +} + +void DeviceAwareExprMutator::PreVisitLetBlock_(const LetNode* let_node) { /* no-op */ +} + +std::pair DeviceAwareExprMutator::PreVisitLetBinding_(const Var& var, + const Expr& value) { + return std::make_pair(Downcast(VisitExpr(var)), VisitExpr(value)); +} + +Expr DeviceAwareExprMutator::PostVisitLet_(const LetNode* pre_let_node, + const LetNode* post_let_node) { + if (pre_let_node->var == post_let_node->var && pre_let_node->value == post_let_node->value && + pre_let_node->body == post_let_node->body) { + return GetRef(pre_let_node); + } else { + return GetRef(post_let_node); + } +} + +Expr DeviceAwareExprMutator::PostVisitLetBlock_(const LetNode* pre_let_node, + const LetNode* post_let_node) { + if (pre_let_node->var == post_let_node->var && pre_let_node->value == post_let_node->value && + pre_let_node->body == post_let_node->body) { + return GetRef(pre_let_node); + } else { + return GetRef(post_let_node); + } +} + +} // namespace transform +} // namespace relay +} // namespace tvm diff --git a/src/relay/transforms/device_planner.h b/src/relay/transforms/device_planner.h new file mode 100644 index 0000000000000..e2ec84863904a --- /dev/null +++ b/src/relay/transforms/device_planner.h @@ -0,0 +1,314 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef TVM_RELAY_TRANSFORMS_DEVICE_PLANNER_H_ +#define TVM_RELAY_TRANSFORMS_DEVICE_PLANNER_H_ + +#include +#include +#include +#include + +#include +#include +#include + +#include "../op/annotation/annotation.h" + +namespace tvm { +namespace relay { +namespace transform { + +// PlanDevices() is declared in the public . + +/*! + * \brief Helper class for expression transformers which need to keep track of the device + * holding the results of expressions and bound variables. This is recovered from the + * "on_device" function attributes and fixed "on_device" CallNodes added by the PlanDevices + * pass. + * + * \sa \p DeviceAwareExpr{Visitor,Mutator}. + */ +class LexicalOnDeviceMixin { + protected: + /*! + * \brief Returns the device type on which the result of \p expr should/will be stored, assuming + * Push/Pop DeviceType/BoundVar have been correctly called. Returns \p kInvalidDeviceType if + * stack is empty and no bound vars have device types. + */ + DLDeviceType GetInScopeDeviceType(const Expr& expr) const; + + /*! \brief Indicate a function body is being entered. */ + void EnterFunctionBody(); + + /*! \brief Indicate a function body has been processed. */ + void ExitFunctionBody(); + + /*! \brief Push a device type onto the lexical device stack. Ignore if \p kInvalidDeviceType. */ + void PushDeviceType(const DLDeviceType device_type); + + /*! \brief Pop a device type from the lexical device stack. Ignore if stack is empty. */ + void PopDeviceType(); + + /*! \brief Remember that \p var will be stored on \p device_type. Ignore if \p kInvalidDeviceType. + * + * CAUTION: Despite the name we don't support re-entering the same function body. + */ + void PushBoundVar(Var var, DLDeviceType device_type); + + /*! \brief Remove the binding for \p var to it's device type. Ignore if var is not bound. */ + void PopBoundVar(const Var& var); + + /*! + * \brief Returns the number of function definitions wrapping the currently visited expression. + */ + int function_nesting() const { return function_nesting_; } + + private: + /*! + * \brief The number of function bodies entered. Since many transforms need to distinguish global + * functions from local functions this supports the mixin's \p is_global() helper method. + */ + int function_nesting_ = 0; + + /*! + * \brief The stack of lexically enclosing "on_device" devices types, from outermost to innermost. + * When visiting an expression other than a variable we can assume the expression result is + * to be stored on device_type_.back(). + */ + std::vector expr_device_types_; + /*! + * \brief A map from in-scope variable to their device types. We may assume the variable is only + * ever bound to a value stored on this device at runtime. + */ + std::unordered_map + var_device_types_; +}; + +template +class DeviceAwareExprFunctor; + +/*! + * \brief ExprFunctor which tracks devices. We only support 'visitor' style implementation + * with no additional arguments, thus this is equivalent to \p DeviceAwareExprVisitor without + * any memoization. + */ +template <> +class DeviceAwareExprFunctor : public ExprFunctor, + public LexicalOnDeviceMixin { + private: + using TSuper = ExprFunctor; + + public: + void VisitExpr_(const FunctionNode* function_node) { + if (function_node->HasNonzeroAttr(attr::kPrimitive)) { + // No tracking inside primitive functions. + return DeviceAwareVisitExpr_(function_node); + } else { + // Function parameters come into scope. + for (size_t i = 0; i < function_node->params.size(); ++i) { + PushBoundVar(function_node->params[i], GetFunctionParamDeviceType(function_node, i)); + } + // Entering scope of function body. + PushDeviceType(GetFunctionResultDeviceType(function_node)); + EnterFunctionBody(); + + DeviceAwareVisitExpr_(function_node); + + // Leaving scope of function body. + ExitFunctionBody(); + PopDeviceType(); + // Function parameters go out of scope. + for (size_t i = 0; i < function_node->params.size(); ++i) { + PopBoundVar(function_node->params[i]); + } + } + } + + void VisitExpr_(const LetNode* let_node) { + PreVisitLetBlock_(let_node); + std::vector bindings; + Expr expr = GetRef(let_node); + while (const auto* inner_let_node = expr.as()) { + // Let-bound var (in pre visited version) goes into scope. + // (We'll just assume this is a letrec.) + PushBoundVar(inner_let_node->var, GetInScopeDeviceType(inner_let_node->value)); + PreVisitLetBinding_(inner_let_node->var, inner_let_node->value); + bindings.emplace_back(inner_let_node); + expr = inner_let_node->body; + } + + VisitExpr(expr); + + for (auto itr = bindings.rbegin(); itr != bindings.rend(); ++itr) { + // Let-bound var goes out of scope. + const LetNode* visited_let_node = *itr; + PopBoundVar(visited_let_node->var); + PostVisitLet_(visited_let_node); + } + PostVisitLetBlock_(let_node); + } + + void VisitExpr_(const CallNode* call_node) { + auto props = GetOnDeviceProps(call_node); + if (props.body.defined() && props.is_fixed) { + // Entering lexical scope of fixed "on_device" call. + PushDeviceType(props.device_type); + VisitExpr(props.body); + // Leaving lexical scope of "on_device" call. + PopDeviceType(); + } else { + DeviceAwareVisitExpr_(call_node); + } + } + + /*! + * \brief These are as for VisitExpr_. Devices for expressions and function parameters will be + * tracked automatically. Default implementation defers to ExprMutator::VisitExpr_. For + * functions the function_nesting count will already include that of \p function_node. + */ + + virtual void DeviceAwareVisitExpr_(const FunctionNode* function_node) { + return TSuper::VisitExpr_(function_node); + } + + virtual void DeviceAwareVisitExpr_(const CallNode* call_node) { + return TSuper::VisitExpr_(call_node); + } + + /*! + * \brief Visit the first let in a chain of let expressions before any let bindings or final + * body has been visited. Default implementation is a no-op. + */ + virtual void PreVisitLetBlock_(const LetNode* let_node) { /* no-op */ + } + + /*! + * \brief Visit a let-bound expression before the let body has been visited. Devices for the + * let-bound variable will be tracked automatically. Default implementation just visits var and + * value. + */ + virtual void PreVisitLetBinding_(const Var& var, const Expr& value) { + VisitExpr(var); + VisitExpr(value); + } + + /*! + * \brief Visit a let expression after the let-bound value and body have been visited. + * Default implementation is a no-op. + */ + virtual void PostVisitLet_(const LetNode* let_node) { /* no-op */ + } + + /*! + * \brief Visit the first let in a chain of let expressions after it has been visited. + * Default implementation is a no-op. + */ + virtual void PostVisitLetBlock_(const LetNode* let_node) {} +}; + +/*! \brief ExprVisitor which tracks devices. */ +class DeviceAwareExprVisitor : public ExprVisitor, public LexicalOnDeviceMixin { + public: + using ExprVisitor::VisitExpr_; + + void VisitExpr_(const FunctionNode* function_node) final; + void VisitExpr_(const LetNode* let_node) final; + void VisitExpr_(const CallNode* call_node) final; + + /*! + * \brief These are as for VisitExpr_. Devices for expressions and function parameters will be + * tracked automatically. Default implementation defers to ExprMutator::VisitExpr_. For + * functions the function_nesting count will already include that of \p function_node. + */ + virtual void DeviceAwareVisitExpr_(const FunctionNode* function_node); + virtual void DeviceAwareVisitExpr_(const CallNode* call_node); + + /*! + * \brief Visit the first let in a chain of let expressions before any let bindings or final + * body has been visited. Default implementation is a no-op. + */ + virtual void PreVisitLetBlock_(const LetNode* let_node); + + /*! + * \brief Visit a let-bound expression before the let body has been visited. Devices for the + * let-bound variable will be tracked automatically. Default implementation just visits var and + * value. + */ + virtual void PreVisitLetBinding_(const Var& var, const Expr& value); + + /*! + * \brief Visit a let expression after the let-bound value and body have been visited. + * Default implementation is a no-op. + */ + virtual void PostVisitLet_(const LetNode* let_node); + + /*! + * \brief Visit the first let in a chain of let expressions after it has been visited. + * Default implementation is a no-op. + */ + virtual void PostVisitLetBlock_(const LetNode* let_node); +}; + +/*! \brief ExprMutator which tracks devices. */ +class DeviceAwareExprMutator : public ExprMutator, public LexicalOnDeviceMixin { + public: + Expr VisitExpr_(const FunctionNode* function_node) final; + Expr VisitExpr_(const LetNode* let_node) final; + Expr VisitExpr_(const CallNode* call_node) final; + + /*! + * \brief These are as for VisitExpr_. Devices for expressions and function parameters will be + * tracked automatically. Default implementation defers to ExprMutator::VisitExpr_. For + * functions the function_nesting count will already include that of \p function_node. + */ + virtual Expr DeviceAwareVisitExpr_(const FunctionNode* function_node); + virtual Expr DeviceAwareVisitExpr_(const CallNode* call_node); + + /*! + * \brief Visit the first let in a chain of let expressions before any let bindings or final + * body has been visited. Default implementation is a no-op. + */ + virtual void PreVisitLetBlock_(const LetNode* let_node); + + /*! + * \brief Visit a let-bound expression before the let body has been visited. Devices for the + * let-bound variable will be tracked automatically. Default implementation just visits var and + * value. + */ + virtual std::pair PreVisitLetBinding_(const Var& var, const Expr& value); + + /*! + * \brief Visit a let expression after the let-bound value and body have been visited. + * Default implementation just returns a reference to the post-visited node. + */ + virtual Expr PostVisitLet_(const LetNode* pre_let_node, const LetNode* post_let_node); + + /*! + * \brief Visit the first let in a chain of let expressions after it has been visited. + * Default implementation returns reference to let node. + */ + virtual Expr PostVisitLetBlock_(const LetNode* pre_let_node, const LetNode* post_let_node); +}; + +} // namespace transform +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_TRANSFORMS_DEVICE_PLANNER_H_ diff --git a/tests/python/relay/test_pass_plan_devices.py b/tests/python/relay/test_pass_plan_devices.py new file mode 100644 index 0000000000000..6c3d2e266b8d2 --- /dev/null +++ b/tests/python/relay/test_pass_plan_devices.py @@ -0,0 +1,1405 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License + + +"""Unit tests for the PlanDevices pass. We check: + - The pass alone given the expected AST, though we need to manually run InferTypes. + - The pass is idempotent.""" + +# TODO(mbs): All the input/expected programs should be directly quoted using @script +# TODO(mbs): Not testing Match and Constructor since not supported by Python bindings? +# TODO(mbs): Add back reference implementation tests once VM is ready. + +import tvm +from tvm import relay +import tvm.testing +import numpy as np + +N = 5 +M = 7 +CPU = tvm.device("cpu") # device_type=1 +GPU = tvm.device("cuda") # device_type=2 +DEFAULT = GPU + + +def rewrite_and_assert(in_mod, expected_mod): + """Manually run the pass and assert it's structurally equals to the expected.""" + actual_mod = relay.transform.InferType()(in_mod) + actual_mod = relay.transform.PlanDevices(DEFAULT)(actual_mod) + actual_mod = relay.transform.InferType()(actual_mod) + expected_mod = relay.transform.InferType()(expected_mod) + if not tvm.ir.structural_equal(actual_mod, expected_mod): + # Print everything in full so we can see what's going on when things fail. + print("Input module:") + print(in_mod) + print("Expected module:") + print(expected_mod) + print("Actual module:") + print(actual_mod) + # Assert again so as to see the actual disagreeing sub-expressions. + tvm.ir.assert_structural_equal(actual_mod, expected_mod) + + +def rand(shape): + return np.random.rand(*shape).astype("float32") + + +def rands(shape, n): + return [rand(shape) for i in range(n)] + + +def exercise(in_mod: tvm.IRModule, expected_mod: tvm.IRModule, reference_func, args): + """Test in_mod against expected_mod and reference_func using args.""" + # Correctness + rewrite_and_assert(in_mod, expected_mod) + # Idempotence + rewrite_and_assert(expected_mod, expected_mod) + # TODO(mbs): Add back compiling and comparing to reference implementation once VM is ready. + + +# +# Annotation shorthands +# + + +def on_cpu(expr: relay.Expr): + return relay.annotation.on_device(expr, CPU) + + +def on_gpu(expr: relay.Expr): + return relay.annotation.on_device(expr, GPU) + + +def cpu_to_gpu(expr: relay.Expr): + return relay.op.device_copy(expr, CPU, GPU) + + +def gpu_to_cpu(expr: relay.Expr): + return relay.op.device_copy(expr, GPU, CPU) + + +def fixed_cpu(expr: relay.Expr): + return relay.annotation.on_device(expr, CPU, True) + + +def fixed_gpu(expr: relay.Expr): + return relay.annotation.on_device(expr, GPU, True) + + +def test_plain(): + shape = (N, M) + a = relay.var("a", shape=shape) + b = relay.var("b", shape=shape) + c = relay.var("c", shape=shape) + d = relay.var("d", shape=shape) + + # def @main(a, b, c, d) { subtract(add(a, b), add(c, d)) } + def input(): + return tvm.IRModule.from_expr( + relay.Function([a, b, c, d], relay.subtract(relay.add(a, b), relay.add(c, d))) + ) + + # def @main(a, b, c, d, on_device={param_device_types=[2,2,2,2], result_device_type=2}) { + # subtract(add(a, b), add(c, d)) + # } + def expected(): + return tvm.IRModule.from_expr( + relay.annotation.function_on_device( + relay.Function([a, b, c, d], relay.subtract(relay.add(a, b), relay.add(c, d))), + [GPU, GPU, GPU, GPU], + GPU, + ) + ) + + def ref(a, b, c, d): + return np.subtract(np.add(a, b), np.add(c, d)) + + exercise(input(), expected(), ref, rands(shape, 4)) + + +def test_left_add_on_cpu(): + shape = (N, M) + a = relay.var("a", shape=shape) + b = relay.var("b", shape=shape) + c = relay.var("c", shape=shape) + d = relay.var("d", shape=shape) + + # def @main(a, b, c, d) { subtract(on_cpu(add(a, b)), add(c, d)) } + def input(): + return tvm.IRModule.from_expr( + relay.Function([a, b, c, d], relay.subtract(on_cpu(relay.add(a, b)), relay.add(c, d))) + ) + + # def @main(a, b, c, d, on_device={param_device_types=[1,1,2,2], result_device_type=2}) { + # subtract(cpu_to_gpu(fixed_cpu(add(a, b)), add(c, d)) + def expected(): + return tvm.IRModule.from_expr( + relay.annotation.function_on_device( + relay.Function( + [a, b, c, d], + relay.subtract(cpu_to_gpu(fixed_cpu(relay.add(a, b))), relay.add(c, d)), + ), + [CPU, CPU, GPU, GPU], + GPU, + ) + ) + + def ref(a, b, c, d): + return np.subtract(np.add(a, b), np.add(c, d)) + + exercise(input(), expected(), ref, rands(shape, 4)) + + +def test_left_add_on_cpu_via_copy(): + shape = (N, M) + a = relay.var("a", shape=shape) + b = relay.var("b", shape=shape) + c = relay.var("c", shape=shape) + d = relay.var("d", shape=shape) + + # def @main(a, b, c, d) { subtract(cpu_to_gpu(add(a, b)), add(c, d)) } + def input(): + return tvm.IRModule.from_expr( + relay.Function( + [a, b, c, d], relay.subtract(cpu_to_gpu(relay.add(a, b)), relay.add(c, d)) + ) + ) + + # def @main(a, b, c, d, on_device={param_device_types=[1,1,2,2], result_device_type=2}) { + # subtract(cpu_to_gpu(fixed_cpu(add(a, b)), add(c, d)) + def expected(): + return tvm.IRModule.from_expr( + relay.annotation.function_on_device( + relay.Function( + [a, b, c, d], + relay.subtract(cpu_to_gpu(fixed_cpu(relay.add(a, b))), relay.add(c, d)), + ), + [CPU, CPU, GPU, GPU], + GPU, + ) + ) + + def ref(a, b, c, d): + return np.subtract(np.add(a, b), np.add(c, d)) + + exercise(input(), expected(), ref, rands(shape, 4)) + + +def test_both_adds_on_cpu(): + shape = (N, M) + a = relay.var("a", shape=shape) + b = relay.var("b", shape=shape) + c = relay.var("c", shape=shape) + d = relay.var("d", shape=shape) + + # def @main(a, b, c, d) { subtract(on_cpu(add(a, b)), on_cpu(add(c, d))) } + def input(): + return tvm.IRModule.from_expr( + relay.Function( + [a, b, c, d], relay.subtract(on_cpu(relay.add(a, b)), on_cpu(relay.add(c, d))) + ) + ) + + # def @main(a, b, c, d, on_device={param_device_types=[1,1,1,1], result_device_type=2}) { + # subtract(cpu_to_gpu(fixed_cpu(add(a, b)), cpu_to_gpu(fixed_cpu(add(c, d)))) + def expected(): + return tvm.IRModule.from_expr( + relay.annotation.function_on_device( + relay.Function( + [a, b, c, d], + relay.subtract( + cpu_to_gpu(fixed_cpu(relay.add(a, b))), + cpu_to_gpu(fixed_cpu(relay.add(c, d))), + ), + ), + [CPU, CPU, CPU, CPU], + GPU, + ) + ) + + def ref(a, b, c, d): + return np.subtract(np.add(a, b), np.add(c, d)) + + exercise(input(), expected(), ref, rands(shape, 4)) + + +def test_sharing(): + shape = (N, M) + a = relay.var("a", shape=shape) + b = relay.var("b", shape=shape) + + # def @main(a, b) { + # %0 = add(a, b) + # subtract(on_cpu(%0), %0) } + def input(): + add = relay.add(a, b) + return tvm.IRModule.from_expr( + relay.Function([a, b], relay.subtract(on_cpu(add), on_cpu(add))) + ) + + # def @main(a, b, on_device={param_device_types=[1,1], result_device_type=2}) { + # %0 = add(a, b) + # subtract(cpu_to_gpu(fixed_cpu(%0), cpu_to_gpu(fixed_cpu(%0))) + def expected(): + add = relay.add(a, b) + return tvm.IRModule.from_expr( + relay.annotation.function_on_device( + relay.Function( + [a, b], relay.subtract(cpu_to_gpu(fixed_cpu(add)), cpu_to_gpu(fixed_cpu(add))) + ), + [CPU, CPU], + GPU, + ) + ) + + def ref(a, b): + x = np.add(a, b) + return np.subtract(x, x) + + exercise(input(), expected(), ref, rands(shape, 2)) + + +def test_let_on_cpu(): + shape = (N, M) + a = relay.var("a", shape=shape) + b = relay.var("b", shape=shape) + c = relay.var("c", shape=shape) + d = relay.var("d", shape=shape) + l = relay.Var("l") + r = relay.Var("r") + + # def @main(a, b, c, d) { + # let l = add(a, b); + # let r = add(c, d); + # subtract(on_cpu(l), r) + # } + def input(): + return tvm.IRModule.from_expr( + relay.Function( + [a, b, c, d], + relay.Let( + l, relay.add(a, b), relay.Let(r, relay.add(c, d), relay.subtract(on_cpu(l), r)) + ), + ) + ) + + # def @main(a, b, c, d, on_device={param_device_types=[1,1,2,2], result_device_type=2}) { + # let l = fixed_cpu(add(a, b)); + # let r = add(c, d); + # subtract(cpu_to_gpu(l), r) + # } + def expected(): + return tvm.IRModule.from_expr( + relay.annotation.function_on_device( + relay.Function( + [a, b, c, d], + relay.Let( + l, + fixed_cpu(relay.add(a, b)), + relay.Let(r, relay.add(c, d), relay.subtract(cpu_to_gpu(l), r)), + ), + ), + [CPU, CPU, GPU, GPU], + GPU, + ) + ) + + def ref(a, b, c, d): + return np.subtract(np.add(a, b), np.add(c, d)) + + exercise(input(), expected(), ref, rands(shape, 4)) + + +def test_func_param_on_cpu(): + shape = (N, M) + a = relay.var("a", shape=shape) + b = relay.var("b", shape=shape) + c = relay.var("c", shape=shape) + d = relay.var("d", shape=shape) + f = relay.Var("f") + x = relay.Var("x") + y = relay.Var("y") + + # def @main(a, b, c, d) { + # let f = fn(x, y) { on_cpu(add(x, y)) } -- forces both body and result on CPU + # subtract(f(a, b), add(c, d)) + # } + def input(): + return tvm.IRModule.from_expr( + relay.Function( + [a, b, c, d], + relay.Let( + f, + relay.Function([x, y], on_cpu(relay.add(x, y))), + relay.subtract(relay.Call(f, [a, b]), relay.add(c, d)), + ), + ) + ) + + # def @main(a, b, c, d, on_device={param_device_types=[1,1,1,1], result_device_type=1}) { + # let f = fn(x, y, on_device={param_device_types[1,1], result_device_type=1}) { + # add(x, y) + # }; + # subtract(f(a, b), add(c, d)) + # } + def expected(): + return tvm.IRModule.from_expr( + relay.annotation.function_on_device( + relay.Function( + [a, b, c, d], + relay.Let( + f, + relay.annotation.function_on_device( + relay.Function([x, y], relay.add(x, y)), [CPU, CPU], CPU + ), + relay.subtract(relay.Call(f, [a, b]), relay.add(c, d)), + ), + ), + [CPU, CPU, CPU, CPU], + CPU, + ) + ) + + def ref(a, b, c, d): + return np.subtract(np.add(a, b), np.add(c, d)) + + exercise(input(), expected(), ref, rands(shape, 4)) + + +def test_func_result_on_cpu(): + shape = (N, M) + a = relay.var("a", shape=shape) + b = relay.var("b", shape=shape) + c = relay.var("c", shape=shape) + d = relay.var("d", shape=shape) + f = relay.Var("f") + x = relay.Var("x") + y = relay.Var("y") + + # def @main(a, b, c, d) { + # let f = fn(x, y) { add(x, y) } + # subtract(on_cpu(f(a, b)), add(c, d)) + # } + def input(): + return tvm.IRModule.from_expr( + relay.Function( + [a, b, c, d], + relay.Let( + f, + relay.Function([x, y], relay.add(x, y)), + relay.subtract(on_cpu(relay.Call(f, [a, b])), relay.add(c, d)), + ), + ) + ) + + # def @main(a, b, c, d, on_device={param_device_types=[1,1,2,2], result_device_type=2}) { + # let f = fixed_cpu(fn(x, y, on_device={param_device_types=[1,1], result_device_type=1}) { + # add(x, y) + # }); + # subtract(cpu_to_gpu(fixed_cpu(f(a, b))), add(c, d)) + # } + def expected(): + return tvm.IRModule.from_expr( + relay.annotation.function_on_device( + relay.Function( + [a, b, c, d], + relay.Let( + f, + fixed_cpu( + relay.annotation.function_on_device( + relay.Function([x, y], relay.add(x, y)), [CPU, CPU], CPU + ) + ), + relay.subtract( + cpu_to_gpu(fixed_cpu(relay.Call(f, [a, b]))), relay.add(c, d) + ), + ), + ), + [CPU, CPU, GPU, GPU], + GPU, + ) + ) + + def ref(a, b, c, d): + return np.subtract(np.add(a, b), np.add(c, d)) + + exercise(input(), expected(), ref, rands(shape, 4)) + + +def test_higher_order(): + shape = (N, M) + x = relay.var("x", shape=shape) + y = relay.var("y", shape=shape) + f = relay.Var("f") + g = relay.Var("g") + a = relay.Var("a") + h = relay.Var("h") + b = relay.Var("b") + + # The constraint on a flows back to y via f and h + # def @main(x, y) { + # let f = fn(g) { fn(a) { add(g(on_cpu(a)), x) } } + # let h = fn(b) { relu(b) } + # subtract(x, f(h)(y)) + # } + def input(): + return tvm.IRModule.from_expr( + relay.Function( + [x, y], + relay.Let( + f, + relay.Function( + [g], relay.Function([a], relay.add(relay.Call(g, [on_cpu(a)]), x)) + ), + relay.Let( + h, + relay.Function([b], relay.negative(b)), + relay.subtract(x, relay.Call(relay.Call(f, [h]), [y])), + ), + ), + ) + ) + + # def @main(x, y, on_device={param_device_types=[GPU, CPU], result_device_type=GPU}) { + # let f = fn(g, on_device={param_device_types=[GPU], result_device_type=GPU}) { + # fn(a, on_device={param_device_types=[CPU], result_device_type=GPU}) { + # add(g(cpu_to_gpu(a)), x) + # } + # } + # let h = fn(b, on_device={param_device_types=[GPU], result_device_type=GPU}) { negative(b) } + # subtract(x, f(h)(y)) + # } + def expected(): + # Yeah, this is illegible. + return tvm.IRModule.from_expr( + relay.annotation.function_on_device( + relay.Function( + [x, y], + relay.Let( + f, + relay.annotation.function_on_device( + relay.Function( + [g], + relay.annotation.function_on_device( + relay.Function( + [a], relay.add(relay.Call(g, [cpu_to_gpu(a)]), x) + ), + [CPU], + GPU, + ), + ), + [GPU], + GPU, + ), + relay.Let( + h, + relay.annotation.function_on_device( + relay.Function([b], relay.negative(b)), [GPU], GPU + ), + relay.subtract(x, relay.Call(relay.Call(f, [h]), [y])), + ), + ), + ), + [GPU, CPU], + GPU, + ) + ) + + def ref(x, y): + def f(g): + return lambda a: np.add(g(a), x) + + def h(b): + return np.negative(b) + + return np.subtract(x, f(h)(y)) + + exercise(input(), expected(), ref, rands(shape, 2)) + + +def test_function_in_tuple(): + shape = (N, M) + x = relay.var("x", shape=shape) + y = relay.var("y", shape=shape) + a = relay.var("a", shape=shape) + b = relay.var("b", shape=shape) + y = relay.var("y", shape=shape) + f = relay.Var("f") + t = relay.Var("t") + + # Since f end up in a tuple its argument and result is forced to be on the CPU + # def @main(x, y) { + # let f = fn(a, b) { add(a, on_cpu(b)) } + # let t = (f, x) + # t.0(t.1, y) + # } + def input(): + return tvm.IRModule.from_expr( + relay.Function( + [x, y], + relay.Let( + f, + relay.Function([a, b], relay.add(a, on_cpu(b))), + relay.Let( + t, + relay.Tuple([f, x]), + relay.Call(relay.TupleGetItem(t, 0), [relay.TupleGetItem(t, 1), y]), + ), + ), + ) + ) + + # def @main(x, y, on_device={param_device_types=[1,1], result_device_type=1}) { + # let f = fn(a, b, on_device={param_device_types=[1,1], result_device_type=1}) { add(a, b) } + # let t = (f, x) + # t.0(t.1, y) + # } + def expected(): + return tvm.IRModule.from_expr( + relay.annotation.function_on_device( + relay.Function( + [x, y], + relay.Let( + f, + relay.annotation.function_on_device( + relay.Function([a, b], relay.add(a, b)), [CPU, CPU], CPU + ), + relay.Let( + t, + relay.Tuple([f, x]), + relay.Call(relay.TupleGetItem(t, 0), [relay.TupleGetItem(t, 1), y]), + ), + ), + ), + [CPU, CPU], + CPU, + ) + ) + + def ref(x, y): + return np.add(x, y) + + exercise(input(), expected(), ref, rands(shape, 2)) + + +def test_device_copy(): + shape = (N, M) + x = relay.var("x", shape=shape) + const = relay.const(rand(shape)) + + # def @main(x) { add(cpu_to_gpu(x), const) } + def input(): + return tvm.IRModule.from_expr(relay.Function([x], relay.add(cpu_to_gpu(x), const))) + + # def @main(x, on_device={param_device_types=[1], result_device_type=2}) { + # add(cpu_to_gpu(x), constant) + # } + def expected(): + return tvm.IRModule.from_expr( + relay.annotation.function_on_device( + relay.Function([x], relay.add(cpu_to_gpu(x), const)), [CPU], GPU + ) + ) + + def ref(x): + return np.add(x, const.data.numpy()) + + exercise(input(), expected(), ref, rands(shape, 1)) + + +def test_shape_func(): + p = relay.var("p") + data_shape = (relay.Any(),) + x = relay.var("x", shape=data_shape) + y = relay.var("y", shape=data_shape) + s = relay.var("s", shape=(1,), dtype="int64") + + # def @main(x, s) { + # let p = fixed_gpu(fn(y) { relu(y) }) -- simulates a primitive post FuseOps + # shape_func(p, + # (shape_of(fixed_gpu(x)),), -- shape of primitive input tensor + # (s,), -- space for output shape + # [False]) -- calling with input shapes not tensors + # } + def input(): + return tvm.IRModule.from_expr( + relay.Function( + [x, s], + relay.Let( + p, + fixed_gpu(relay.Function([y], relay.nn.relu(y))), + relay.op.vm.shape_func( + p, + relay.Tuple([relay.op.vm.shape_of(fixed_gpu(x))]), + relay.Tuple([s]), + [False], + ), + ), + ) + ) + + # def @main(x, s, on_device={param_device_types=[2,1], result_device_type=1}) { + # let p = fixed_gpu(fn(y, param_device_types=[2], result_device_type=2) { relu(y) }) + # shape_func(p, + # (shape_of(x),), + # (s,), + # [False]) + # } + def expected(): + return tvm.IRModule.from_expr( + relay.annotation.function_on_device( + relay.Function( + [x, s], + relay.Let( + p, + fixed_gpu( + relay.annotation.function_on_device( + relay.Function([y], relay.nn.relu(y)), [GPU], GPU + ) + ), + relay.op.vm.shape_func( + p, relay.Tuple([relay.op.vm.shape_of(x)]), relay.Tuple([s]), [False] + ), + ), + ), + [GPU, CPU], + CPU, + ) + ) + + # Don't try to execute, too fiddly to setup. + exercise(input(), expected(), None, None) + + +def test_shape_of(): + compiletime_shape = (relay.Any(), relay.Any()) + runtime_shape = (N, M) + x = relay.var("x", shape=compiletime_shape) + + # We need to use fixed_gpu since the result of on_gpu will default to the result device of @main which is cpu, + # which then forces a copy. + # TODO(mbs): Perhaps the defaulting heuristics are being too clever? + # def @main(x) { shape_of(fixed_gpu(x)) } + def input(): + return tvm.IRModule.from_expr(relay.Function([x], relay.op.vm.shape_of(fixed_gpu(x)))) + + # def @main(x, on_device={param_device_types=[2], result_dev_type=1}) { + # shape_of(x) + # } + def expected(): + return tvm.IRModule.from_expr( + relay.annotation.function_on_device( + relay.Function([x], relay.op.vm.shape_of(x)), [GPU], CPU + ) + ) + + def ref(x): + return x.shape + + exercise(input(), expected(), ref, rands(runtime_shape, 1)) + + +def test_alloc_storage(): + size = relay.Var("size", relay.scalar_type("int64")) + alignment = relay.Var("alignment", relay.scalar_type("int64")) + main = relay.GlobalVar("main") + stdlib = tvm.IRModule() + stdlib.import_from_std("core.rly") + + # def @main(size, alignment) { alloc_storage(size, alignment, GPU) } + def input(): + mod = tvm.IRModule() + mod.update(stdlib) + mod[main] = relay.Function( + [size, alignment], relay.op.memory.alloc_storage(size, alignment, GPU) + ) + return mod + + # def @main(size, alignment, on_device={param_device_types=[1,1], result_device_type=2}) { + # alloc_storage(size, alignment, GPU) + # } + def expected(): + mod = tvm.IRModule() + mod.update(stdlib) + mod[main] = relay.annotation.function_on_device( + relay.Function([size, alignment], relay.op.memory.alloc_storage(size, alignment, GPU)), + [CPU, CPU], + GPU, + ) + return mod + + # Don't try to execute, too fiddly to setup. + exercise(input(), expected(), None, None) + + +def test_alloc_tensor(): + stdlib = tvm.IRModule() + stdlib.import_from_std("core.rly") + sto_type = relay.TypeCall(stdlib.get_global_type_var("Storage"), []) + sto = relay.Var("sto", sto_type) + main = relay.GlobalVar("main") + shape = relay.const(np.array([3, 2]), dtype="int64") + + # def @main(sto) { alloc_tensor(sto, 0, [3, 2]) } + def input(): + mod = tvm.IRModule() + mod.update(stdlib) + mod[main] = relay.Function( + [sto], relay.op.memory.alloc_tensor(sto, relay.const(0, dtype="int64"), shape) + ) + return mod + + # def @main(sto, on_device={param_device_types=[2], result_device_type=2}) { + # alloc_tensor(sto, fixed_cpu(0), fixed_cpu([3, 2])) + # } + def expected(): + mod = tvm.IRModule() + mod.update(stdlib) + mod[main] = relay.annotation.function_on_device( + relay.Function( + [sto], + relay.op.memory.alloc_tensor( + sto, fixed_cpu(relay.const(0, dtype="int64")), fixed_cpu(shape) + ), + ), + [GPU], + GPU, + ) + return mod + + # Don't try to execute, too fiddly to setup. + exercise(input(), expected(), None, None) + + +def test_reshape_tensor(): + shape = (2, 8) + x = relay.var("x", shape=shape, dtype="float32") + newshape_expr = relay.const([2, 4, 2], dtype="int64") + newshape_prim = [2, 4, 2] + + # def @main(x) { reshape_tensor(x, shape, newshape=[2,4,2]) } + def input(): + return tvm.IRModule.from_expr( + relay.Function([x], relay.op.vm.reshape_tensor(x, newshape_expr, newshape_prim)) + ) + + # def @main(x, on_device={param_device_types=[2], result_device_type=2}) { + # reshape_tensor(x, fixed_cpu(shape), newshape=[2,4,2]) + # } + def expected(): + return tvm.IRModule.from_expr( + relay.annotation.function_on_device( + relay.Function( + [x], relay.op.vm.reshape_tensor(x, fixed_cpu(newshape_expr), newshape_prim) + ), + [GPU], + GPU, + ) + ) + + def ref(x): + return np.reshape(x, newshape_prim) + + exercise(input(), expected(), ref, rands(shape, 1)) + + +def test_dynamic_input(): + compiletime_shape = (relay.Any(), relay.Any()) + runtime_shape = (N, M) + x0 = relay.var("x0", shape=compiletime_shape) + x1 = relay.var("x1", shape=compiletime_shape) + + # def @main(x0, x1) { add(x0, x1) } + def input(): + return tvm.IRModule.from_expr(relay.Function([x0, x1], relay.add(x0, x1))) + + # def @main(x0, x1), on_device={param_device_types=[2,2], result_device_type=2}) { + # add(x0, x1) + # } + def expected(): + return tvm.IRModule.from_expr( + relay.annotation.function_on_device( + relay.Function([x0, x1], relay.add(x0, x1)), [GPU, GPU], GPU + ) + ) + + def ref(x0, x1): + return np.add(x0, x1) + + exercise(input(), expected(), ref, rands(runtime_shape, 2)) + + +def test_redundant_annotation(): + shape = (N, M) + x = relay.var("x", shape=shape) + y = relay.var("y", shape=shape) + z = relay.var("z", shape=shape) + + # def @main(x, y, z) { + # %0 = add(x, y) + # add(subtract(on_cpu(%0), z), on_cpu(%0)) + # } + def input(): + a = relay.add(x, y) + return tvm.IRModule.from_expr( + relay.Function([x, y, z], relay.add(relay.subtract(on_cpu(a), z), on_cpu(a))) + ) + + # def @main(x, y, z, on_device={param_device_types=[1,1,2], result_device_type=2}) { + # %0 = add(x, y) + # add(subtract(cpu_to_gpu(fixed_cpu(%0)), z), cpu_to_gpu(fixed_cpu(%0))) + # } + def expected(): + a = relay.add(x, y) + return tvm.IRModule.from_expr( + relay.annotation.function_on_device( + relay.Function( + [x, y, z], + relay.add( + relay.subtract(cpu_to_gpu(fixed_cpu(a)), z), cpu_to_gpu(fixed_cpu(a)) + ), + ), + [CPU, CPU, GPU], + GPU, + ) + ) + + def ref(x, y, z): + a = np.add(x, y) + return np.add(np.subtract(a, z), a) + + exercise(input(), expected(), ref, rands(shape, 3)) + + +def test_annotate_expr(): + shape = (N, M) + x = relay.var("x", shape=shape) + y = relay.var("y", shape=shape) + z = relay.var("z", shape=shape) + + # def @main(x, y, z) { on_cpu(subtract(on_gpu(add(x, y)), z)) } -- forces function result also on cpu + def input(): + return tvm.IRModule.from_expr( + relay.Function([x, y, z], on_cpu(relay.subtract(on_gpu(relay.add(x, y)), z))) + ) + + # def @main(x, y, z, on_device={param_device_types=[2,2,1], result_device_type=1}) { + # subtract(gpu_to_cpu(fixed_gpu(add(x, y))), z) + # } + def expected(): + add = relay.add(x, y) + return tvm.IRModule.from_expr( + relay.annotation.function_on_device( + relay.Function( + [x, y, z], relay.subtract(gpu_to_cpu(fixed_gpu(relay.add(x, y))), z) + ), + [GPU, GPU, CPU], + CPU, + ) + ) + + def ref(x, y, z): + return np.subtract(np.add(x, y), z) + + exercise(input(), expected(), ref, rands(shape, 3)) + + +def test_annotate_all(): + shape = (N, M) + x = relay.var("x", shape=shape) + y = relay.var("y", shape=shape) + z = relay.var("z", shape=shape) + + # def @main(x, y, z) { on_cpu(subtract(on_cpu(add(x, y)), z) } -- top-level also forces result to be CPU + def input(): + return tvm.IRModule.from_expr( + relay.Function([x, y, z], on_cpu(relay.subtract(on_cpu(relay.add(x, y)), z))) + ) + + # def @main(x, y, z, on_device={param_device_types=[CPU, CPU, CPU], result_device_type=CPU}) { + # subtract(add(x, y), z) + # } + def expected(): + return tvm.IRModule.from_expr( + relay.annotation.function_on_device( + relay.Function([x, y, z], relay.subtract(relay.add(x, y), z)), [CPU, CPU, CPU], CPU + ) + ) + + def ref(x, y, z): + return np.subtract(np.add(x, y), z) + + exercise(input(), expected(), ref, rands(shape, 3)) + + +def test_conv_network(): + r"""The network and devices are as follows: + data1 data2 <--- CPU + | | + conv2d conv2d <--- CPU + \ / + \ / + add <--- GPU + | + conv2d <--- CPU + | + <--- CPU + """ + batch_size = 1 + dshape = (batch_size, 64, 56, 56) + wshape = (64, 64, 3, 3) + weight = relay.var("weight", shape=wshape) + data1 = relay.var("data1", shape=dshape) + data2 = relay.var("data2", shape=dshape) + + def input(): + conv2d_1 = relay.nn.conv2d(data1, weight, channels=64, kernel_size=(3, 3), padding=(1, 1)) + conv2d_2 = relay.nn.conv2d(data2, weight, channels=64, kernel_size=(3, 3), padding=(1, 1)) + add = relay.add(on_cpu(conv2d_1), on_cpu(conv2d_2)) + conv2d_3 = relay.nn.conv2d( + on_gpu(add), weight, channels=64, kernel_size=(3, 3), padding=(1, 1) + ) + return tvm.IRModule.from_expr(relay.Function([data1, data2, weight], on_cpu(conv2d_3))) + + def expected(): + conv2d_1 = relay.nn.conv2d(data1, weight, channels=64, kernel_size=(3, 3), padding=(1, 1)) + conv2d_2 = relay.nn.conv2d(data2, weight, channels=64, kernel_size=(3, 3), padding=(1, 1)) + add = relay.add(cpu_to_gpu(fixed_cpu(conv2d_1)), cpu_to_gpu(fixed_cpu(conv2d_2))) + conv2d_3 = relay.nn.conv2d( + gpu_to_cpu(fixed_gpu(add)), weight, channels=64, kernel_size=(3, 3), padding=(1, 1) + ) + return tvm.IRModule.from_expr( + relay.annotation.function_on_device( + relay.Function([data1, data2, weight], conv2d_3), [CPU, CPU, CPU], CPU + ) + ) + + # Don't try to execute, we don't have a reference conv2d + exercise(input(), expected(), None, None) + + +def test_tuple_get_item(): + shape = (3, 3, 4) + x = relay.Var("x", relay.ty.TensorType(shape, "float32")) + t = relay.Var("t") + + # We'll device copy after projection, not before. + # def @main(x) { + # let t = split(x, 3); + # subtract(on_cpu(t).0, on_cpu(t).1) + # } + def input(): + return tvm.IRModule.from_expr( + relay.Function( + [x], + relay.Let( + t, + relay.op.split(x, 3).astuple(), + on_gpu( + relay.subtract( + relay.TupleGetItem(on_cpu(t), 0), relay.TupleGetItem(on_cpu(t), 1) + ) + ), + ), + ) + ) + + # def @main(x, on_device={param_device_type=[1], result_device_type=2}) { + # let t = fixed_cpu(split(x, 3)) + # subtract(cpu_to_gpu(fixed_cpu(t.0)), cpu_to_gpu(fixed_cpu(t.1))) + # } + def expected(): + return tvm.IRModule.from_expr( + relay.annotation.function_on_device( + relay.Function( + [x], + relay.Let( + t, + fixed_cpu(relay.op.split(x, 3).astuple()), + relay.subtract( + cpu_to_gpu(fixed_cpu(relay.TupleGetItem(t, 0))), + cpu_to_gpu(fixed_cpu(relay.TupleGetItem(t, 1))), + ), + ), + ), + [CPU], + GPU, + ) + ) + + def ref(x): + t = np.split(x, 3) + return np.subtract(t[0], t[1]) + + exercise(input(), expected(), ref, rands(shape, 1)) + + +def test_propogation(): + R""" The network and devices are as follows: + x <--- CPU + | + log <--- CPU + / \ + log2 log10 <--- GPU + \ / + add <--- GPU + | + tan <--- CPU + | + <--- CPU + """ + shape = (N, M) + x = relay.var("x", shape=shape) + + def input(): + log = relay.log(x) + log2 = relay.log2(on_cpu(log)) + log10 = relay.log10(on_cpu(log)) + add = relay.add(on_gpu(log2), on_gpu(log10)) + tan = relay.tan(on_gpu(add)) + return tvm.IRModule.from_expr(relay.Function([x], on_cpu(tan))) + + def expected(): + log = relay.log(x) + log2 = relay.log2(cpu_to_gpu(fixed_cpu(log))) + log10 = relay.log10(cpu_to_gpu(fixed_cpu(log))) + add = relay.add(log2, log10) + tan = relay.tan(gpu_to_cpu(fixed_gpu(add))) + return tvm.IRModule.from_expr( + relay.annotation.function_on_device(relay.Function([x], tan), [CPU], CPU) + ) + + def ref(x): + y = np.log(x) + return np.tan(np.add(np.log2(y), np.log10(y))) + + exercise(input(), expected(), ref, rands(shape, 1)) + + +def test_fusible_network(): + R""" The network is as follows: + x y <--- GPU + \ / + add <--- GPU + / \ + negative \ <--- CPU + \ \ + \ negative <--- GPU + \ / + add <--- GPU + | + negative <--- CPU + | + <--- CPU + """ + shape = (N, M) + x = relay.var("x", shape=shape) + y = relay.var("y", shape=shape) + + def input(): + add = relay.add(x, y) + sqrt = relay.negative(on_gpu(add)) + log = relay.negative(add) + subtract = relay.add(on_cpu(sqrt), log) + exp = relay.negative(on_gpu(subtract)) + return tvm.IRModule.from_expr(relay.Function([x, y], on_cpu(exp))) + + def expected(): + add = relay.add(x, y) + sqrt = relay.negative(gpu_to_cpu(fixed_gpu(add))) + log = relay.negative(add) + subtract = relay.add(cpu_to_gpu(fixed_cpu(sqrt)), log) + exp = relay.negative(gpu_to_cpu(fixed_gpu(subtract))) + return tvm.IRModule.from_expr( + relay.annotation.function_on_device(relay.Function([x, y], exp), [GPU, GPU], CPU) + ) + + def ref(x, y): + z = np.add(x, y) + return np.negative(np.add(np.negative(z), np.negative(z))) + + exercise(input(), expected(), ref, rands(shape, 2)) + + +def test_unpropagatable_graph(): + r"""The network is as follows: + a b <--- CPU + \ / + \ / c d <--- GPU + \ / \ / + add \ / <--- CPU + \ \ / + \ multiply <--- GPU + \ / + subtract <--- CPU + | + <--- CPU + """ + shape = (N, M) + a = relay.var("a", shape=shape) + b = relay.var("b", shape=shape) + c = relay.var("c", shape=shape) + d = relay.var("d", shape=shape) + + def input(): + return tvm.IRModule.from_expr( + relay.Function( + [a, b, c, d], + on_cpu(relay.subtract(on_cpu(relay.add(a, b)), on_gpu(relay.multiply(c, d)))), + ) + ) + + def expected(): + return tvm.IRModule.from_expr( + relay.annotation.function_on_device( + relay.Function( + [a, b, c, d], + relay.subtract(relay.add(a, b), gpu_to_cpu(fixed_gpu(relay.multiply(c, d)))), + ), + [CPU, CPU, GPU, GPU], + CPU, + ) + ) + + def ref(a, b, c, d): + return np.subtract(np.add(a, b), np.multiply(c, d)) + + exercise(input(), expected(), ref, rands(shape, 4)) + + +def test_conditional(): + shape = (N, M) + x = relay.Var("x", relay.ty.scalar_type("bool")) + y = relay.var("y", shape=shape) + z = relay.var("z", shape=shape) + f = relay.Var("f") + g = relay.Var("g") + h = relay.Var("h") + a1 = relay.Var("a") + a2 = relay.Var("a") + + # def @main(x, y, z) { + # let f = fn(a) { add(a, fixed_cpu(y)) } + # let g = fn(a) { subtract(a, y) } + # let h = if (x) { + # f + # } else { + # g + # } + # h(z) + # } + def input(): + return tvm.IRModule.from_expr( + relay.Function( + [x, y, z], + relay.Let( + f, + relay.Function([a1], relay.add(a1, fixed_cpu(y))), + relay.Let( + g, + relay.Function([a2], relay.subtract(a2, y)), + relay.Let(h, relay.If(x, f, g), relay.Call(h, [z])), + ), + ), + ) + ) + + # def @main(x, y, z, on_device={param_device_types=[1,1,1], result_device_type=1}) { + # let f = fn(a, on_device={param_device_types=[1], result_device_type=1}) { add(a, y) } + # let g = fn + # (a, on_device={param_device_types=[1], result_device_type=1}) { subtract(a, y) } + # let h = if (x) { + # f + # } else { + # g + # } + # h(z) + # } + def expected(): + return tvm.IRModule.from_expr( + relay.annotation.function_on_device( + relay.Function( + [x, y, z], + relay.Let( + f, + relay.annotation.function_on_device( + relay.Function([a1], relay.add(a1, y)), [CPU], CPU + ), + relay.Let( + g, + relay.annotation.function_on_device( + relay.Function([a2], relay.subtract(a2, y)), [CPU], CPU + ), + relay.Let(h, relay.If(x, f, g), relay.Call(h, [z])), + ), + ), + ), + [CPU, CPU, CPU], + CPU, + ) + ) + + def ref(x, y, z): + def f(a): + return np.add(a, y) + + def g(a): + return np.subtract(a, y) + + h = f if x else g + return h(z) + + exercise(input(), expected(), ref, [True, rand(shape), rand(shape)]) + + +def test_global(): + shape = (N, M) + a = relay.var("a", shape=shape) + b = relay.var("b", shape=shape) + x = relay.var("x", shape=shape) + y = relay.var("y", shape=shape) + f = relay.GlobalVar("f") + main = relay.GlobalVar("main") + + # def @f(a, b) { add(a, on_cpu(b)) } + # def @main(x, y) { @f(y, x) } + def input(): + mod = tvm.IRModule() + mod[f] = relay.Function( + [a, b], relay.add(a, on_cpu(b)), relay.ty.TensorType(shape, "float32") + ) + mod[main] = relay.Function( + [x, y], relay.Call(f, [y, x]), relay.ty.TensorType(shape, "float32") + ) + return mod + + # def @f(a, b, on_device={param_device_types=[2,1], result_device_type=2}) { add(a, on_cpu(b)) } + # def @main(x, y, on_device={param_device_types=[1,2], result_device_type=2}) { @f(y, x) } + def expected(): + mod = tvm.IRModule() + mod[f] = relay.annotation.function_on_device( + relay.Function( + [a, b], relay.add(a, cpu_to_gpu(b)), relay.ty.TensorType(shape, "float32") + ), + [GPU, CPU], + GPU, + ) + mod[main] = relay.annotation.function_on_device( + relay.Function([x, y], relay.Call(f, [y, x]), relay.ty.TensorType(shape, "float32")), + [CPU, GPU], + GPU, + ) + return mod + + def ref(x, y): + def f(a, b): + return np.add(a, b) + + return f(x, y) + + exercise(input(), expected(), ref, rands(shape, 2)) + + +# Note that match and ADTs don't appear to be supported for direct AST +# construction. + + +def test_ref(): + shape = (N, M) + x = relay.var("x", shape=shape) + y = relay.var("y", shape=shape) + r = relay.var("r") + dummy = relay.var("dummy") + + # def @main(x, y) { + # r = ref(x) + # ref_write(r, on_cpu(y)) + # add(x, ref_read(r)) + # } + def input(): + return tvm.IRModule.from_expr( + relay.Function( + [x, y], + relay.Let( + r, + relay.RefCreate(x), + relay.Let(dummy, relay.RefWrite(r, on_cpu(y)), relay.add(x, relay.RefRead(r))), + ), + ) + ) + + # def @main(x, y, on_device={param_device_types=[GPU, CPU], result_device_type=GPU}) { + # r = ref(x) + # ref_write(r, cpu_to_gpu(y)) + # add(x, ref_read(r)) + # } + def expected(): + return tvm.IRModule.from_expr( + relay.annotation.function_on_device( + relay.Function( + [x, y], + relay.Let( + r, + relay.RefCreate(x), + relay.Let( + dummy, relay.RefWrite(r, cpu_to_gpu(y)), relay.add(x, relay.RefRead(r)) + ), + ), + ), + [GPU, CPU], + GPU, + ) + ) + + def ref(x, y): + r = {"value": x} + r["value"] = y + return np.add(x, r["value"]) + + # Don't try to execute, no backend currently supports both cross-devices and references. + exercise(input(), expected(), None, None) + + +if __name__ == "__main__": + test_plain() + test_left_add_on_cpu() + test_left_add_on_cpu_via_copy() + test_both_adds_on_cpu() + test_sharing() + test_let_on_cpu() + test_func_param_on_cpu() + test_func_result_on_cpu() + test_higher_order() + test_function_in_tuple() + test_device_copy() + test_shape_func() + test_shape_of() + test_alloc_storage() + test_alloc_tensor() + test_reshape_tensor() + test_dynamic_input() + test_redundant_annotation() + test_annotate_expr() + test_annotate_all() + test_conv_network() + test_tuple_get_item() + test_propogation() + test_fusible_network() + test_unpropagatable_graph() + test_conditional() + test_global() + test_ref()