From 779a50629f9b322cf9a26277574fac5bbaff3a24 Mon Sep 17 00:00:00 2001 From: Mark Shields <87091372+mbs-octoml@users.noreply.github.com> Date: Mon, 4 Oct 2021 12:18:18 -0700 Subject: [PATCH] [Relay] Remove DeviceMap from LowerTE (#8788) * [Relay] Switch the graph, VM and AOT executors to use the merged device_planner.cc from #9038, and finally remove DeviceMap from the LowerTE Pass. - We retire analysis/context_analysis.cc and transforms/device_annotation.cc (and their tests). That includes the CollectDeviceInfo, CollectDeviceAnnotationOps and ContextAnalysis entry points. These are all subsumed by the PlanDevices pass and the device aware visitors. - The following passes now use the new 'Device Aware' visitors to recover the device for every Relay sub-expression: - backend/aot_executor_codegen.cc (AOTOnDemandAllocator) - backend/graph_plan_memory.cc (StorageAllocaBaseVisitor etc) - backend/te_compiler.cc (LowerTensorExprMutator) - transforms/memory_alloc.cc (DialectRewriter) - backend/vm/compiler.cc (VMFunctionCompiler) - The following passes/utils must maintain the device information encoded by the device planner within "on_device" annotations and "param_device_types"/"result_device_type" function attributes: - backend/vm/lambda_lift.cc (LambdaLifter) - transforms/to_a_normal_form.cc (Fill) - ir/expr_functior.cc (Bind) - Remove a lot ad-hoc 'homogeneous' vs 'hetrogeneous' conditionals in favor of just asking for the device. Also removed a lot of ad-doc encodings of the 'default' device. - We no longer need to run device-planning twice (before and after lowering). Device planning is also decoupled from memory planning. - The LowerTE Pass no longer needs an expression-to-device side table (which was the problem which kicked this series of PRs off in the first place). * [checkpoint] Revert unnecessary changes - Started down multi-target handling in interpreter but didn't finish - Some one-off debug stuff * [checkpoint] TODO's for default device logic --- include/tvm/relay/analysis.h | 29 - include/tvm/relay/transform.h | 3 +- include/tvm/runtime/vm/vm.h | 4 +- python/tvm/relay/analysis/analysis.py | 49 -- python/tvm/relay/op/_tensor.py | 1 - python/tvm/relay/transform/transform.py | 21 - src/parser/parser.cc | 48 +- src/parser/tokenizer.h | 4 +- src/relay/analysis/context_analysis.cc | 719 ------------------ src/relay/backend/aot_executor_codegen.cc | 144 ++-- src/relay/backend/build_module.cc | 166 ++-- src/relay/backend/graph_executor_codegen.cc | 79 +- src/relay/backend/graph_plan_memory.cc | 177 +++-- src/relay/backend/interpreter.cc | 89 ++- src/relay/backend/te_compiler.cc | 166 ++-- src/relay/backend/te_compiler.h | 11 +- src/relay/backend/utils.cc | 20 + src/relay/backend/utils.h | 5 + src/relay/backend/vm/compiler.cc | 244 +++--- src/relay/backend/vm/compiler.h | 12 +- src/relay/backend/vm/lambda_lift.cc | 91 +-- src/relay/ir/expr_functor.cc | 26 +- src/relay/op/annotation/annotation.cc | 9 +- src/relay/op/annotation/annotation.h | 14 +- src/relay/transforms/device_annotation.cc | 581 -------------- src/relay/transforms/device_aware_visitors.cc | 54 +- src/relay/transforms/device_aware_visitors.h | 44 +- src/relay/transforms/device_planner.cc | 1 + src/relay/transforms/fold_scale_axis.cc | 4 + src/relay/transforms/fuse_ops.cc | 7 +- src/relay/transforms/higher_order_gradient.cc | 2 +- src/relay/transforms/let_list.h | 2 +- src/relay/transforms/memory_alloc.cc | 156 ++-- src/relay/transforms/pass_utils.h | 65 +- src/relay/transforms/pattern_utils.h | 4 - src/relay/transforms/split_args.cc | 3 +- src/relay/transforms/to_a_normal_form.cc | 417 ++++++---- .../transforms/to_basic_block_normal_form.cc | 36 +- src/relay/transforms/type_infer.cc | 5 +- src/runtime/vm/serialize_utils.h | 4 +- src/runtime/vm/vm.cc | 4 +- .../transforms/device_domains_test.cc | 0 tests/python/relay/test_pass_annotation.py | 663 ---------------- .../relay/test_pass_context_analysis.py | 205 ----- tests/python/relay/test_pass_plan_devices.py | 43 +- tests/python/relay/test_vm.py | 2 + 46 files changed, 1157 insertions(+), 3276 deletions(-) delete mode 100644 src/relay/analysis/context_analysis.cc delete mode 100644 src/relay/transforms/device_annotation.cc rename tests/cpp/relay/{relay => }/transforms/device_domains_test.cc (100%) delete mode 100644 tests/python/relay/test_pass_annotation.py delete mode 100644 tests/python/relay/test_pass_context_analysis.py diff --git a/include/tvm/relay/analysis.h b/include/tvm/relay/analysis.h index 264f2609a4b6..0f85587262ac 100644 --- a/include/tvm/relay/analysis.h +++ b/include/tvm/relay/analysis.h @@ -211,24 +211,6 @@ TVM_DLL tvm::Array AllTypeVars(const Expr& expr, const IRModule& mod); */ TVM_DLL tvm::Array AllTypeVars(const Type& t, const IRModule& mod); -/*! - * \brief Collect the device mapping information of each expression. - * - * \param expr The expression. - * - * \return The device mapping. - */ -TVM_DLL Map CollectDeviceInfo(const Expr& expr); - -/*! - * \brief Collect the device anntation operators. - * - * \param expr The expression. - * - * \return The annotated expression to device type mapping for annotation ops. - */ -TVM_DLL Map CollectDeviceAnnotationOps(const Expr& expr); - /*! * \brief Finds cases that the given match expression does not catch, if any. * @@ -268,17 +250,6 @@ TVM_DLL IRModule GetCalibrateModule(IRModule mod); */ TVM_DLL Map> GetCalibrateOutputMap(const IRModule& mod); -/*! - * \brief Analyze the device context of each IR node in a given relay module. - * - * \param mod The module for analysis. - * \param default_device The default device used by unassigned IR nodes. - * - * \return The mapping between an IR node and its associated device. - */ -TVM_DLL std::unordered_map -ContextAnalysis(const IRModule& mod, const Device& default_device); - } // namespace relay } // namespace tvm diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index e740776d6d4f..91f731410863 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -165,11 +165,12 @@ TVM_DLL Pass ToANormalForm(); /*! * \brief ToANormalForm but on incomplete graph. * + * \param maybe_mod optional module holding definitions for global vars in \p expr * \param expr the graph. * * \return The transformed program. */ -TVM_DLL Expr ToANormalForm(const Expr& expr); +TVM_DLL Expr ToANormalForm(const Optional& maybe_mod, const Expr& expr); /*! * \brief Turn an expression into continuation passing style(CPS). diff --git a/include/tvm/runtime/vm/vm.h b/include/tvm/runtime/vm/vm.h index 2fdfec9452af..831336b9dbfe 100644 --- a/include/tvm/runtime/vm/vm.h +++ b/include/tvm/runtime/vm/vm.h @@ -84,11 +84,11 @@ struct VMFunction { /*! \brief The size of the frame for this function */ Index register_file_size; /*! \brief The device type of each parameter for this function. */ - std::vector params_device_type; + std::vector params_device_type; VMFunction(const std::string& name, std::vector params, const std::vector& instructions, Index register_file_size, - const std::vector params_device_type = {}) + const std::vector params_device_type = {}) : name(name), params(params), instructions(instructions), diff --git a/python/tvm/relay/analysis/analysis.py b/python/tvm/relay/analysis/analysis.py index 524f69bcdd13..b62700573581 100644 --- a/python/tvm/relay/analysis/analysis.py +++ b/python/tvm/relay/analysis/analysis.py @@ -28,21 +28,6 @@ from .feature import Feature -def context_analysis(mod, default_device): - """Analyze the device context information of each IR node in a Relay - program. - - Parameters - ---------- - mod : tvm.IRModule - The input module. - - default_device : tvm.runtime.Device - The default context allocated to an IR node. - """ - return _ffi_api.ContextAnalysis(mod, default_device) - - def post_order_visit(expr, fvisit): """Recursively visit the ir in post DFS order node, apply fvisit. Each node is guaranteed to be visited @@ -268,40 +253,6 @@ def all_dtypes(expr): return set(_ffi_api.all_dtypes(expr)) -def collect_device_info(expr): - """Collect the device allocation map for the given expression. The device - ids are propagated from the `device_copy` operators. - - Parameters - ---------- - expr : tvm.relay.Expr - The input expression. - - Returns - ------- - ret : Dict[tvm.relay.ir.expr, int] - A dictionary mapping tvm.relay.Expr to device type. - """ - return _ffi_api.CollectDeviceInfo(expr) - - -def collect_device_annotation_ops(expr): - """Collect the device annotation ops for the given expression. - - Parameters - ---------- - expr : tvm.relay.Expr - The input expression. - - Returns - ------- - ret : Dict[tvm.relay.Expr, int] - A dictionary mapping tvm.relay.Expr to device type where the keys are - annotation expressions. - """ - return _ffi_api.CollectDeviceAnnotationOps(expr) - - def get_total_mac_number(expr): """ Count the number of MACs (multiply-accumulate) of a model diff --git a/python/tvm/relay/op/_tensor.py b/python/tvm/relay/op/_tensor.py index d7d99c017b2b..18ce93322f43 100644 --- a/python/tvm/relay/op/_tensor.py +++ b/python/tvm/relay/op/_tensor.py @@ -91,7 +91,6 @@ register_broadcast_schedule("fast_erf") # a fake on_device schedule. # this will not be used in actual computation -# as on_device will be removed during DeviceAnnotation pass register_injective_schedule("on_device") diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index bb91afc06195..0dc07944836d 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -544,27 +544,6 @@ def MergeCompilerRegions(): return _ffi_api.MergeCompilerRegions() -def RewriteAnnotatedOps(fallback_device): - """Rewrite the annotated program where annotation operators, e.g. - `on_device`, mark which device an expression should be scheduled to. - This pass helps heterogeneous execution where different operators may need - to be allocated on various devices. - - Parameters - ---------- - fallback_device : int - The fallback device type. It is also used as the default device for - operators with no annotated device. - - Returns - ------- - ret: tvm.transform.Pass - The registered pass that rewrites an expression with annotated - `on_device` operators. - """ - return _ffi_api.RewriteDeviceAnnotation(fallback_device) - - def ToANormalForm(): """Turn Graph Normal Form expression into A Normal Form Expression. The scope of the root expression is the global scope. diff --git a/src/parser/parser.cc b/src/parser/parser.cc index 5eec716cc20c..ebd6566889dc 100644 --- a/src/parser/parser.cc +++ b/src/parser/parser.cc @@ -417,7 +417,7 @@ class Parser { * Useful for matching optional tokens, effectively looksahead by one. */ bool WhenMatch(const TokenType& token_type) { - VLOG(1) << "Parser::WhenMatch: Peek() == " << Peek(); + VLOG(9) << "Parser::WhenMatch: Peek() == " << Peek(); if (Peek()->token_type == token_type) { Consume(token_type); return true; @@ -594,7 +594,7 @@ class Parser { template R WithSpan(std::function parser) { auto start_span = Peek()->span; - VLOG(0) << "WithSpan: start_span = " << start_span; + VLOG(9) << "WithSpan: start_span = " << start_span; R ast = parser(); if (ast.defined()) { // The token at the head of the stream is now 1 past where we parsed. So we find its start @@ -608,7 +608,7 @@ class Parser { span_pos--; } auto end_token = tokens.at(span_pos); - VLOG(0) << "WithSpan: end_span = " << end_token->span; + VLOG(9) << "WithSpan: end_span = " << end_token->span; ast->span = start_span.Merge(end_token->span); } return ast; @@ -668,7 +668,7 @@ class Parser { template Array ParseSequence(TokenType start, TokenType sep, TokenType stop, std::function parse, std::function before_stop = nullptr) { - VLOG(0) << "Parser::ParseSequence: start=" << ToString(start) << " sep=" << ToString(sep) + VLOG(9) << "Parser::ParseSequence: start=" << ToString(start) << " sep=" << ToString(sep) << " stop=" << ToString(stop); Match(start); @@ -686,7 +686,7 @@ class Parser { if (WhenMatch(stop)) { return Array(); } else { - VLOG(0) << "Parser::ParseSequence: parse first"; + VLOG(9) << "Parser::ParseSequence: parse first"; auto data = parse(); Array elements = {data}; @@ -695,7 +695,7 @@ class Parser { // parse '( expr ',' * ')' } else if (WhenMatch(sep)) { while (true) { - VLOG(0) << "Parser::ParseSequence: parse element"; + VLOG(9) << "Parser::ParseSequence: parse element"; if (WhenMatch(stop)) { break; } else { @@ -893,12 +893,12 @@ class Parser { /*! \brief Parse a single Relay expression. */ Expr ParseExpr() { - VLOG(0) << "Parser::ParseExpr"; + VLOG(9) << "Parser::ParseExpr"; return WithSpan([this] { std::vector exprs; while (true) { - VLOG(0) << "Parser::ParseExpr: parsing a single expression"; + VLOG(9) << "Parser::ParseExpr: parsing a single expression"; auto next = Peek(); switch (next->token_type) { // For graph or let, match first rhs, then invoke ParseBindingExpr @@ -1011,7 +1011,7 @@ class Parser { // This ensures for n sequential bindings // the call depth will be the same before // and after parsing the n bindings. - VLOG(0) << "Parser::ParseBindingExpr"; + VLOG(9) << "Parser::ParseBindingExpr"; std::vector> bindings; int scopes = 0; @@ -1085,7 +1085,7 @@ class Parser { * Handles things of the form [T1, ..., TN](arg1: U1, ..., argN : UN) -> Ret { body }. */ Function ParseFunctionDef() { - VLOG(0) << "Parser::ParseFunctionDef"; + VLOG(9) << "Parser::ParseFunctionDef"; return WithSpan([&]() { PushScope(); PushTypeScope(); @@ -1147,7 +1147,7 @@ class Parser { /*! \brief Parse an if-expression. */ Expr ParseIf() { return WithSpan([&]() { - VLOG(0) << "Parser::ParseIf"; + VLOG(9) << "Parser::ParseIf"; Consume(TokenType::kIf); auto guard = WithSpan([&] { return Parens([&] { return ParseExpr(); }); }); @@ -1186,7 +1186,7 @@ class Parser { * This function recursively parses a pattern. */ Pattern ParsePattern() { - VLOG(0) << "Parser::ParsePattern"; + VLOG(9) << "Parser::ParsePattern"; auto next = Peek(); switch (next->token_type) { case TokenType::kUnderscore: { @@ -1249,7 +1249,7 @@ class Parser { } Expr ParseExprBinOp() { - VLOG(0) << "Parser::ParseExprBinOp"; + VLOG(9) << "Parser::ParseExprBinOp"; return WithSpan([this] { // We must parse at least one expression, the default // case is that there is no operator and we will fall @@ -1333,7 +1333,7 @@ class Parser { } ObjectRef ParseAttributeValue() { - VLOG(0) << "Parser::ParseAttributeValue"; + VLOG(9) << "Parser::ParseAttributeValue"; auto next = Peek(); switch (next->token_type) { case TokenType::kFloat: @@ -1375,7 +1375,7 @@ class Parser { } Map ParseAttrs() { - VLOG(0) << "Parser::ParseAttrs"; + VLOG(9) << "Parser::ParseAttrs"; Map kwargs; while (Peek()->token_type == TokenType::kIdentifier) { auto key = GetHierarchicalName(ParseHierarchicalName().data); @@ -1387,14 +1387,14 @@ class Parser { kwargs.Set(key, value); WhenMatch(TokenType::kComma); } - VLOG(0) << "Parser::ParseAttrs: kwargs=" << kwargs; + VLOG(9) << "Parser::ParseAttrs: kwargs=" << kwargs; return kwargs; } Expr ParseCallArgs(Expr op) { ICHECK(op.defined()) << "the operator must be defined"; - VLOG(0) << "Parser::ParseCallArgs"; + VLOG(9) << "Parser::ParseCallArgs"; Attrs attrs; std::string op_key; bool is_op = false; @@ -1471,7 +1471,7 @@ class Parser { } Expr ParseCallExpr() { - VLOG(0) << "Parser::ParseCallExpr"; + VLOG(9) << "Parser::ParseCallExpr"; return WithSpan([this] { Expr expr = ParseAtomicExpr(); // Parse as many call args as possible, building up expression @@ -1500,7 +1500,7 @@ class Parser { } Expr GetOp(const std::string& op_name, const Span& span) { - VLOG(0) << "op_name=" << op_name << " span=" << span; + VLOG(9) << "op_name=" << op_name << " span=" << span; try { return Op::Get(op_name); } catch (const Error& e) { @@ -1513,7 +1513,7 @@ class Parser { } Expr ParseAtomicExpr() { - VLOG(0) << "Parser::ParseAtomicExpr"; + VLOG(9) << "Parser::ParseAtomicExpr"; Expr expr = WithSpan([this] { auto next = Peek(); switch (next->token_type) { @@ -1649,7 +1649,7 @@ class Parser { auto token = Match(TokenType::kInteger); auto index = token.ToNumber(); auto span = token->span.Merge(expr->span); - VLOG(0) << "Parser::ParseAtomicExpr: tuple get item"; + VLOG(9) << "Parser::ParseAtomicExpr: tuple get item"; return relay::TupleGetItem(expr, index, span); } else { return expr; @@ -1870,7 +1870,7 @@ class Parser { Parser InitParser(const std::string& file_name, const std::string& file_content, const Optional& init_module, const MetaTable& init_meta_table) { - VLOG(0) << "InitParser: file_name: " << file_name << "file_content_size: " << file_content.size(); + VLOG(9) << "InitParser: file_name: " << file_name << "file_content_size: " << file_content.size(); SourceName src_name = SourceName::Get(file_name); Source source(src_name, file_content); @@ -1909,7 +1909,7 @@ Parser InitParser(const std::string& file_name, const std::string& file_content, IRModule ParseModule(const std::string& file_name, const std::string& file_content, const Optional& init_module, const MetaTable& init_meta_table) { - VLOG(0) << "ParseModule"; + VLOG(9) << "ParseModule"; auto parser = InitParser(file_name, file_content, init_module, init_meta_table); auto mod = parser.ParseModule(); ICHECK(mod.defined()) << "The parser must return a non-null module."; @@ -1923,7 +1923,7 @@ IRModule ParseModule(const std::string& file_name, const std::string& file_conte } Expr ParseExpr(const std::string& file_name, const std::string& file_content) { - VLOG(0) << "ParseExpr"; + VLOG(9) << "ParseExpr"; auto parser = InitParser(file_name, file_content, Optional(), MetaTable()); parser.ParseSemVer(false); parser.PushScope(); diff --git a/src/parser/tokenizer.h b/src/parser/tokenizer.h index 8f197db45318..f8098cf94100 100644 --- a/src/parser/tokenizer.h +++ b/src/parser/tokenizer.h @@ -339,7 +339,7 @@ struct Tokenizer { int line = this->line; int col = this->col; auto next = Peek(); - VLOG(1) << "tvm::parser::TokenizeOnce: next=" << next; + VLOG(9) << "tvm::parser::TokenizeOnce: next=" << next; if (next == '\n') { auto token = NewToken(TokenType::kNewline); Next(); @@ -550,7 +550,7 @@ struct Tokenizer { } void Tokenize() { - VLOG(0) << "tvm::parser::Tokenize"; + VLOG(9) << "tvm::parser::Tokenize"; while (this->More()) { auto token = TokenizeOnce(); ICHECK(token.defined()); diff --git a/src/relay/analysis/context_analysis.cc b/src/relay/analysis/context_analysis.cc deleted file mode 100644 index 35813f67d094..000000000000 --- a/src/relay/analysis/context_analysis.cc +++ /dev/null @@ -1,719 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/relay/analysis/context_analysis.cc - * \brief A pass for analyzing device attribute of each IR node. - * - * We use union-find data structures to analyze the context information of each - * sub-expression in a Relay program in this pass. Only the device copy node in - * Relay directly contains bidiretional device information. We use it to - * bidirectionally propagate the device info of its inputs and outputs. - * - * However, to support dynamism (e.g dynamic inputs), Relay introduces several - * concepts to compute the shape of tensors and operators at runtime, i.e. - * shape_of, shape_func, and reshape_tensor. These nodes are also referred to as - * VM dialects as we have native VM instructions for them. These dialects are - * intrinsically CPU friendly, therefore, they are only designed to be - * executed on CPU. We, hence, unify their inputs and outputs to CPU as well. - * Note the input of shape_of is a tensor and we only need the tensor shape. - * Therefore, the input could be sitting on GPU as well since no real data is - * needed. The context of the input would be propagated from its other - * consumers or fallback to the default device. - * - * Another type of dialect is used fo memory allocation, namely, alloc_storage - * and alloc_tensor. alloc_storage contains a context field to indicate where - * the chunk of memory is allocated. Therefore, we unify the context of - * alloc_storage with the context field. Other inputs, such as size and - * alignment, are left on CPU. - * - * Based on the above rules, we keep unifying the connected expressions and - * propagating their device information. An error will be raised whenever there - * is a unification conflict. All IR nodes that are not propagated with device - * context will fallback to the specified device. - */ - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace tvm { -namespace relay { - -using PackedAnalysisResultMap = Map>; -using AnalysisResultMap = - std::unordered_map; - -namespace analysis { - -// Cache ops -static const Op& device_copy_op = Op::Get("device_copy"); -static const Op& alloc_storage_op = Op::Get("memory.alloc_storage"); -static const Op& alloc_tensor_op = Op::Get("memory.alloc_tensor"); -static const Op& shape_of_op = Op::Get("vm.shape_of"); -static const Op& invoke_tvm_op = Op::Get("vm.invoke_tvm_op"); -static const Op& shape_func_of = Op::Get("vm.shape_func"); -static const Op& reshape_tensor_op = Op::Get("vm.reshape_tensor"); - -class DeviceDomain; -using DeviceDomainPtr = std::shared_ptr; - -/* - * \brief A class to represent the device of a domain, i.e. a segment of relay program. - */ -class DeviceDomain { - public: - // Construct an empty domain. - DeviceDomain() { - device_.device_type = static_cast(-1); - device_.device_id = -1; - } - - // Construct a domain based on a given context. - explicit DeviceDomain(const Device& dev) : device_(dev) {} - - // Check if the current domain is empty. - bool IsEmptyDomain() const { - return static_cast(device_.device_type) == -1 && device_.device_id == -1; - } - - // Check if the current domain equals the other one. - bool operator==(const DeviceDomain& other) const { - return device_.device_type == other.device_.device_type && - device_.device_id == other.device_.device_id; - } - - bool operator!=(const DeviceDomain& other) const { return !(*this == other); } - - private: - // Create a hash for a domain. - struct Hash { - size_t operator()(const DeviceDomainPtr& domain) const { - if (domain->IsEmptyDomain()) { - return static_cast(reinterpret_cast(domain.get())); - } else { - size_t const h1(std::hash()(static_cast(domain->device_.device_type))); - size_t const h2(std::hash()(domain->device_.device_id)); - return h1 ^ (h2 << 1); - } - } - }; - - // Create an equality for domains. - struct Equal { - public: - bool operator()(const DeviceDomainPtr& lhs, const DeviceDomainPtr& rhs) const { - // We compare the pointer for empty domains. - if (lhs->IsEmptyDomain() && rhs->IsEmptyDomain()) return lhs.get() == rhs.get(); - - // Otherwise device type and id are used to check equality. - return (*lhs.get() == *rhs.get()); - } - }; - - /* \brief The device to be assigned to the current domain. */ - Device device_; - - friend DeviceDomainPtr Join(const DeviceDomainPtr& lhs, const DeviceDomainPtr& rhs); - friend class ContextAnalyzer; -}; - -// Join two domains. -DeviceDomainPtr Join(const DeviceDomainPtr& lhs, const DeviceDomainPtr& rhs) { - if (lhs->IsEmptyDomain() && rhs->IsEmptyDomain()) { - return lhs; - } else if (lhs->IsEmptyDomain()) { - return rhs; - } else if (rhs->IsEmptyDomain()) { - return lhs; - } else { - ICHECK(*lhs.get() == *rhs.get()) << "All expressions must have a singular device to unify"; - return lhs; - } -} - -/* - * \brief Compute on which device each sub-expression will execute. A union find - * algorithm is used to assign and merge the context domains. - */ -class ContextAnalyzer : public MixedModeVisitor { - public: - ContextAnalyzer(const IRModule& mod, const GlobalVar& current_func, - const Device& default_device) - : MixedModeVisitor(9), // the number of repeated visits a node can perform - mod_(mod), - current_func_(current_func), - default_device_(default_device) { - cpu_dev_.device_type = kDLCPU; - cpu_dev_.device_id = 0; - } - - // Create an empty domain. - // This usually happens when we enter a new scope, i.e. Function. - DeviceDomainPtr Bottom() { return std::make_shared(DeviceDomain()); } - - // Create a domain with the given device context. - DeviceDomainPtr DeviceType(const Device& dev) { - return std::make_shared(DeviceDomain(dev)); - } - - // Find the root of a device. - DeviceDomainPtr Lookup(DeviceDomainPtr device) { - while (device_uf_.count(device) && device != device_uf_[device]) { - // Path compression - if (device_uf_.count(device_uf_[device])) { - device_uf_[device] = device_uf_[device_uf_[device]]; - } - device = device_uf_[device]; - } - return device; - } - - // Unify two domains. - DeviceDomainPtr Unify(DeviceDomainPtr lhs, DeviceDomainPtr rhs) { - lhs = Lookup(lhs); - rhs = Lookup(rhs); - auto unified_device = Join(lhs, rhs); - if (lhs != unified_device) { - device_uf_[lhs] = unified_device; - } - - if (rhs != unified_device) { - device_uf_[rhs] = unified_device; - } - - return unified_device; - } - - // Unify the domain for two IR nodes. - DeviceDomainPtr UnifyExpr(const Expr& lhs, const Expr& rhs) { - auto lhs_dom = DeviceFor(lhs); - auto rhs_dom = DeviceFor(rhs); - return Unify(lhs_dom, rhs_dom); - } - - // Lookup or insert an IR node to device domain map. - DeviceDomainPtr DeviceFor(const Expr& expr) { - auto it = expr_to_device_.find(expr); - if (it == expr_to_device_.end()) { - auto bottom = Bottom(); - expr_to_device_[expr] = bottom; - return bottom; - } else { - return it->second; - } - } - - // Unify the device context for a device copy node. Device copy node is - // the only node that carries bidirectional devices in the input program. The device - // attribute of other nodes can be propagated from it. - void UnifyDeviceCopy(const std::vector& inps, const std::vector& outputs, - DLDeviceType src_dev_type, DLDeviceType dst_dev_type) { - Device src_dev; - src_dev.device_type = src_dev_type; - src_dev.device_id = 0; - auto src_domain = DeviceType(src_dev); - for (const auto& it : inps) { - auto lhs = DeviceFor(it); - Unify(lhs, src_domain); - } - - Device dst_dev; - dst_dev.device_type = dst_dev_type; - dst_dev.device_id = 0; - auto dst_domain = DeviceType(dst_dev); - for (const auto& it : outputs) { - auto lhs = DeviceFor(it); - Unify(lhs, dst_domain); - } - } - - // Unify the domain of inputs and outputs of a relay call. - // - // For most call nodes, the op, inputs, and outputs should all be in the - // same domain, i.e. having the same context. However, device_copy call node - // needs to be handled differently as it copies data from one device to - // another. - DeviceDomainPtr UnifyCall(const Expr& call_op, const Array& inps, - const Array& outputs, DeviceDomainPtr device) { - device = Unify(device, DeviceFor(call_op)); - - for (const auto& it : inps) { - device = Unify(device, DeviceFor(it)); - } - - for (const auto& it : outputs) { - device = Unify(device, DeviceFor(it)); - } - - return device; - } - - void VisitExpr_(const CallNode* cn) final { - Call call = GetRef(cn); - - if (IsDeviceCopy(call)) { - UnifyDeviceCopyCall(cn); - } else if (call->op == alloc_storage_op) { - UnifyAllocStorageCall(cn); - } else if (call->op == alloc_tensor_op) { - UnifyAllocTensorCall(cn); - } else if (call->op == shape_func_of) { - UnifyShapeFuncCall(cn); - } else if (call->op == shape_of_op) { - UnifyShapeOfCall(cn); - } else if (call->op == invoke_tvm_op) { - UnifyInvokeTVMOpCall(cn); - } else if (call->op == reshape_tensor_op) { - UnifyReshapeTensorCall(cn); - } else if (call->op.as()) { - UnifyFunctionCall(cn); - } else if (call->op.as()) { - UnifyGlobalVarCall(cn); - } else if (call->op.as()) { - UnifyVarCall(cn); - } else { - UnifyCall(call, cn->args, {call}, Bottom()); - MixedModeVisitor::VisitExpr_(cn); - } - } - - void VisitExpr_(const LetNode* ln) final { - Expr expr = GetRef(ln); - // Iteratively visit let nodes to avoid stack overflow. - while (expr->IsInstance()) { - Let let = Downcast(expr); - // Save currying/closures since they will be invoked later - auto ty = let->value->checked_type(); - if (ty->IsInstance()) { - auto gv = ExtractClosure(let); - ICHECK(gv.defined() && gv->IsInstance()); - closures_[let->var] = Downcast(gv); - } - - // Unify let var, value, and body - Unify(DeviceFor(let->var), DeviceFor(let->value)); - UnifyExpr(let, let->body); - MixedModeVisitor::VisitExpr(let->value); - expr = let->body; - } - // Visit the last body - MixedModeVisitor::VisitExpr(expr); - } - - void VisitExpr_(const FunctionNode* fn) final { - auto func = GetRef(fn); - // No need to step into fused primitive functions as they are handled as - // a whole. - if (fn->HasNonzeroAttr(attr::kPrimitive)) { - return; - } - - auto device = Unify(DeviceFor(func), DeviceFor(fn->body)); - for (const auto& it : fn->params) { - DeviceFor(it); - } - MixedModeVisitor::VisitExpr(fn->body); - } - - void VisitExpr_(const TupleNode* tn) final { - // We only support tuple with the same of device. - Tuple tup = GetRef(tn); - if (tn->fields.size() > 0) { - auto device = DeviceFor(tup->fields[0]); - for (size_t i = 1; i < tup->fields.size(); i++) { - device = Unify(device, DeviceFor(tup->fields[i])); - } - Unify(device, DeviceFor(tup)); - } - MixedModeVisitor::VisitExpr_(tn); - } - - void VisitExpr_(const TupleGetItemNode* tn) final { - TupleGetItem item = GetRef(tn); - - Unify(DeviceFor(item), DeviceFor(item->tuple)); - - MixedModeVisitor::VisitExpr_(tn); - } - - void VisitExpr_(const MatchNode* mn) final { - // For match node, we unify the value and the rhs of each clause - Match m = GetRef(mn); - auto device = Unify(DeviceFor(m), DeviceFor(m->data)); - for (const auto& c : m->clauses) { - device = Unify(device, DeviceFor(c->rhs)); - } - MixedModeVisitor::VisitLeaf(mn->data); - for (const Clause& c : mn->clauses) { - this->VisitClause(c); - MixedModeVisitor::VisitLeaf(c->rhs); - } - } - - void VisitExpr_(const GlobalVarNode* gvn) final { DeviceFor(GetRef(gvn)); } - - void VisitExpr_(const VarNode* vn) { DeviceFor(GetRef(vn)); } - - void VisitExpr_(const ConstantNode* cn) final { DeviceFor(GetRef(cn)); } - - // Return the analysis results. - AnalysisResultMap Results() { - AnalysisResultMap ret; - for (const auto& it : expr_to_device_) { - auto device = Lookup(it.second); - if (device->IsEmptyDomain()) { - ret[it.first] = default_device_; - } else { - ret[it.first] = device->device_; - } - } - - return ret; - } - - private: - Expr ExtractClosure(Expr expr) const { - while (expr->IsInstance()) { - Let let = Downcast(expr); - expr = let->value; - if (expr->IsInstance()) { - return expr; - } else { - const auto* cn = expr.as(); - if (cn && cn->op->IsInstance()) { - return cn->op; - } - } - } - return Expr(nullptr); - } - - // Check if an expression is a device copy call. - bool IsDeviceCopy(const Expr& expr) const { - if (!expr->IsInstance()) return false; - - Call call = Downcast(expr); - if (call->op == device_copy_op) return true; - - // Fused function with device copy op as the body - // device copy op is opaque therefore the fused function only has one node. - if (const FunctionNode* fn = call->op.as()) { - if (const CallNode* cn = fn->body.as()) { - return cn->op == device_copy_op; - } - } - - return false; - } - - // Check if a function is a closure. - bool IsClosure(const Function& func) { return func->GetAttr(attr::kClosure, 0) != 0; } - - // Check if a function is a currying function. - bool IsCurrying(const Function& func) { - if (const auto* let = func->body.as()) { - return closures_.find(let->var) != closures_.end(); - } - return false; - } - - // Process device copy call node - void UnifyDeviceCopyCall(const CallNode* call) { - ICHECK_EQ(call->args.size(), 1U); - - std::vector inps{call->args[0]}; - std::vector outs{GetRef(call)}; - DLDeviceType src_dev_type, dst_dev_type; - const DeviceCopyAttrs* attrs = nullptr; - if (const auto* fn = call->op.as()) { - // device_copy is fused, propagate device to the fused function. - inps.push_back(fn->params[0]); - outs.push_back(call->op); - Expr body = fn->body; - ICHECK(body->IsInstance() && IsDeviceCopy(body)); - Call call_body = Downcast(body); - attrs = call_body->attrs.as(); - } else { - attrs = call->attrs.as(); - } - ICHECK(attrs != nullptr); - src_dev_type = static_cast(attrs->src_dev_type); - dst_dev_type = static_cast(attrs->dst_dev_type); - - // Device copy op only has one input which is now annotated with the - // same device to the source device type of the device copy op. - // The call itself has the same device type to the destination. - UnifyDeviceCopy(inps, outs, src_dev_type, dst_dev_type); - MixedModeVisitor::VisitExpr_(call); - } - - void UnifyAllocStorageCall(const CallNode* call) { - // [size, alignment] - ICHECK_EQ(call->args.size(), 2U); - - // The arguments of alloc storage should be on CPU. - for (int i = 0; i < 2; i++) { - Unify(DeviceFor(call->args[i]), DeviceType(cpu_dev_)); - MixedModeVisitor::VisitExpr(call->args[i]); - } - Device dev; - const auto* attrs = call->attrs.as(); - dev.device_type = static_cast(attrs->device_type); - dev.device_id = attrs->device_id; - Unify(DeviceFor(GetRef(call)), DeviceType(dev)); - } - - void UnifyAllocTensorCall(const CallNode* call) { - // [storage, offset, shape] - ICHECK_EQ(call->args.size(), 3U); - - Expr storage = call->args[0]; - Expr shape = call->args[1]; - Unify(DeviceFor(storage), DeviceFor(GetRef(call))); - - // The shape for alloc_tensor should be on CPU. - Unify(DeviceFor(shape), DeviceType(cpu_dev_)); - MixedModeVisitor::VisitExpr(shape); - } - - void UnifyShapeFuncCall(const CallNode* call) { - // [func, inputs, outputs] - ICHECK_EQ(call->args.size(), 3U); - auto shape_func_domain = DeviceType(cpu_dev_); - - // No need to unify the op of a shape_func as shape_func doesn't - // invoke the op itself. It should be handled by invoke_tvm_op. - // Therefore, we skip call.args[0] here. - Tuple inps = Downcast(call->args[1]); - Tuple outputs = Downcast(call->args[2]); - UnifyCall(GetRef(call), inps->fields, outputs->fields, shape_func_domain); - for (const auto& it : inps->fields) { - MixedModeVisitor::VisitExpr(it); - } - - for (const auto& it : outputs->fields) { - MixedModeVisitor::VisitExpr(it); - } - } - - void UnifyInvokeTVMOpCall(const CallNode* call) { - // [op, inputs, outputs] - ICHECK_EQ(call->args.size(), 3U); - Tuple inps = Downcast(call->args[1]); - Tuple outputs = Downcast(call->args[2]); - UnifyCall(call->args[0], inps->fields, outputs->fields, Bottom()); - MixedModeVisitor::VisitExpr_(call); - } - - void UnifyShapeOfCall(const CallNode* call) { - // vm shape_of is always on the CPU. - ICHECK_EQ(call->args.size(), 1U); - MixedModeVisitor::VisitExpr(call->args[0]); - // Note we don't unify the input of a shape_of with the cpu domain. This is - // because vm.shape_of has a native instruction to compute the shape of - // a tensor regardless its device type. - // Instead, the device type of the input is left for its other consumers to - // unify or it will fallback to the default context. - Unify(DeviceFor(GetRef(call)), DeviceType(cpu_dev_)); - } - - void UnifyReshapeTensorCall(const CallNode* call) { - // [data, shape] - ICHECK_EQ(call->args.size(), 2U); - Expr data = call->args[0]; - Expr shape = call->args[1]; - Unify(DeviceFor(GetRef(call)), DeviceFor(data)); - - // The shape field of reshape_tensor is always on the CPU. - Unify(DeviceFor(shape), DeviceType(cpu_dev_)); - MixedModeVisitor::VisitExpr(data); - MixedModeVisitor::VisitExpr(shape); - } - - void UnifyFunctionCall(const CallNode* call) { - auto device = DeviceFor(GetRef(call)); - // Unify the arguments of the caller. - for (const auto& arg : call->args) { - device = Unify(device, DeviceFor(arg)); - MixedModeVisitor::VisitExpr(arg); - } - - // Unify the parameters of the callee. - if (!call->op->IsInstance()) return; - Function func = Downcast(call->op); - for (const auto& param : func->params) { - device = Unify(device, DeviceFor(param)); - MixedModeVisitor::VisitExpr(param); - } - - // Unify the function expression and its body - Unify(device, DeviceFor(call->op)); - Unify(device, DeviceFor(func->body)); - - // Step into the callee. It will be skipped if the callee if a primitive - // function - MixedModeVisitor::VisitExpr(call->op); - } - - // Invoke a global function. - void UnifyGlobalVarCall(const CallNode* call) { - auto device = DeviceFor(GetRef(call)); - ICHECK(mod_.defined()) << "Cannot analyze context on a globalvar without module"; - GlobalVar gv = Downcast(call->op); - auto func = Downcast(mod_->Lookup(gv)); - ICHECK_EQ(call->args.size(), func->params.size()) - << "The number of arguments doesn't match the number of parameters of the function."; - - for (size_t i = 0; i < call->args.size(); i++) { - Expr arg = call->args[i]; - Expr param = func->params[i]; - MixedModeVisitor::VisitExpr(arg); - - // Save the the arg to function mapping for closures as it will - // be invoked/unified later. - ICHECK(arg->checked_type().defined()) - << "Type inference is required to run the context analysis passes."; - if (arg->checked_type()->IsInstance()) { - auto it = closures_.find(arg); - if (it != closures_.end()) { - closures_[param] = it->second; - } else { - ICHECK(arg->IsInstance()); - closures_[param] = Downcast(arg); - } - } - Unify(DeviceFor(arg), DeviceFor(param)); - } - device = Unify(device, DeviceFor(call->op)); - device = Unify(device, DeviceFor(func)); - device = Unify(device, DeviceFor(func->body)); - - // Step into the callee. We need to skip recursive calls, otherwise, it - // would be a infinite loop. - // - // TODO(@zhiics) This may cause problem for mutual recursive calls as well. - auto cur_func = current_func_; - current_func_ = gv; - if (cur_func->name_hint != gv->name_hint) { - MixedModeVisitor::VisitExpr(func); - } - // Exit the frame. - current_func_ = cur_func; - } - - void UnifyVarCall(const CallNode* call) { - // It is a closure when we call a var. - // Unify the corresponding arguement and parameter. - auto device = DeviceFor(GetRef(call)); - auto it = closures_.find(call->op); - ICHECK(it != closures_.end()) << "Cannot find var: " << call->op; - auto glb_var = it->second; - ICHECK(mod_.defined()) << "Cannot analyze context on a globalvar without module"; - Function func = Downcast(mod_->Lookup(glb_var)); - // Unify the underlying function for clousre or currying functions. - while (IsClosure(func) || IsCurrying(func)) { - device = Unify(device, DeviceFor(func)); - if (IsClosure(func)) { - func = Downcast(func->body); - } else if (IsCurrying(func)) { - Let let = Downcast(func->body); - func = Downcast(mod_->Lookup(closures_[let->var])); - } else { - LOG(FATAL) << "func is expected to be a closure or a currying function"; - } - } - - ICHECK_EQ(call->args.size(), func->params.size()); - for (size_t i = 0; i < call->args.size(); i++) { - Unify(DeviceFor(call->args[i]), DeviceFor(func->params[i])); - MixedModeVisitor::VisitExpr(call->args[i]); - } - device = Unify(device, DeviceFor(call->op)); - device = Unify(device, DeviceFor(glb_var)); - device = Unify(device, DeviceFor(func)); - - // Step into the global function. - auto cur_func = current_func_; - current_func_ = glb_var; - if (cur_func->name_hint != glb_var->name_hint) { - MixedModeVisitor::VisitExpr(func); - } - current_func_ = cur_func; - } - - private: - /* \brief The cpu context. */ - Device cpu_dev_; - /* \brief The module that helps context analysis. */ - const IRModule& mod_; - /* \brief The current function that is being analyzed. */ - GlobalVar current_func_; - /* \brief The default device that could be attached to an expression. */ - const Device& default_device_; - /* \brief The IR node to device domain mapping. */ - std::unordered_map - expr_to_device_; - /* \brief The domain map for union-find. */ - std::unordered_map - device_uf_; - /* - * \brief The expr to global var map. It saves the closures/currying that - * will be invoked lazily. - */ - std::unordered_map closures_; -}; - -} // namespace analysis - -AnalysisResultMap ContextAnalysis(const IRModule& mod, const Device& default_device) { - // TODO(@zhiics) Apply the pass to all functions/entries - auto entry = mod->GetGlobalVar("main"); - auto ca = analysis::ContextAnalyzer(mod, entry, default_device); - auto expr = mod->Lookup(entry); - ca.VisitExpr(expr); - return ca.Results(); -} - -// Unpack the device type and deivce id fields in Device for PackedFunc calls -// as Device is not in the object system. -PackedAnalysisResultMap ContextAnalysisPacked(const IRModule& mod, const Device& default_device) { - PackedAnalysisResultMap ret; - auto res = ContextAnalysis(mod, default_device); - for (const auto& it : res) { - Integer dev_ty = static_cast(it.second.device_type); - Integer dev_id = it.second.device_id; - ret.Set(it.first, {dev_ty, dev_id}); - } - - return ret; -} - -TVM_REGISTER_GLOBAL("relay.analysis.ContextAnalysis").set_body_typed(ContextAnalysisPacked); - -} // namespace relay -} // namespace tvm diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index fc850e37379c..38eb6aa6a07e 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -23,6 +23,7 @@ */ #include +#include #include #include #include @@ -38,6 +39,8 @@ #include #include +#include "../op/annotation/annotation.h" +#include "../transforms/device_aware_visitors.h" #include "./te_compiler.h" #include "./utils.h" @@ -53,18 +56,12 @@ using StorageMap = * This is an on demand allocator for AOT. A new temporary * (storage allocator identifier) is allocated for each operation. */ -class AOTOnDemandAllocator : public MixedModeVisitor { +class AOTOnDemandAllocator : public transform::DeviceAwareExprVisitor { public: - // run the visitor on a function. - void Run(const Function& func) { - node_device_map_ = CollectDeviceInfo(func); + AOTOnDemandAllocator() : transform::DeviceAwareExprVisitor(Optional()) {} - for (Expr param : func->params) { - CreateStorage(param.operator->()); - } - - GetStorage(func->body); - } + // run the visitor on a global function. + void Run(const Function& func) { VisitExpr(func); } std::vector GetReturnIds() const { return return_ids_; } @@ -75,8 +72,9 @@ class AOTOnDemandAllocator : public MixedModeVisitor { AssignReturnSid(GetRef(op)); } - void VisitExpr_(const CallNode* op) final { + void DeviceAwareVisitExpr_(const CallNode* op) final { // create token for the call node. + VisitExpr(op->op); CreateStorage(op); for (Expr arg : op->args) { GetStorage(arg); @@ -86,8 +84,19 @@ class AOTOnDemandAllocator : public MixedModeVisitor { void VisitExpr_(const VarNode* op) final { AssignReturnSid(GetRef(op)); } - void VisitExpr_(const FunctionNode* op) final { - // do not recurse into sub function. + void DeviceAwareVisitExpr_(const FunctionNode* func_node) final { + if (function_nesting() > 1) { + // do not recurse into sub functions. + return; + } + if (func_node->HasNonzeroAttr(attr::kPrimitive)) { + // No storage needed for primitive functions. + return; + } + for (const auto& param : func_node->params) { + CreateStorage(param.get()); + } + GetStorage(func_node->body); } void VisitExpr_(const GlobalVarNode* op) final { @@ -127,7 +136,9 @@ class AOTOnDemandAllocator : public MixedModeVisitor { void VisitExpr_(const IfNode* op) final { LOG(FATAL) << "if is not supported."; } - void VisitExpr_(const LetNode* op) final { LOG(FATAL) << "let is not supported."; } + void PreVisitLetBinding_(const Var& var, const Expr& value) final { + LOG(FATAL) << "let is not supported."; + } private: void AssignReturnSid(Expr e) { @@ -151,9 +162,10 @@ class AOTOnDemandAllocator : public MixedModeVisitor { * \brief Get the memory requirement. * \param prototype The prototype token. * \return The required memory size. + * + * TODO(mbs): Cf CalculateRelayExprSizeBytes in utils.cc */ - size_t GetMemorySizeBytes(const TensorTypeNode* ttype) { - ICHECK(ttype != nullptr); + size_t GetMemorySizeBytes(const TensorType& ttype) { size_t size = 1; for (IndexExpr dim : ttype->shape) { const int64_t* pval = tir::as_const_int(dim); @@ -170,45 +182,43 @@ class AOTOnDemandAllocator : public MixedModeVisitor { * \return The corresponding token. */ StorageInfo GetStorage(const Expr& expr) { - this->VisitExpr(expr); - auto it = storage_device_map_.find(expr); + auto props = GetOnDeviceProps(expr); + // See through "on_device" calls. + Expr true_expr = props.body.defined() ? props.body : expr; + VisitExpr(true_expr); + auto it = storage_device_map_.find(true_expr); ICHECK(it != storage_device_map_.end()); return it->second; } /*! * \brief Create storage for the expression. - * \param expr The expression. */ void CreateStorage(const ExprNode* op) { + Expr expr = GetRef(op); + return CreateStorage(expr, GetInScopeDeviceType(expr)); + } + + /*! + * \brief Create storage to hold the result of evaluating \p expr on \p device_type. + */ + void CreateStorage(const Expr& expr, DLDeviceType device_type) { + ICHECK(device_type != kInvalidDeviceType) << "invalid device type for expr:" << std::endl + << PrettyPrint(expr); std::vector storage_ids; std::vector device_types; std::vector storage_sizes_in_bytes; - Expr expr = GetRef(op); - int device_type_int = - node_device_map_.count(GetRef(op)) ? node_device_map_[expr]->value : 0; - if (const auto* tuple_type = op->checked_type().as()) { - for (Type t : tuple_type->fields) { - const auto* ttype = t.as(); - ICHECK(ttype); - storage_ids.push_back(next_available_sid_++); - storage_sizes_in_bytes.push_back(GetMemorySizeBytes(ttype)); - device_types.push_back(DLDeviceType(device_type_int)); - } - } else { - const auto* ttype = op->checked_type().as(); - ICHECK(ttype); + for (const auto& ttype : FlattenTupleType(expr->checked_type())) { storage_ids.push_back(next_available_sid_++); + device_types.push_back(device_type); storage_sizes_in_bytes.push_back(GetMemorySizeBytes(ttype)); - device_types.push_back(DLDeviceType(device_type_int)); } storage_device_map_[expr] = StorageInfo(storage_ids, device_types, storage_sizes_in_bytes); } - /*! \brief mapping of expression -> storageInfo*/ + + /*! \brief mapping of expression -> storageInfo */ StorageMap storage_device_map_; - /*! \brief mapping of expression -> device type*/ - Map node_device_map_; - /*! \brief current id of the temporary allocated*/ + /*! \brief current id of the temporary allocated */ int next_available_sid_{0}; /*! \brief the set of intermediate tensors that are return variables */ std::vector return_ids_; @@ -560,29 +570,14 @@ class AOTExecutorCodegen : public MixedModeVisitor { use_unpacked_api_(target_host->GetAttr("unpacked-api").value_or(Bool(false))) {} LoweredOutput Codegen(relay::Function func, String mod_name) { - auto aot_allocator = AOTOnDemandAllocator(); - aot_allocator.Run(func); + AOTOnDemandAllocator initial_aot_allocator; + initial_aot_allocator.Run(func); // Pre-lowering storage map and memory plan - StorageMap initial_storage_map = aot_allocator.GetStorageMap(); + // TODO(mbs): Why plan memory and update workspace sizes before lowering? + StorageMap initial_storage_map = initial_aot_allocator.GetStorageMap(); StaticMemoryPlan memory_plan(initial_storage_map); - // Build a map from each operation to device. - tec::DeviceMap device_context_map; - for (const auto& it : memory_plan->expr_to_storage_info) { - auto expr = it.first; - auto storage_info = it.second; - auto device_types = storage_info->device_types; - // CHECK_EQ(device_types.size(), 1); - tvm::Device dev; - dev.device_id = 0; - dev.device_type = device_types[0]; - device_context_map.insert({expr, dev}); - } - - // This first phase moves from implicit use of compile engine, - // to instead explicitly lowering the incoming IRModule, and then - // performing the preexisting AOT executor code generation phase. IRModule mod = IRModule::FromExpr(func); backend::FunctionInfo func_info; @@ -593,29 +588,28 @@ class AOTExecutorCodegen : public MixedModeVisitor { mod = WithAttr(mod, "main_func_info", func_info); } - IRModule lowered_mod = - tec::LowerTEPass(targets_, device_context_map, mod_name, [this](Function func) { - // We need to maintain the constant map for external - // functions so we pass this processing function which - // allows us to process each function as we lower it. - if (func->GetAttr(attr::kCompiler).defined()) { - UpdateConstants(func, ¶ms_); - } - - // TODO(@areusch, @jroesch): We should refactor this to - // execute as a further pass, instead writing data to the - // lowering process directly. - tec::UpdateFunctionMetadata(func, this->function_metadata_); - })(mod); + IRModule lowered_mod = tec::LowerTEPass(targets_, mod_name, [this](Function func) { + // We need to maintain the constant map for external + // functions so we pass this processing function which + // allows us to process each function as we lower it. + if (func->GetAttr(attr::kCompiler).defined()) { + UpdateConstants(func, ¶ms_); + } + + // TODO(@areusch, @jroesch): We should refactor this to + // execute as a further pass, instead writing data to the + // lowering process directly. + tec::UpdateFunctionMetadata(func, this->function_metadata_); + })(mod); auto lowered_main = lowered_mod->Lookup("main"); auto lowered_main_func = GetRef(lowered_main.as()); // Post-lowering storage map for writing main func - this should be the same map as previously // created, just referencing the new expressions created from lowering - auto new_allocator = AOTOnDemandAllocator(); - new_allocator.Run(lowered_main_func); - storage_device_map_ = new_allocator.GetStorageMap(); + AOTOnDemandAllocator final_aot_allocator; + final_aot_allocator.Run(lowered_main_func); + storage_device_map_ = final_aot_allocator.GetStorageMap(); for (auto input : lowered_main_func->params) { input_vars_.push_back(input); @@ -637,7 +631,7 @@ class AOTExecutorCodegen : public MixedModeVisitor { } // Retrieve the return sids - return_sid_ = aot_allocator.GetReturnIds(); + return_sid_ = final_aot_allocator.GetReturnIds(); for (unsigned int output_index = 0; output_index < return_sid_.size(); output_index++) { main_signature_.push_back(tir::Var("output", DataType::Handle())); } diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 69dced36295e..ef82ed617508 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -280,6 +280,15 @@ class RelayBuildModule : public runtime::ModuleNode { */ void Build(IRModule mod, const TargetsMap& targets, const tvm::Target& target_host, const String executor, const String mod_name) { + for (const auto& pair : targets) { + VLOG(0) << "Build target " << pair.first << " = " << pair.second->str(); + } + if (target_host.defined()) { + VLOG(0) << "Build target_host = " << target_host->str(); + } + VLOG(0) << "Build executor = '" << executor << "'"; + VLOG(0) << "Build mod_name = '" << mod_name << "'"; + // Create protected variable targets_ from ground up targets_ = targets; target_host_ = target_host; @@ -302,6 +311,13 @@ class RelayBuildModule : public runtime::ModuleNode { */ IRModule Optimize(IRModule relay_module, const TargetsMap& targets, const std::unordered_map& params) { + targets_ = targets; + // No target_host setup it seems. + return OptimizeImpl(relay_module, params); + } + + IRModule OptimizeImpl(IRModule relay_module, + const std::unordered_map& params) { ICHECK(relay_module.defined()) << "The IRModule must be defined for the Relay compiler."; if (params.size()) { @@ -313,40 +329,66 @@ class RelayBuildModule : public runtime::ModuleNode { relay_module_ptr->Update(main_glb_var, new_main); } - Array pass_seqs = GetPassPrefix(targets, false); + Array pass_seqs = GetPassPrefix(targets_, false); + transform::PassContext pass_ctx = PassContext::Current(); + + // TODO(mbs): Centralize this logic and reconcile with similar in relay/backend/vm/compiler.cc + DLDeviceType default_device_type; + if (targets_.size() == 1) { + // Homogenous execution. + default_device_type = static_cast((*targets_.begin()).first->value); + const auto& target = (*targets_.begin()).second; - if (targets.size() == 1) { - const auto& target = (*targets.begin()).second; + // This pass currently only supports the homogeneous case. pass_seqs.push_back( transform::SplitArgs(target->GetAttr("max_function_args", -1).value())); + } else { + // Heterogeneous execution. + Optional opt_fallback_dev = + pass_ctx->GetConfig("relay.fallback_device_type"); + if (opt_fallback_dev) { + default_device_type = static_cast(opt_fallback_dev.value()->value); + Integer integer(static_cast(default_device_type)); + CHECK_GT(default_device_type, 0U) + << "The 'relay.fallback_device_type' is set to an invalid device type."; + if (targets_.count(integer) == 0) { + LOG(WARNING) + << "The 'relay.fallback_device_type' has been set to " << default_device_type + << " however no target has been given for that device type in the targets map. " + "Creating an appropriate default target."; + targets_.Set(integer, CreateDefaultTarget(default_device_type)); + } + } else { + default_device_type = kDLCPU; + Integer integer(static_cast(default_device_type)); + if (targets_.count(integer) == 0) { + LOG(WARNING) << "Using the default device type of kDLCPU, however no target has been " + "given for that device type in the targets map. Creating an appropriate " + "default target."; + targets_.Set(integer, CreateDefaultTarget(default_device_type)); + } + } } + // Always plan devices so the remaining passes don't need to distinguish homogeneous vs + // hetrogenous execution. + pass_seqs.push_back(transform::PlanDevices(default_device_type)); + + // Fuse the operations if it is needed. + pass_seqs.push_back(transform::FuseOps()); + // Create a sequential pass and perform optimizations. transform::Pass seq = transform::Sequential(pass_seqs); - if (targets.size() == 1) { - const auto& it = targets.begin(); - With tctx((*it).second); + if (targets_.size() == 1) { + With tctx((*targets_.begin()).second); relay_module = seq(relay_module); } else { relay_module = seq(relay_module); } - // Handle heterogeneous compilation. - transform::PassContext pass_ctx = PassContext::Current(); - if (targets_.size() > 1) { - Optional opt_fallback_dev = - pass_ctx->GetConfig("relay.fallback_device_type", Integer(static_cast(kDLCPU))); - auto fallback_dev = opt_fallback_dev.value(); - ICHECK_GT(fallback_dev->value, 0U); - relay_module = RunDeviceAnnotationPass(relay_module, fallback_dev->value); - } - - // Fuse the operations if it is needed. - relay_module = transform::FuseOps()(relay_module); - // Do layout rewrite for auto-scheduler. - if (backend::IsAutoSchedulerEnabled() && targets.size() == 1) { - const auto& target = (*targets.begin()).second; + if (backend::IsAutoSchedulerEnabled() && targets_.size() == 1) { + const auto& target = (*targets_.begin()).second; Pass major_pass = transform::AutoSchedulerLayoutRewrite(); bool enable_layout_rewrite_targets = target->kind->device_type == kDLCPU || target->GetAttr("device", "") == "mali"; @@ -378,83 +420,15 @@ class RelayBuildModule : public runtime::ModuleNode { } /*! - * \brief Create a default type. - * \param device_type The device type index. - * \return the default target for the device. + * \brief Returns a default target to represent \p device_type. */ - Target CreateDefaultTarget(int device_type) { + static Target CreateDefaultTarget(DLDeviceType device_type) { std::string name = runtime::DeviceName(device_type); - if (name == "cpu") return Target("llvm"); - if (name == "cuda") return Target("cuda"); - return Target(name); - } - - /*! - * \brief Update the target and fallback device required for heterogeneous - * compilation. CPU is used as the fallback device if it wasn't provided. - * Meanwhile, a CPU device type and "llvm" pair will be added to the target - * dictionary in this case. - * - * \param fallback_device The fallback device for heterogeneous execution. - */ - void UpdateHeterogeneousInputs(int fallback_device) { - std::unordered_map tmp_map; - for (const auto& kv : targets_) { - tmp_map[kv.first->value] = kv.second; - } - if (tmp_map.count(fallback_device) == 0) { - targets_.Set(fallback_device, CreateDefaultTarget(fallback_device)); - } - } - - /*! - * \brief Execute the device annotation passes to update the input program and - * target information. - * - * \param relay_module The input Relay module. - * \param fallback_device The fallback device for heterogeneous execution. - * - * \return updated_module The updated module after device annotation. - */ - IRModule RunDeviceAnnotationPass(const IRModule& relay_module, int fallback_device) { - UpdateHeterogeneousInputs(fallback_device); - auto rewrite = transform::RewriteAnnotatedOps(fallback_device); - auto updated_module = rewrite(relay_module); - ICHECK(updated_module.defined()); - - tvm::Map device_map; - for (const auto& it : updated_module->functions) { - device_map = relay::CollectDeviceInfo(it.second); - if (!device_map.empty()) break; - } - - if (device_map.empty()) { - tvm::Map annotation_map; - for (const auto& it : relay_module->functions) { - annotation_map = relay::CollectDeviceAnnotationOps(it.second); - if (!annotation_map.empty()) break; - } - // None op is annotated but they are fallen back to the default device. - if (annotation_map.empty()) { - targets_.Set(0, CreateDefaultTarget(fallback_device)); - } else { - // All ops are annotated to the same device type. - int64_t dev_type = -1; - for (auto kv : annotation_map) { - dev_type = kv.second->value; - break; - } - for (auto kv : annotation_map) { - ICHECK_EQ(kv.second->value, dev_type) << "Expressions in the function are " - << "annotated with various device types," - << "but not device copy operators " - << "found. Please check the " - << "RewriteAnnotation pass."; - } - targets_.Set(0, CreateDefaultTarget(dev_type)); - } + if (name == "cpu") { + return Target("llvm"); + } else { + return Target(name); } - return updated_module; } /*! @@ -476,7 +450,7 @@ class RelayBuildModule : public runtime::ModuleNode { CheckAndUpdateHostConsistency(&targets_, &target_host); // Relay IRModule -> IRModule optimizations. - relay_module = Optimize(relay_module, targets_, params); + relay_module = OptimizeImpl(relay_module, params); // Get the updated function. auto func = Downcast(relay_module->Lookup("main")); diff --git a/src/relay/backend/graph_executor_codegen.cc b/src/relay/backend/graph_executor_codegen.cc index 92e7568d9f38..dbe14b63293f 100644 --- a/src/relay/backend/graph_executor_codegen.cc +++ b/src/relay/backend/graph_executor_codegen.cc @@ -36,13 +36,17 @@ #include #include +#include "../op/annotation/annotation.h" +#include "../transforms/device_aware_visitors.h" #include "./te_compiler.h" #include "./utils.h" namespace tvm { namespace relay { + // TODO(@jroesch, @csullivan): declare directly elsewhere backend::StaticMemoryPlan GraphPlanMemory(const Function& func); + namespace backend { class GraphNode; @@ -196,30 +200,19 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslatorGetAttr(attr::kCompiler).defined()) { - UpdateConstants(func, ¶ms_); - } + IRModule lowered_mod = tec::LowerTEPass(targets_, mod_name_, [this](Function func) { + // We need to maintain the constant map for external + // functions so we pass this processing function which + // allows us to process each function as we lower it. + if (func->GetAttr(attr::kCompiler).defined()) { + UpdateConstants(func, ¶ms_); + } - // TODO(@areusch, @jroesch): We should refactor this to - // execute as a further pass, instead writing data to the - // lowering process directly. - tec::UpdateFunctionMetadata(func, this->function_metadata_); - })(mod); + // TODO(@areusch, @jroesch): We should refactor this to + // execute as a further pass, instead writing data to the + // lowering process directly. + tec::UpdateFunctionMetadata(func, this->function_metadata_); + })(mod); Optional main_func_info = lowered_mod->GetAttr("main_func_info"); @@ -454,18 +446,21 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator VisitExpr_(const CallNode* call_node) override { relay::Call call = GetRef(call_node); - if (auto global_node = call->op.as()) { - auto prim_fn_name = global_node->name_hint; - - return GraphAddCallNode(call_node, prim_fn_name, GraphAttrs()); - } else { - ICHECK(false) << "Non-primitive-call nodes should have been transformed away.\n" - << "The graph executor code generator expects all calls to have their callee " - "normalized to a GlobalVar but found a " - << call->GetTypeKey() << "." - << "AST: " << PrettyPrint(call) << PrettyPrint(call) << std::endl; - return {}; + auto props = GetOnDeviceProps(call_node); + if (props.body.defined()) { + // See through "on_device" calls. + return VisitExpr(props.body); } + + const auto* global_node = call->op.as(); + ICHECK(global_node) + << "Non-primitive-call nodes should have been transformed away.\n" + << "The graph executor code generator expects all calls to have their callee " + "normalized to a GlobalVar, but found:" + << std::endl + << PrettyPrint(call); + auto prim_fn_name = global_node->name_hint; + return GraphAddCallNode(call_node, prim_fn_name, GraphAttrs()); } std::vector VisitExpr_(const LetNode* op) override { diff --git a/src/relay/backend/graph_plan_memory.cc b/src/relay/backend/graph_plan_memory.cc index 93c823d8a007..7642f3ccf703 100644 --- a/src/relay/backend/graph_plan_memory.cc +++ b/src/relay/backend/graph_plan_memory.cc @@ -27,9 +27,13 @@ #include #include #include +#include #include #include "../../support/arena.h" +#include "../op/annotation/annotation.h" +#include "../op/memory/memory.h" +#include "../transforms/device_aware_visitors.h" #include "./utils.h" namespace tvm { @@ -39,44 +43,41 @@ using backend::StaticMemoryPlan; using backend::StorageInfo; using IntegerArray = Array; +/*! A representation of a block of memory required at runtime on some device. */ struct StorageToken { /*! \brief Reference counter */ int ref_counter{0}; /*! \brief number of bytes */ size_t max_bytes{0}; - /*! \brief The corresponding tensor type node. */ - const TensorTypeNode* ttype{nullptr}; - /*! \brief virtual device index that corresponds to the device_type in - * DLDevice. */ - int device_type{0}; + /*! \brief The corresponding tensor type. */ + TensorType ttype{nullptr}; + /*! \brief Device on which memory will reside. */ + Device device{kInvalidDeviceType, -1}; /*! \brief The storage id */ int64_t storage_id{-1}; + + bool is_valid() const { return device.device_type != kInvalidDeviceType; } + + bool is_compatible(const StorageToken& that) const { + return device.device_type == that.device.device_type; + } + + std::string ToString() const { + std::ostringstream os; + os << "{id: " << storage_id << ", bytes: " << max_bytes << ", type: " << PrettyPrint(ttype) + << ", device: " << device.device_type << "}"; + return os.str(); + } }; -std::ostream& operator<<(std::ostream& os, StorageToken tok) { - return os << "StorageToken: " << std::endl - << "ref_counter: " << tok.ref_counter << std::endl - << "max_bytes: " << tok.max_bytes << std::endl - << "tttype: " << tok.ttype - << std::endl - // ok idk how to print this properly - << "tttype shape: " << tok.ttype->shape << std::endl - << "device_type: " << tok.device_type << std::endl - << "storage_id: " << tok.storage_id << std::endl; -} - -class StorageAllocaBaseVisitor : public ExprVisitor { +class StorageAllocaBaseVisitor : public transform::DeviceAwareExprVisitor { public: - // run the visitor on a function. - void Run(const Function& func) { - for (Var param : func->params) { - CreateToken(param.operator->(), false); - } - // must always keep output alive. - for (StorageToken* tok : GetToken(func->body)) { - tok->ref_counter += 1; - } - } + StorageAllocaBaseVisitor() : transform::DeviceAwareExprVisitor(Optional()) {} + + // run the visitor on a global function. + void Run(const Function& func) { VisitExpr(func); } + + using transform::DeviceAwareExprVisitor::VisitExpr_; void VisitExpr_(const ConstantNode* op) final { this->CreateToken(op, false); } @@ -84,8 +85,22 @@ class StorageAllocaBaseVisitor : public ExprVisitor { // Do nothing. } - void VisitExpr_(const FunctionNode* op) final { - // do not recurse into sub function. + void DeviceAwareVisitExpr_(const FunctionNode* func_node) final { + if (function_nesting() > 1) { + // do not recurse into sub functions. + return; + } + if (func_node->HasNonzeroAttr(attr::kPrimitive)) { + // No storage needed for primitive functions. + return; + } + for (const auto& param : func_node->params) { + CreateToken(param.get(), /*can_realloc=*/false); + } + // Process the function body, and make sure all result tokens are considered 'alive'. + for (StorageToken* tok : GetToken(func_node->body)) { + tok->ref_counter += 1; + } } void VisitExpr_(const GlobalVarNode* op) final { @@ -113,15 +128,17 @@ class StorageAllocaBaseVisitor : public ExprVisitor { void VisitExpr_(const IfNode* op) final { LOG(FATAL) << "if is not supported."; } - void VisitExpr_(const LetNode* op) final { - auto token = GetToken(op->value); - token_map_[op->var.operator->()] = token; - token_map_[op] = GetToken(op->body); + void PreVisitLetBinding_(const Var& var, const Expr& value) final { + token_map_[var.get()] = GetToken(value); + } + + void PostVisitLet_(const LetNode* let_node) final { + token_map_[let_node] = GetToken(let_node->body); } protected: /*! \brief internal token map */ - std::unordered_map > token_map_; + std::unordered_map> token_map_; /*! * \brief Get the necessary token. @@ -130,27 +147,39 @@ class StorageAllocaBaseVisitor : public ExprVisitor { */ const std::vector& GetToken(const Expr& expr) { this->VisitExpr(expr); - auto it = token_map_.find(expr.operator->()); - ICHECK(it != token_map_.end()) - << "Expression: `" << PrettyPrint(expr) << "` not found in storage map."; + // See through on_device calls. + auto props = GetOnDeviceProps(expr); + Expr real_expr = props.body.defined() ? props.body : expr; + auto it = token_map_.find(real_expr.get()); + ICHECK(it != token_map_.end()) << "Expression not found in storage map:" << std::endl + << PrettyPrint(real_expr); return it->second; } + + /*! + * \brief Allocates (or reuses if \p can_realloc is true) a storage token for holding + * the result of evaluating \p op. + */ + void CreateToken(const ExprNode* op, bool can_realloc) { + return CreateTokenOnDevice(op, GetInScopeDeviceType(GetRef(op)), can_realloc); + } + /*! - * \brief Populate the token map to set op's tokens - * \param op The node to be processed. - * \param can_realloc Whether we can re-allocate the memory. + * \brief Allocates (or reuses if \p can_realloc is true) a storage token for holding + * the result of evaluating \p op on \p device_type. */ - virtual void CreateToken(const ExprNode* op, bool can_realloc) = 0; + virtual void CreateTokenOnDevice(const ExprNode* op, DLDeviceType device_type, + bool can_realloc) = 0; }; +/*! \brief Associate storage with every expression without any concern for sharing. */ class StorageAllocaInit : protected StorageAllocaBaseVisitor { public: explicit StorageAllocaInit(support::Arena* arena) : arena_(arena) {} /*! \return The internal token map */ - std::unordered_map > GetInitTokenMap( + std::unordered_map> GetInitTokenMap( const Function& func) { - node_device_map_ = CollectDeviceInfo(func); this->Run(func); return std::move(token_map_); } @@ -158,32 +187,24 @@ class StorageAllocaInit : protected StorageAllocaBaseVisitor { protected: using StorageAllocaBaseVisitor::VisitExpr_; - void CreateToken(const ExprNode* op, bool can_realloc) final { + void CreateTokenOnDevice(const ExprNode* op, DLDeviceType device_type, + bool can_realloc) override { ICHECK(!token_map_.count(op)); std::vector tokens; - int device_type = - node_device_map_.count(GetRef(op)) ? node_device_map_[GetRef(op)]->value : 0; - if (const auto* tuple_type = op->checked_type().as()) { - for (Type t : tuple_type->fields) { - const auto* ttype = t.as(); - ICHECK(ttype); - StorageToken* token = arena_->make(); - token->ttype = ttype; - token->device_type = device_type; - tokens.push_back(token); - } - } else { - const auto* ttype = op->checked_type().as(); - ICHECK(ttype); + for (const auto& ttype : FlattenTupleType(op->checked_type())) { StorageToken* token = arena_->make(); token->ttype = ttype; - token->device_type = device_type; + // TODO(mbs): Should be TargetDevice. + token->device.device_type = device_type; + token->device.device_id = 0; tokens.push_back(token); } token_map_[op] = tokens; } - void VisitExpr_(const CallNode* op) final { + using StorageAllocaBaseVisitor::DeviceAwareVisitExpr_; + + void DeviceAwareVisitExpr_(const CallNode* op) final { // create token for the call node. CreateToken(op, true); @@ -198,13 +219,15 @@ class StorageAllocaInit : protected StorageAllocaBaseVisitor { private: // allocator support::Arena* arena_; - Map node_device_map_; }; +/*! \brief Associate storage with every expression, reusing storage where possible. */ class StorageAllocator : public StorageAllocaBaseVisitor { public: + StorageAllocator() = default; + /*! - * \return totoal number of bytes allocated + * \return total number of bytes allocated */ size_t TotalAllocBytes() const { size_t total = 0; @@ -216,6 +239,8 @@ class StorageAllocator : public StorageAllocaBaseVisitor { // Run storage allocation for a function. StaticMemoryPlan Plan(const Function& func) { + VLOG_CONTEXT << "StorageAllocator"; + VLOG(1) << "planning:" << std::endl << PrettyPrint(func); prototype_ = StorageAllocaInit(&arena_).GetInitTokenMap(func); this->Run(func); @@ -231,12 +256,13 @@ class StorageAllocator : public StorageAllocaBaseVisitor { std::vector sid_sizes_byte; for (StorageToken* tok : kv.second) { - if (tok->device_type) { + VLOG(1) << "token: " << tok->ToString(); + if (tok->is_valid()) { num_annotated_nodes++; } num_nodes++; storage_ids.push_back(tok->storage_id); - device_types.push_back(static_cast(tok->device_type)); + device_types.push_back(static_cast(tok->device.device_type)); sid_sizes_byte.push_back(GetMemorySize(tok)); } auto storage_info = backend::StorageInfo(storage_ids, device_types, sid_sizes_byte); @@ -253,21 +279,21 @@ class StorageAllocator : public StorageAllocaBaseVisitor { } protected: - using StorageAllocaBaseVisitor::VisitExpr_; // override create token by getting token as prototype requirements. - void CreateToken(const ExprNode* op, bool can_realloc) final { + void CreateTokenOnDevice(const ExprNode* op, DLDeviceType device_type, bool can_realloc) final { ICHECK(!token_map_.count(op)); auto it = prototype_.find(op); ICHECK(it != prototype_.end()); std::vector tokens; for (StorageToken* tok : it->second) { + ICHECK_EQ(tok->device.device_type, device_type); if (can_realloc) { tokens.push_back(Request(tok)); } else { // Allocate a new token, StorageToken* allocated_tok = Alloc(tok, GetMemorySize(tok)); - allocated_tok->device_type = tok->device_type; + allocated_tok->device = tok->device; // ensure it never get de-allocated. allocated_tok->ref_counter += 1; tokens.push_back(allocated_tok); @@ -275,6 +301,7 @@ class StorageAllocator : public StorageAllocaBaseVisitor { } token_map_[op] = tokens; } + // Mark op to reuse the input_token // tie the two memories together void ReuseInputToken(const ExprNode* op, StorageToken* input_token) { @@ -291,8 +318,10 @@ class StorageAllocator : public StorageAllocaBaseVisitor { token_map_[op] = {input_token}; } + using StorageAllocaBaseVisitor::DeviceAwareVisitExpr_; + // The call map - void VisitExpr_(const CallNode* op) final { + void DeviceAwareVisitExpr_(const CallNode* op) final { std::vector args; // for each input, visit argument token. for (Expr arg : op->args) { @@ -364,8 +393,8 @@ class StorageAllocator : public StorageAllocaBaseVisitor { * \return The required memory size. */ size_t GetMemorySize(StorageToken* prototype) { - const TensorTypeNode* ttype = prototype->ttype; - ICHECK(ttype != nullptr); + TensorType ttype = prototype->ttype; + ICHECK(ttype.defined()); size_t size = 1; for (IndexExpr dim : ttype->shape) { const int64_t* pval = tir::as_const_int(dim); @@ -394,7 +423,7 @@ class StorageAllocator : public StorageAllocaBaseVisitor { // search for memory blocks larger than requested for (auto it = mid; it != end; ++it) { StorageToken* tok = it->second; - if (tok->device_type != prototype->device_type) continue; + if (!tok->is_compatible(*prototype)) continue; ICHECK_EQ(tok->ref_counter, 0); // Use exect matching strategy tok->max_bytes = std::max(size, tok->max_bytes); @@ -407,7 +436,7 @@ class StorageAllocator : public StorageAllocaBaseVisitor { for (auto it = mid; it != begin;) { --it; StorageToken* tok = it->second; - if (tok->device_type != prototype->device_type) continue; + if (!tok->is_compatible(*prototype)) continue; ICHECK_EQ(tok->ref_counter, 0); // Use exect matching strategy tok->max_bytes = std::max(size, tok->max_bytes); @@ -452,7 +481,7 @@ class StorageAllocator : public StorageAllocaBaseVisitor { // all the storage resources available std::vector data_; /*! \brief internal prototype token map */ - std::unordered_map > prototype_; + std::unordered_map> prototype_; }; StaticMemoryPlan GraphPlanMemory(const Function& func) { return StorageAllocator().Plan(func); } diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index d87cf9811bc7..ef89fd9c9c6c 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -35,9 +35,9 @@ #include #include +#include "../op/annotation/annotation.h" #include "../transforms/pass_utils.h" -#include "compile_engine.h" -#include "te_compiler.h" +#include "./te_compiler.h" namespace tvm { namespace relay { @@ -439,7 +439,7 @@ class Interpreter : public ExprFunctor, const Array& all_prim_shape_fn_vars, const Array& prim_shape_fn_states, size_t num_shape_inputs, size_t num_shape_outputs, - const std::vector& args) { + Target prim_shape_target, const std::vector& args) { ICHECK(prim_shape_fn_var.defined()); ICHECK(prim_shape_fn_states.defined()); ICHECK(prim_shape_fn_var->checked_type().defined()); @@ -460,11 +460,10 @@ class Interpreter : public ExprFunctor, Device shape_device; shape_device.device_type = kDLCPU; shape_device.device_id = 0; - Target shape_target("llvm"); // 'Compile' the TIR shape function to appropriate callable form. PackedFunc packed_shape_func = - TIRToPackedFunc(prim_shape_fn_var, all_prim_shape_fn_vars, shape_target); + TIRToPackedFunc(prim_shape_fn_var, all_prim_shape_fn_vars, prim_shape_target); size_t arity = num_shape_inputs + num_shape_outputs; std::vector values(arity); @@ -481,13 +480,13 @@ class Interpreter : public ExprFunctor, // flattened form of this arg. Does that match what lowering actually does? int64_t state = prim_shape_fn_states[i]->value; for (const auto& nd_array : FlattenADT(args[i])) { - if (state & kNeedInputData) { + if (state & tec::kNeedInputData) { auto arr = nd_array.CopyTo(shape_device); inputs[arg_counter] = arr; setter(arg_counter, arr); ++arg_counter; } - if (state & kNeedInputShape) { + if (state & tec::kNeedInputShape) { int64_t ndim = nd_array.Shape().size(); NDArray shape_arr; if (ndim == 0) { @@ -553,16 +552,17 @@ class Interpreter : public ExprFunctor, * @return Result of primitive. */ ObjectRef InvokePrimitiveOp(const GlobalVar& prim_fn_var, const Array all_prim_fn_vars, - const GlobalVar& prim_shape_fn_var, + Target prim_target, const GlobalVar& prim_shape_fn_var, const Array& all_prim_shape_fn_vars, const Array& prim_shape_fn_states, size_t num_shape_inputs, - size_t num_shape_outputs, const std::vector& args) { + size_t num_shape_outputs, Target prim_shape_target, + const std::vector& args) { ICHECK(prim_fn_var->checked_type().defined()); const FuncTypeNode* ftn = prim_fn_var->checked_type().as(); ICHECK(ftn); // 'Compile' the TIR primitive to appropriate callable form (on the desired target). - PackedFunc packed_func = TIRToPackedFunc(prim_fn_var, all_prim_fn_vars, target_); + PackedFunc packed_func = TIRToPackedFunc(prim_fn_var, all_prim_fn_vars, prim_target); // Argument tuples are flattened. std::vector arg_nd_arrays = FlattenADTs(args); @@ -596,7 +596,7 @@ class Interpreter : public ExprFunctor, ICHECK(prim_shape_fn_states.defined()); runtime_shapes = ComputeDynamicShape(prim_shape_fn_var, all_prim_shape_fn_vars, prim_shape_fn_states, - num_shape_inputs, num_shape_outputs, args); + num_shape_inputs, num_shape_outputs, prim_shape_target, args); ICHECK_EQ(runtime_shapes.size(), result_tensor_types.size()); } @@ -676,26 +676,33 @@ class Interpreter : public ExprFunctor, return WithFrame(Frame(locals), [&]() { return Eval(func->body); }); } - ObjectRef VisitExpr_(const CallNode* call) final { + ObjectRef VisitExpr_(const CallNode* call_node) final { std::vector args; - for (auto arg : call->args) { + for (auto arg : call_node->args) { args.push_back(Eval(arg)); } + if (call_node->op == OnDeviceOp()) { + // Special case: The call 'on_device(expr)' denotes that expr should be executed on + // a particular device. We can ignore this during interpretation. + ICHECK_EQ(call_node->args.size(), 1UL); + return args[0]; + } + // We should not find calls to operators after running fusion and lowering. - if (const OpNode* op_node = call->op.as()) { + if (const OpNode* op_node = call_node->op.as()) { LOG(FATAL) << "found " << op_node->name << "; operators should have been removed by previous passes; try " "fusing and lowering"; } - if (const ConstructorNode* con = call->op.as()) { + if (const ConstructorNode* con = call_node->op.as()) { // Special case: ADT constructor return ConstructorValue(con->tag, args, GetRef(con)); } - if (const GlobalVarNode* gvn = call->op.as()) { - if (const TIRCallAttrs* attrs = call->attrs.as()) { + if (const GlobalVarNode* gvn = call_node->op.as()) { + if (const TIRCallAttrs* attrs = call_node->attrs.as()) { // Special case: Call a lowered TIR function. // TODO(mbs): Make calling convention first-class in Relay. Array all_prim_fn_vars; @@ -727,15 +734,14 @@ class Interpreter : public ExprFunctor, Downcast(attrs->metadata.at("prim_shape_fn_num_outputs"))->value); } - // Special case: Call TIR primitive. - return InvokePrimitiveOp(GetRef(gvn), all_prim_fn_vars, prim_shape_fn_var, - all_prim_shape_fn_vars, prim_shape_fn_states, num_shape_inputs, - num_shape_outputs, args); + return InvokePrimitiveOp(GetRef(gvn), all_prim_fn_vars, target_, + prim_shape_fn_var, all_prim_shape_fn_vars, prim_shape_fn_states, + num_shape_inputs, num_shape_outputs, cpu_target_, args); } } // Now we just evaluate and expect to find a closure. - ObjectRef fn_val = Eval(call->op); + ObjectRef fn_val = Eval(call_node->op); if (const InterpreterClosureObj* closure_node = fn_val.as()) { auto closure = GetRef(closure_node); return Invoke(closure, args); @@ -883,6 +889,8 @@ class Interpreter : public ExprFunctor, Device device_; // Unique target describing how to compile for primitives (but not shape functions). Target target_; + // Default 'CPU' target for shape primitives. + Target cpu_target_{"llvm"}; // Call stack. Stack stack_; // The distinguished 'debug' operator, which is handled specially. @@ -898,21 +906,27 @@ IRModule Prepare(IRModule mod, Device device, Target target) { // Things to initialize to pass into tec::LowerTEPass // We only have one device-specific target. tec::TargetMap targets = {{device.device_type, target}}; - - // All calls to primitives will use the unique target. - tec::DeviceMap device_map; + if (device.device_type != kDLCPU) { + // However some primitives (eg dynamic shape functions) must always execute on the CPU, + // so make sure we have a target for that. + targets.emplace(kDLCPU, Target("llvm")); + } // Run minimal transforms on module to establish invariants needed by interpreter. - transform::Sequential seq({transform::SimplifyInference(), - // FuseOps will mark wrapped calls to prim-ops with the 'Primitive' - // attribute. - transform::FuseOps(/*fuse_opt_level=*/0), transform::ToANormalForm(), - // eta expand to support constructors in argument position - transform::EtaExpand( - /*expand_constructor=*/true, /*expand_global_var=*/false), - transform::InferType(), - tec::LowerTEPass(targets, device_map, /*module_name=*/"intrp", - [](Function func) { /* no-op */ })}); + transform::Sequential seq( + {transform::SimplifyInference(), + // Figure out which devices should be used to execute. + transform::PlanDevices(device.device_type), + // FuseOps will mark wrapped calls to prim-ops with the 'Primitive' + // attribute. + transform::FuseOps(/*fuse_opt_level=*/0), + // Use ANF to reduce number of cases to handle. + transform::ToANormalForm(), + // eta expand to support constructors in argument position. + transform::EtaExpand( + /*expand_constructor=*/true, /*expand_global_var=*/false), + transform::InferType(), + tec::LowerTEPass(targets, /*module_name=*/"intrp", [](Function func) { /* no-op */ })}); transform::PassContext pass_ctx = transform::PassContext::Current(); With ctx(pass_ctx); @@ -964,6 +978,9 @@ class NeedsPreparationVisitor : public ExprVisitor { TypedPackedFunc)> EvalFunction(IRModule mod, Expr expr, Device device, Target target) { + VLOG_CONTEXT << "EvalFunction"; + VLOG(1) << "evaling module:\n" << PrettyPrint(mod) << "and expression:\n" << PrettyPrint(expr); + // // Step 1: Prepare mod. // @@ -1021,6 +1038,8 @@ TypedPackedFunc)> EvalFunction(IRModule mod, Expr expr, De ICHECK(closure->func.defined()); return TypedPackedFunc)>([intrp, closure](Array args) { + VLOG_CONTEXT << "EvalFunction::Apply"; + VLOG(1) << "evaling closure with " << args.size() << " arguments"; // // Step 3: Apply closure to arguments. // diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index d37fbeabc277..445602540dbb 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -42,6 +42,8 @@ #include #include +#include "../op/annotation/annotation.h" +#include "../transforms/device_aware_visitors.h" #include "./te_compiler_cache.h" #include "./utils.h" @@ -367,14 +369,13 @@ std::tuple IsDeviceCopy(const Function& func) { * ... %p(...) ... * \endcode */ -class LowerTensorExprMutator : public ExprMutator { +class LowerTensorExprMutator : public DeviceAwareExprMutator { public: - LowerTensorExprMutator(const IRModule& module, const TargetMap& targets, - const DeviceMap& device_ctx_map, ProcessFn process_fn, + LowerTensorExprMutator(const IRModule& module, const TargetMap& targets, ProcessFn process_fn, const String& module_name, TECompiler compiler) - : module_(module), + : DeviceAwareExprMutator(module), + module_(module), targets_(targets), - device_context_map_(device_ctx_map), process_fn_(process_fn), module_name_(module_name), compiler_(compiler), @@ -471,6 +472,8 @@ class LowerTensorExprMutator : public ExprMutator { auto device_copy = IsDeviceCopy(func); if (std::get<0>(device_copy)) { + // Record that device copy source and destination devices so the device planner can + // still follow along. auto source_device = std::get<1>(device_copy); auto dst_device = std::get<2>(device_copy); tir_call_attrs->metadata.Set("source_device", tvm::Integer(source_device)); @@ -487,8 +490,8 @@ class LowerTensorExprMutator : public ExprMutator { // on the host cpu irrespective of where the primitive runs. // TODO(mbs): Cleanup target handling. Target shape_target("llvm"); - DLOG(INFO) << "lowering to target '" << shape_target->str() - << "' for dynamic shape function for primitive"; + VLOG(1) << "lowering to target '" << shape_target->str() + << "' for dynamic shape function for primitive"; CCacheKey shape_key(func, shape_target); CachedFunc lowered_shape_func = compiler_->LowerShapeFunc(shape_key); // Capture the shape function's global var and parameters 'states' in call @@ -514,50 +517,44 @@ class LowerTensorExprMutator : public ExprMutator { return {lowered_func->prim_fn_var, Attrs(tir_call_attrs)}; } - Expr VisitExpr_(const LetNode* let) override { - Var var = Downcast(Mutate(let->var)); - Expr value = Mutate(let->value); - BaseFunc prim_func = ResolveToPrimitive(value); + std::pair PreVisitLetBinding_(const Var& var, const Expr& value) final { + Var new_var = Downcast(Mutate(var)); + Expr new_value = Mutate(value); + BaseFunc prim_func = ResolveToPrimitive(new_value); - if (prim_func.defined()) { - // Already lowered by other means, no need to mutate the Let node - if (prim_func->IsInstance()) { - return GetRef(let); - } - - // Remember let var is bound to (possibly indirectly) to a primitive. + if (prim_func.defined() && !prim_func->IsInstance()) { + // Remember let var is bound to (possibly indirectly) to a non-tir primitive. Function func = Downcast(prim_func); - primitive_functions_.emplace(let->var, func); - } - Expr body = Mutate(let->body); - if (prim_func.defined()) { - // Leaving let var scope. - primitive_functions_.erase(let->var); - } - if (var.same_as(let->var) && value.same_as(let->value) && body.same_as(let->body)) { - return GetRef(let); - } else { - return Let(var, value, body, let->span); + primitive_functions_.emplace(var, func); } + return {new_var, new_value}; } - Expr VisitExpr_(const CallNode* call) override { - Call expr = GetRef(call); + Expr PostVisitLet_(const LetNode* pre_let_node, const LetNode* post_let_node) final { + BaseFunc prim_func = ResolveToPrimitive(post_let_node->value); + if (prim_func.defined() && !prim_func->IsInstance()) { + // Leaving let var scope + primitive_functions_.erase(pre_let_node->var); + } + return DeviceAwareExprMutator::PostVisitLet_(pre_let_node, post_let_node); + } + Expr DeviceAwareVisitExpr_(const CallNode* call_node) override { + Call call = GetRef(call_node); // Look for (indirect) calls to primitives. - BaseFunc prim_func = ResolveToPrimitive(call->op); + BaseFunc prim_func = ResolveToPrimitive(call_node->op); if (!prim_func.defined()) { - // Not a call to a primitive function. - if (const FunctionNode* fn = call->op.as()) { + // Not a call_node to a primitive function. + if (const FunctionNode* fn = call_node->op.as()) { this->process_fn_(GetRef(fn)); } - return ExprMutator::VisitExpr_(call); + return ExprMutator::VisitExpr_(call_node); } // Already lowered by other means so we don't need to mutate // the call if (prim_func->IsInstance()) { - return std::move(expr); + return std::move(call); } // Find the desired target device. @@ -565,17 +562,11 @@ class LowerTensorExprMutator : public ExprMutator { if (prim_func->GetAttr(attr::kCompiler).defined()) { // The generic 'external device' target. target = Target("ext_dev"); - } else if (device_context_map_.empty() && targets_.size() == 1) { - // The unique target. - target = GetTargetFromInteger(kDLCPU, targets_); } else { - // The target corresponding to the call expression's annotation. - auto itr = device_context_map_.find(expr); - ICHECK(itr != device_context_map_.end()) - << "Could not find an entry in the device context map for " << expr - << "The memory planning was either not performed for this precise node, or there is " - "bug in the memory planner."; - target = GetTargetFromInteger(itr->second.device_type, targets_); + // The target corresponding to the call_node expression's annotation. + DLDeviceType device_type = GetInScopeDeviceType(call); + // TODO(mbs): Replace device_type with target so this lookup is unnecessary. + target = GetTargetFromInteger(device_type, targets_); } // Lower the primitive function for that target. @@ -584,7 +575,7 @@ class LowerTensorExprMutator : public ExprMutator { // Similarly transform arguments. Array args; - for (const auto& arg : call->args) { + for (const auto& arg : call_node->args) { args.push_back(VisitExpr(arg)); } @@ -595,7 +586,6 @@ class LowerTensorExprMutator : public ExprMutator { IRModule module_; TargetMap targets_; - DeviceMap device_context_map_; ProcessFn process_fn_; // Map from in-scope let-bound variables to Relay functions known to be // primitive. We'll rewrite these to the fresh global vars bound to the lowered @@ -641,12 +631,11 @@ Target GetTargetFromInteger(DLDeviceType dev_type, tec::TargetMap targets) { } } -Pass LowerTensorExpr(TargetMap targets, DeviceMap device_context_map, const String& module_name, - TECompiler compiler, std::function process_fn) { +Pass LowerTensorExpr(TargetMap targets, const String& module_name, TECompiler compiler, + std::function process_fn) { runtime::TypedPackedFunc pass_func = [=](Function func, IRModule module, PassContext ctx) { - LowerTensorExprMutator lower_te(module, targets, device_context_map, process_fn, - module_name, compiler); + LowerTensorExprMutator lower_te(module, targets, process_fn, module_name, compiler); return Downcast(lower_te.Mutate(func)); }; return CreateFunctionPass(pass_func, 0, "LowerTensorExpr", {}); @@ -656,6 +645,12 @@ backend::FunctionInfo UpdateMainWorkspaceSize(const IRModule& mod, tec::TargetMa Map storage_info_map) { Function func = Downcast(mod->Lookup("main")); + VLOG_CONTEXT << "UpdateMainWorkspaceSize"; + VLOG(1) << "calculating FunctionInfo for main:" << std::endl << PrettyPrint(func); + for (const auto& pair : targets) { + VLOG(1) << " target " << pair.first << " = " << pair.second->str(); + } + // This is a Map> std::unordered_map, backend::EnumClassHash> sid_workspace; @@ -695,34 +690,39 @@ backend::FunctionInfo UpdateMainWorkspaceSize(const IRModule& mod, tec::TargetMa // In this final case there is only one allocation for all tensors which share storage // which will be the maximal size of all tensors which were assigned to it. for (const auto& kv : storage_info_map) { - Expr expr = kv.first; + const Expr& expr = kv.first; + const backend::StorageInfo& storage_info = kv.second; int64_t size_bytes = backend::CalculateRelayExprSizeBytes(expr->checked_type()); - backend::StorageInfo storage_info = kv.second; + VLOG(1) << "expression:" << std::endl + << PrettyPrint(expr) << std::endl + << "of type:" << std::endl + << PrettyPrint(expr->checked_type()) << std::endl + << "has size " << size_bytes << " and storage info:" << std::endl + << storage_info; std::vector storage_ids = storage_info->storage_ids; std::vector devices = storage_info->device_types; if (expr->IsInstance()) { for (const auto& dev : devices) { + ICHECK_EQ(device_consts.count(dev), 1); device_consts[dev] += size_bytes; } - continue; } else if (expr->IsInstance() || expr.same_as(func->body)) { CHECK_GE(devices.size(), 1) << "must be at least one device"; for (const auto& dev : devices) { device_io[dev] += size_bytes; } - continue; - } - - // TODO(@electriclilies): This code is never being called which means sid_workspace is not - // updated.. This means that storage info is probably not being created correctly. Or is not - // equivalent to what was here previously - for (uint32_t i = 0; i < storage_ids.size(); i++) { - // Here we record the largest size of the tensor - // that share the same storage id, because storage_id will - // be shared between multiple tensors that are not live simultaneously. - if (size_bytes > sid_workspace[devices[i]][storage_ids[i]]) { - sid_workspace[devices[i]][storage_ids[i]] = size_bytes; + } else { + // TODO(@electriclilies): This code is never being called which means sid_workspace is not + // updated.. This means that storage info is probably not being created correctly. Or is not + // equivalent to what was here previously + for (uint32_t i = 0; i < storage_ids.size(); i++) { + // Here we record the largest size of the tensor + // that share the same storage id, because storage_id will + // be shared between multiple tensors that are not live simultaneously. + if (size_bytes > sid_workspace[devices[i]][storage_ids[i]]) { + sid_workspace[devices[i]][storage_ids[i]] = size_bytes; + } } } } @@ -762,11 +762,14 @@ backend::FunctionInfo UpdateMainWorkspaceSize(const IRModule& mod, tec::TargetMa for (const auto& dev_and_size : device_consts) { auto tgt = tec::GetTargetFromInteger(dev_and_size.first, targets); + ICHECK_EQ(constant_sizes.count(tgt), 0); constant_sizes.Set(tgt, dev_and_size.second); } - return backend::FunctionInfo(workspace_sizes, io_sizes, constant_sizes, tir_primfuncs, - relay_primfuncs); + backend::FunctionInfo func_info(workspace_sizes, io_sizes, constant_sizes, tir_primfuncs, + relay_primfuncs); + VLOG(1) << "func_info: " << func_info; + return std::move(func_info); } /*! @@ -777,13 +780,15 @@ backend::FunctionInfo UpdateMainWorkspaceSize(const IRModule& mod, tec::TargetMa */ void UpdateFunctionMetadata(Function relay_func, Map& function_metadata) { // NOLINT(*) + VLOG_CONTEXT << "UpdateFunctionMetadata"; + VLOG(1) << "updating function metadata for:" << std::endl << PrettyPrint(relay_func); // Originally UpdateFunctionMetadata took in CCachedFunc and looped through all the funcs stored // there Now the goal is to take only one func because process_fn should be controlling the - // iteration However, to do the workspace calculations we need the primfuncs. So process_fn needs - // to either access the cached funcs or be directly passed primfuncs This is bad and ideally we - // don't want process_fn to look at primfuncs There's also the question now of what the function - // metadatas are and how they are used if we can do something else to replicate the behavior of - // the function metadatas that might be good (ie annotating functions or something). + // iteration However, to do the workspace calculations we need the primfuncs. So process_fn + // needs to either access the cached funcs or be directly passed primfuncs This is bad and + // ideally we don't want process_fn to look at primfuncs There's also the question now of what + // the function metadatas are and how they are used if we can do something else to replicate the + // behavior of the function metadatas that might be good (ie annotating functions or something). Map workspace_sizes; Map io_sizes; Map constant_sizes; @@ -843,19 +848,20 @@ void UpdateFunctionMetadata(Function relay_func, backend::FunctionInfo fi = backend::FunctionInfo(workspace_sizes, io_sizes, constant_sizes, tir_primfuncs, relay_primfuncs); + VLOG(1) << "FunctionInfo: " << prim_fn_var.value()->name_hint << " = " << PrettyPrint(fi); + // The primitive function name here corresponds to the string we will use to generate // this Relay function at the low level. function_metadata.Set(prim_fn_var.value()->name_hint, fi); } -IRModule LowerTE(const IRModule& module, TargetMap targets, DeviceMap device_context_map, - const String& module_name, std::function process_fn) { +IRModule LowerTE(const IRModule& module, TargetMap targets, const String& module_name, + std::function process_fn) { DLOG(INFO) << "lowering module:\n" << PrettyPrint(module); TECompiler compiler; - auto updated_module = - LowerTensorExpr(targets, device_context_map, module_name, compiler, process_fn)(module); + auto updated_module = LowerTensorExpr(targets, module_name, compiler, process_fn)(module); backend::UpdateAutoSchedulerOpWeights(compiler); @@ -900,11 +906,11 @@ Map GetPerTargetModules(IRModule mod) { return per_target_modules; } -Pass LowerTEPass(TargetMap targets, DeviceMap device_context_map, const String& module_name, +Pass LowerTEPass(TargetMap targets, const String& module_name, std::function process_fn) { runtime::TypedPackedFunc pass_func = [=](IRModule module, PassContext ctx) { - return LowerTE(module, targets, device_context_map, module_name, process_fn); + return LowerTE(module, targets, module_name, process_fn); }; return tvm::transform::Sequential({tvm::relay::transform::RelayToTIRTargetHook(), diff --git a/src/relay/backend/te_compiler.h b/src/relay/backend/te_compiler.h index d5135e6301c4..248fd40f98eb 100644 --- a/src/relay/backend/te_compiler.h +++ b/src/relay/backend/te_compiler.h @@ -173,7 +173,6 @@ Map GetPerTargetModules(IRModule mod); * * \param module The IRModule. * \param targets The mapping for devices to targets. - * \param device_map An analysis result mapping each sub-expression to a device. * \param memory_plan The memory plan used during lowering * \param module_name The name of this module * \param process_fn Callback allowing one-level up code generators to process @@ -181,9 +180,8 @@ Map GetPerTargetModules(IRModule mod); * \return The lowered module, see above. */ IRModule LowerTE( - const IRModule& module, TargetMap targets, DeviceMap device_map, - backend::StaticMemoryPlan memory_plan, const String& module_name, - ProcessFn process_fn = [](Function f) {}); + const IRModule& module, TargetMap targets, backend::StaticMemoryPlan memory_plan, + const String& module_name, ProcessFn process_fn = [](Function f) {}); /*! \brief Pass to lower an IRModule's primitive functions to TIR. * @@ -192,14 +190,13 @@ IRModule LowerTE( * with their target. * * \param targets The mapping for devices to targets. - * \param device_context_map An analysis result mapping each sub-expression to a device. * \param module_name The name of this module * \param process_fn Callback allowing one-level up code generators to process * each function that we lower * \returns The pass which lowers primative functions to TIR */ -transform::Pass LowerTEPass(TargetMap targets, DeviceMap device_context_map, - const String& module_name, std::function process_fn); +transform::Pass LowerTEPass(TargetMap targets, const String& module_name, + std::function process_fn); } // namespace tec } // namespace relay } // namespace tvm diff --git a/src/relay/backend/utils.cc b/src/relay/backend/utils.cc index 67c7558889fb..02caf56c66e6 100644 --- a/src/relay/backend/utils.cc +++ b/src/relay/backend/utils.cc @@ -43,6 +43,25 @@ StorageInfo::StorageInfo(std::vector storage_ids, std::vector([](const ObjectRef& ref, ReprPrinter* p) { + const auto* node = ref.as(); + p->stream << "StorageInfoNode(\n" + << " storage_ids=["; + for (auto id : node->storage_ids) { + p->stream << id << ", "; + } + p->stream << "],\n device_types=["; + for (auto device_type : node->device_types) { + p->stream << device_type << ", "; + } + p->stream << "],\n storage_size_in_bytes=["; + for (auto bytes : node->storage_sizes_in_bytes) { + p->stream << bytes << ", "; + } + p->stream << "])"; + }); + TVM_REGISTER_GLOBAL("relay.ir.StorageInfo") .set_body_typed([](const Array& sids, const Array& dev_types, const Array& sizes_in_bytes) { @@ -97,6 +116,7 @@ TVM_REGISTER_GLOBAL("relay.ir.StaticMemoryPlan") return StaticMemoryPlan(expr_to_storage_info); }); +// TODO(mbs): Cf GetMemorySizeBytes in aot_executor_codegen.cc int64_t CalculateRelayExprSizeBytes(const Type& expr_type) { if (expr_type->IsInstance()) { auto tuple_type = Downcast(expr_type); diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index f8ff20ece561..6d59b858927c 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -189,6 +189,8 @@ struct ConstantUpdater : public ExprVisitor { */ inline void UpdateConstants(Function func, std::unordered_map* params) { + VLOG_CONTEXT << "UpdateConstants"; + VLOG(1) << "updating constants for:" << std::endl << PrettyPrint(func); auto codegen = func->GetAttr(attr::kCompiler); ICHECK(codegen.defined()) << "No external codegen is set"; std::string codegen_name = codegen.value(); @@ -211,6 +213,9 @@ inline void UpdateConstants(Function func, (*params)[const_name] = it.second; } } + for (const auto& pair : *params) { + VLOG(1) << "Constants: " << pair.first << " = " << PrettyPrint(pair.second); + } } /*! diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 723a0ea6ee7e..36cd0c7f406d 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -45,10 +45,12 @@ #include #include "../../../target/source/codegen_source_base.h" +#include "../../op/annotation/annotation.h" #include "../../op/op_common.h" +#include "../../transforms/device_aware_visitors.h" #include "../../transforms/pass_utils.h" #include "../utils.h" -#include "compiler.h" +#include "./compiler.h" namespace tvm { namespace relay { @@ -247,15 +249,14 @@ int GetFallbackDevice() { return fallback_dev->value; } -class VMFunctionCompiler : ExprFunctor { +class VMFunctionCompiler : DeviceAwareExprFunctor { public: - VMFunctionCompiler(VMCompilerContext* context, TargetsMap targets, Target target_host, - ExprDeviceMap expr_device_map) - : last_register_(0), + VMFunctionCompiler(VMCompilerContext* context, TargetsMap targets, Target target_host) + : DeviceAwareExprFunctor(context->module), + last_register_(0), registers_num_(0), context_(context), - target_host_(target_host), - expr_device_map_(std::move(expr_device_map)) { + target_host_(target_host) { CheckAndUpdateHostConsistency(&targets, &target_host); for (const auto& it : targets) { targets_[it.first->value] = it.second; @@ -264,44 +265,48 @@ class VMFunctionCompiler : ExprFunctor { } VMFunction Compile(const GlobalVar& var, const Function& func) { - size_t i = 0; - // We then assign register num to the free variables - for (auto param : func->params) { - auto arg_register = NewRegister(); - ICHECK_EQ(i, arg_register); - var_register_map_.insert({param, arg_register}); - params_.push_back(param->name_hint()); - ++i; - } - + std::vector params_device_type; if (IsClosure(func)) { + // After lifting we'll have functions of the form: + // fn(closure args) { fn(lifted function args) { body } } + // But we want the closure's function to be: + // fn(closure args, lifter function args) { body } + // Do that flattening on-the-fly here. Function inner_func = Downcast(func->body); - for (auto param : inner_func->params) { - auto arg_register = NewRegister(); - ICHECK_EQ(i, arg_register); - var_register_map_.insert({param, arg_register}); - params_.push_back(param->name_hint()); - ++i; + std::vector params; + std::vector param_device_types; + params.reserve(func->params.size() + inner_func->params.size()); + param_device_types.reserve(func->params.size() + inner_func->params.size()); + for (size_t i = 0; i < func->params.size(); ++i) { + params.emplace_back(func->params[i]); + params_device_type.push_back(GetFunctionParamDeviceType(func.get(), i)); + } + for (size_t i = 0; i < inner_func->params.size(); ++i) { + params.emplace_back(inner_func->params[i]); + params_device_type.push_back(GetFunctionParamDeviceType(inner_func.get(), i)); + } + std::vector type_params; + type_params.reserve(func->type_params.size() + inner_func->type_params.size()); + for (const auto& tyvar : func->type_params) { + type_params.push_back(tyvar); + } + for (const auto& tyvar : inner_func->type_params) { + type_params.push_back(tyvar); } - this->VisitExpr(inner_func->body); + Function flattened_func = Function(params, inner_func->body, inner_func->ret_type, + type_params, func->attrs, func->span); + VisitExpr(MaybeFunctionOnDevice(flattened_func, params_device_type, + GetFunctionResultDeviceType(inner_func.get()))); } else { - this->VisitExpr(func->body); - } - instructions_.push_back(Instruction::Ret(last_register_)); - - std::vector params_device_type; - for (const auto& it : func->params) { - if (!expr_device_map_.empty()) { - ICHECK_GT(expr_device_map_.count(it), 0U); - params_device_type.push_back(expr_device_map_[it].device_type); - } else { - ICHECK_EQ(targets_.size(), 1U); - params_device_type.push_back((targets_.begin())->first); + params_device_type.reserve(func->params.size()); + for (size_t i = 0; i < func->params.size(); ++i) { + params_device_type.push_back(GetFunctionParamDeviceType(func.get(), i)); } + VisitExpr(func); } - return VMFunction(var->name_hint, params_, instructions_, registers_num_, params_device_type); } + /*! \brief Attrs objects for each op. */ std::map> op_attrs; @@ -312,7 +317,7 @@ class VMFunctionCompiler : ExprFunctor { size_t NewRegister() { return registers_num_++; } inline void Emit(const Instruction& instr) { - DLOG(INFO) << "VMCompiler::Emit: instr=" << instr; + VLOG(1) << "VMCompiler::Emit: instr=" << instr; ICHECK((int)instr.op < 100) << "Invalid opcode " << (int)instr.op; switch (instr.op) { case Opcode::AllocADT: @@ -342,29 +347,26 @@ class VMFunctionCompiler : ExprFunctor { instructions_.push_back(instr); } - void VisitExpr_(const ConstantNode* const_node) { + using DeviceAwareExprFunctor::VisitExpr_; + + void VisitExpr_(const ConstantNode* const_node) final { // Check the shape is valid NDArray data = const_node->data; size_t konst_idx = context_->constants.size(); - if (expr_device_map_.empty()) { - context_->const_device_type.push_back(targets_.begin()->first); - } else { - auto con = GetRef(const_node); - ICHECK_GT(expr_device_map_.count(con), 0U); - context_->const_device_type.push_back(expr_device_map_[con].device_type); - } + auto con = GetRef(const_node); + context_->const_device_type.push_back(GetInScopeDeviceType(con)); context_->constants.push_back(const_node->data); Emit(Instruction::LoadConst(konst_idx, NewRegister())); } - void VisitExpr_(const VarNode* var_node) { + void VisitExpr_(const VarNode* var_node) final { auto var = GetRef(var_node); auto reg_it = this->var_register_map_.find(var); ICHECK(reg_it != this->var_register_map_.end()); last_register_ = reg_it->second; } - void VisitExpr_(const TupleNode* tuple_node) { + void VisitExpr_(const TupleNode* tuple_node) final { auto tuple = GetRef(tuple_node); std::vector fields_registers; @@ -377,35 +379,28 @@ class VMFunctionCompiler : ExprFunctor { Emit(Instruction::AllocADT(0, tuple->fields.size(), fields_registers, NewRegister())); } - void VisitExpr_(const MatchNode* match_node) { + void VisitExpr_(const MatchNode* match_node) final { auto match = GetRef(match_node); this->VisitExpr(match->data); CompileMatch(match); } - void VisitExpr_(const LetNode* l) final { - Expr let_binding = GetRef(l); - const LetNode* let; - while ((let = let_binding.as())) { - ICHECK(!let->value.as()) - << "invariant violated, inner functions should not exist (did you set opt_level = 2?)"; - VisitExpr(let->value); - var_register_map_.insert({let->var, this->last_register_}); - let_binding = let->body; - } - - VisitExpr(let_binding); + void PreVisitLetBinding_(const Var& var, const Expr& value) final { + ICHECK(!value.as()) + << "invariant violated, inner functions should not exist (did you set opt_level = 2?)"; + VisitExpr(value); + var_register_map_.emplace(var, this->last_register_); } - void VisitExpr_(const TupleGetItemNode* get_node) { + void VisitExpr_(const TupleGetItemNode* get_node) final { auto get = GetRef(get_node); this->VisitExpr(get->tuple); auto tuple_register = last_register_; Emit(Instruction::GetField(tuple_register, get->index, NewRegister())); } - void VisitExpr_(const GlobalVarNode* gvar) { + void VisitExpr_(const GlobalVarNode* gvar) final { auto var = GetRef(gvar); auto func = context_->module->Lookup(var); auto it = context_->global_map.find(var); @@ -414,7 +409,7 @@ class VMFunctionCompiler : ExprFunctor { Emit(Instruction::AllocClosure(it->second, 0, {}, NewRegister())); } - void VisitExpr_(const IfNode* if_node) { + void VisitExpr_(const IfNode* if_node) final { this->VisitExpr(if_node->cond); size_t test_register = last_register_; @@ -501,8 +496,9 @@ class VMFunctionCompiler : ExprFunctor { void EmitInvokeTVMOp(const Function& func, const Expr& inputs, const Expr& outputs) { std::vector argument_registers; - ICHECK(func->GetAttr(attr::kPrimitive, 0) != 0) - << "internal error: invoke_tvm_op requires the first argument to be a relay::Function"; + ICHECK(func->HasNonzeroAttr(attr::kPrimitive)) + << "internal error: invoke_tvm_op requires the first argument to be a primitive " + "relay::Function"; auto input_tuple = inputs.as(); ICHECK(input_tuple) << "internal error: invoke_tvm_op inputs must be a tuple," @@ -526,31 +522,21 @@ class VMFunctionCompiler : ExprFunctor { Target target; + // Which target should execute the function? if (func->GetAttr(attr::kCompiler).defined()) { target = Target("ext_dev"); } else { - // Next generate the invoke instruction. - if (expr_device_map_.empty()) { - // homogeneous execution. - ICHECK_EQ(targets_.size(), 1U); - const auto& it = targets_.begin(); - target = (*it).second; + int dev_type = GetInScopeDeviceType(func); + if (targets_.count(dev_type) == 0) { + target = CreateDefaultTarget(dev_type); } else { - ICHECK_GT(expr_device_map_.count(func), 0U) - << "Found not annotated expression, please make sure " - "context analysis has been executed"; - int dev_type = expr_device_map_[func].device_type; - if (targets_.count(dev_type) == 0) { - target = CreateDefaultTarget(dev_type); - } else { - target = targets_[expr_device_map_[func].device_type]; - } + target = targets_[dev_type]; } } CCacheKey key(func, target); auto mangle_fn = [](String name) { return name; }; - auto cfunc = context_->compiler->Lower(key, mangle_fn); + auto cfunc = context_->compiler->Lower(key, mangle_fn); // <<<< one-func-at-a-time lowering auto op_index = -1; if (func->GetAttr(attr::kCompiler).defined()) { @@ -576,7 +562,7 @@ class VMFunctionCompiler : ExprFunctor { argument_registers)); } - void VisitExpr_(const CallNode* call_node) { + void DeviceAwareVisitExpr_(const CallNode* call_node) final { Expr op = call_node->op; // First we handle the case in which we are using an opaque @@ -646,20 +632,8 @@ class VMFunctionCompiler : ExprFunctor { ICHECK(alloc_attrs != nullptr) << "must be the AllocStorage attrs"; auto dtype = alloc_attrs->dtype; - Index device_type; - if (expr_device_map_.empty()) { - // TODO(zhiics) There is bug if all expressions are annotated with the device - // that is different the first one in the target list. - auto& kv = *(targets_.begin()); - device_type = kv.first; - } else { - ICHECK_GT(expr_device_map_.count(GetRef(call_node)), 0U) - << " The alloc_storage node is not annotated"; - device_type = expr_device_map_[GetRef(call_node)].device_type; - } - - Emit(Instruction::AllocStorage(size_register, alignment, dtype, device_type, - NewRegister())); + Emit(Instruction::AllocStorage(size_register, alignment, dtype, + alloc_attrs->device_type, NewRegister())); }) .Match("vm.shape_func", [this](const Array& args, const Attrs& attrs, const Array& type_arg) { @@ -728,8 +702,8 @@ class VMFunctionCompiler : ExprFunctor { auto global = GetRef(global_node); auto it = context_->global_map.find(global); ICHECK(it != context_->global_map.end()); - DLOG(INFO) << "VisitExpr_: generating invoke for " << global->name_hint - << " with func_index=" << it->second; + VLOG(1) << "VisitExpr_: generating invoke for " << global->name_hint + << " with func_index=" << it->second; // TODO(tvm-team): // Think about mixed call into global that is not a relay::Function @@ -764,12 +738,32 @@ class VMFunctionCompiler : ExprFunctor { } } - void VisitExpr_(const FunctionNode* func_node) { - if (!func_node->HasNonzeroAttr(attr::kPrimitive)) { - LOG(FATAL) << "local functions should have been removed by lambda lifting:" << std::endl - << "Program: " << AsText(GetRef(func_node), false) << std::endl - << "AST: " << GetRef(func_node); + void DeviceAwareVisitExpr_(const FunctionNode* func_node) final { + if (function_nesting() > 1) { + ICHECK(func_node->HasNonzeroAttr(attr::kPrimitive)) + << "local functions should have been removed by lambda lifting:" << std::endl + << "Program: " << AsText(GetRef(func_node), false) << std::endl + << "AST: " << GetRef(func_node); + return; + } + + // We're processing a top-level function which has possibly been rejigged to capture + // both closure and function arguments. Those functions retain their 'Closure' attribute, + // but we can just process them like any other function here. + + // Assign a register num to each parameter. + size_t i = 0; + for (auto param : func_node->params) { + auto arg_register = NewRegister(); + ICHECK_EQ(i, arg_register); + var_register_map_.insert({param, arg_register}); + params_.push_back(param->name_hint()); + ++i; } + + VisitExpr(func_node->body); + + instructions_.push_back(Instruction::Ret(last_register_)); } /*! @@ -862,8 +856,6 @@ class VMFunctionCompiler : ExprFunctor { std::unordered_map targets_; /*! \brief Host target. */ Target target_host_; - /*! \brief Map from Relay expr to device type. */ - ExprDeviceMap expr_device_map_; }; PackedFunc VMCompiler::GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) { @@ -930,15 +922,11 @@ void VMCompiler::Lower(IRModule mod, const TargetsMap& targets, const tvm::Targe // the global state. exec_->functions.resize(context_.module->functions.size()); - // Collect the annotated device information. - // This indicates which device each Relay expr should be executed on. - ExprDeviceMap expr_device_map = AnalyzeContext(); - for (auto named_func : context_.module->functions) { auto gvar = named_func.first; if (auto* n = named_func.second.as()) { auto func = GetRef(n); - VMFunctionCompiler func_compiler(&context_, targets_, target_host_, expr_device_map); + VMFunctionCompiler func_compiler(&context_, targets_, target_host_); auto vm_func = func_compiler.Compile(gvar, func); size_t func_index = context_.global_map.at(gvar); @@ -954,7 +942,7 @@ void VMCompiler::Lower(IRModule mod, const TargetsMap& targets, const tvm::Targe #if USE_RELAY_DEBUG for (auto vm_func : exec_->functions) { - DLOG(INFO) << vm_func << "-------------"; + VLOG(1) << vm_func << "-------------"; } #endif // USE_RELAY_DEBUG @@ -1043,13 +1031,18 @@ IRModule VMCompiler::OptimizeModule(IRModule mod, const TargetsMap& targets_arg, Array pass_seqs = relay::backend::GetPassPrefix(targets, true); - if (targets_.size() > 1) { - // Handle heterogeneous compilation. - int fallback_dev = GetFallbackDevice(); - pass_seqs.push_back(transform::RewriteAnnotatedOps(fallback_dev)); + // TODO(mbs): Reconcile with relay/backend/build_module.cc + DLDeviceType default_device_type; + if (targets_arg.size() == 1) { + default_device_type = + static_cast(static_cast((*targets_arg.begin()).first->value)); + } else { + default_device_type = static_cast(GetFallbackDevice()); } + pass_seqs.push_back(PlanDevices(default_device_type)); pass_seqs.push_back(transform::FuseOps()); + // Do layout rewrite for auto-scheduler. transform::PassContext pass_ctx = PassContext::Current(); if (backend::IsAutoSchedulerEnabled() && targets.size() == 1) { @@ -1147,25 +1140,6 @@ void VMCompiler::Codegen() { exec_->SetLib(lib); } -ExprDeviceMap VMCompiler::AnalyzeContext() const { - Device default_device; - ExprDeviceMap expr_device_map; - if (targets_.size() > 1) { - int fallback_dev = GetFallbackDevice(); - default_device.device_type = static_cast(fallback_dev); - default_device.device_id = 0; - expr_device_map = ContextAnalysis(context_.module, default_device); - } else { - const auto& tgt = targets_.begin(); - default_device.device_type = static_cast((*tgt).first->value); - if (default_device.device_type != kDLCPU) { - default_device.device_id = 0; - expr_device_map = ContextAnalysis(context_.module, default_device); - } - } - return expr_device_map; -} - runtime::Module CreateVMCompiler() { auto exec = make_object(); return runtime::Module(exec); diff --git a/src/relay/backend/vm/compiler.h b/src/relay/backend/vm/compiler.h index a05c52ced07f..af3c5bccbeff 100644 --- a/src/relay/backend/vm/compiler.h +++ b/src/relay/backend/vm/compiler.h @@ -63,7 +63,6 @@ using GlobalMap = NodeMap; using ConstMap = NodeMap; using ConstTensorShapeMap = NodeMap>; using TargetsMap = Map; -using ExprDeviceMap = std::unordered_map; struct VMCompilerContext { // The module context for the compilation @@ -108,8 +107,8 @@ class VMCompiler : public runtime::ModuleNode { * \brief Lower the functions in a Module * * \param mod Relay Module - * \param targets For heterogeneous compilation, it is a dictionary indicating context - * to target mapping. For homogeneous compilation, it is a build target. + * \param targets For heterogeneous compilation, it is a dictionary indicating device type + * to target mapping. For homogeneous compilation, it is a singleton build target. * \param target_host Host compilation target, if target is device. */ void Lower(IRModule mod, const TargetsMap& targets, const tvm::Target& target_host); @@ -122,8 +121,8 @@ class VMCompiler : public runtime::ModuleNode { * \brief Perform a series of optimizations on the input IR module. * * \param mod The input IRModule. - * \param targets For heterogeneous compilation, it is a dictionary indicating context - * to target mapping. For homogeneous compilation, it is a build target. + * \param targets For heterogeneous compilation, it is a dictionary indicating device type + * to target mapping. For homogeneous compilation, it is a singleton build target. * \param target_host Host compilation target. * * \return The optimized IRModule. @@ -136,9 +135,6 @@ class VMCompiler : public runtime::ModuleNode { */ void PopulateGlobalMap(); - /*! \brief Analyze the device context of each expression. */ - ExprDeviceMap AnalyzeContext() const; - protected: /*! \brief Target devices. */ TargetsMap targets_; diff --git a/src/relay/backend/vm/lambda_lift.cc b/src/relay/backend/vm/lambda_lift.cc index c768a2c300ec..d9a2b8b91fa3 100644 --- a/src/relay/backend/vm/lambda_lift.cc +++ b/src/relay/backend/vm/lambda_lift.cc @@ -26,13 +26,15 @@ #include #include #include -#include #include #include #include #include +#include "../../op/annotation/annotation.h" +#include "../../transforms/device_aware_visitors.h" + using namespace tvm::runtime; namespace tvm { @@ -44,7 +46,7 @@ inline std::string GenerateName(const Function& func) { return std::string("lifted_name") + std::to_string(hash); } -bool IsClosure(const Function& func) { return func->GetAttr(attr::kClosure, 0) != 0; } +bool IsClosure(const Function& func) { return func->HasNonzeroAttr(attr::kClosure); } Function MarkClosure(Function func) { return WithAttr(std::move(func), attr::kClosure, tvm::Integer(1)); @@ -56,39 +58,29 @@ Function MarkClosure(Function func) { * We will lift a function out into a global which takes the set of the free * vars and then return the new created function. */ -class LambdaLifter : public ExprMutator { +class LambdaLifter : public transform::DeviceAwareExprMutator { public: - explicit LambdaLifter(const IRModule& module) : module_(module) {} - - Expr VisitExpr_(const LetNode* let_node) final { - auto pre_visit = [this](const LetNode* op) { - bool is_lambda = false; - if (auto func = op->value.as()) { - if (!func->HasNonzeroAttr(attr::kPrimitive)) { - is_lambda = true; - this->letrec_.push_back(op->var); - } - } - Expr value = this->VisitExpr(op->value); + explicit LambdaLifter(const IRModule& module) + : transform::DeviceAwareExprMutator(module), module_(module) {} - if (is_lambda) { - this->letrec_.pop_back(); + std::pair PreVisitLetBinding_(const Var& var, const Expr& value) final { + bool is_lambda = false; + if (const auto* func_node = value.as()) { + if (!func_node->HasNonzeroAttr(attr::kPrimitive)) { + is_lambda = true; + this->letrec_.push_back(var); } - }; - auto post_visit = [this](const LetNode* op) { - // Rely on the Memoizer to cache pre-visit values - Expr value = this->VisitExpr(op->value); - // Visit body and cache the op - Expr body = this->VisitExpr(op->body); - auto expr = GetRef(op); - this->memo_[expr] = Let(op->var, value, body); - }; - ExpandANormalForm(let_node, pre_visit, post_visit); - return memo_[GetRef(let_node)]; + } + Expr new_value = this->VisitExpr(value); + + if (is_lambda) { + this->letrec_.pop_back(); + } + return {var, new_value}; } - Expr VisitExpr_(const CallNode* call_node) final { - auto call = Downcast(ExprMutator::VisitExpr_(call_node)); + Expr DeviceAwareVisitExpr_(const CallNode* call_node) final { + auto call = Downcast(DeviceAwareExprMutator::DeviceAwareVisitExpr_(call_node)); if (auto var_node = call_node->op.as()) { auto var = GetRef(var_node); if (!letrec_.empty() && var == letrec_.back()) { @@ -100,20 +92,27 @@ class LambdaLifter : public ExprMutator { return std::move(call); } - Expr VisitExpr_(const FunctionNode* func_node) final { + Expr DeviceAwareVisitExpr_(const FunctionNode* func_node) final { auto func = GetRef(func_node); - // We should not transform primitive functions. if (func->HasNonzeroAttr(attr::kPrimitive)) { + // We should not transform primitive functions. return std::move(func); } + if (function_nesting() == 1) { + // We don't need to lift global functions. + return Function(func_node->params, VisitExpr(func_node->body), func_node->ret_type, + func_node->type_params, func_node->attrs, func_node->span); + } + auto name = GenerateName(func); auto global = GlobalVar(name); auto free_vars = FreeVars(func); auto free_type_vars = FreeTypeVars(func, module_); Array captured_vars; + std::vector captured_var_device_types; bool recursive = false; for (const auto& var : free_vars) { if (!letrec_.empty() && var == letrec_.back()) { @@ -121,8 +120,10 @@ class LambdaLifter : public ExprMutator { continue; } captured_vars.push_back(var); + captured_var_device_types.push_back(GetInScopeDeviceType(var)); } + // Freshen all the captured vars. Array typed_captured_vars; Map rebinding_map; for (auto free_var : captured_vars) { @@ -131,6 +132,8 @@ class LambdaLifter : public ExprMutator { rebinding_map.Set(free_var, var); } + DLDeviceType result_device_type = GetInScopeDeviceType(func_node->body); + if (recursive) { if (!captured_vars.empty()) { Array fvs; @@ -143,7 +146,7 @@ class LambdaLifter : public ExprMutator { } } - auto body = Downcast(ExprMutator::VisitExpr_(func_node)); + auto body = Downcast(DeviceAwareExprMutator::DeviceAwareVisitExpr_(func_node)); // When performing this optimization there are two cases. // @@ -168,8 +171,9 @@ class LambdaLifter : public ExprMutator { // The "inner" function should be used to generate the // code for the closure. Function lifted_func; - if (captured_vars.size() == 0 && free_type_vars.size() == 0) { - lifted_func = Function(body->params, body->body, body->ret_type, body->type_params); + if (captured_vars.empty() && free_type_vars.empty()) { + lifted_func = Function(body->params, body->body, body->ret_type, body->type_params, + body->attrs, body->span); } else { // When a closure is locally bound in a program, we have its full type information // avalible to us. @@ -183,13 +187,16 @@ class LambdaLifter : public ExprMutator { // bind to go from unannotated free variables -> annotated free variables and then // construct the "closure" function with fully annotated arguments, no longer relying // on type inference. - auto before = Downcast(body)->params.size(); + size_t before_arity = body->params.size(); auto rebound_body = Function(func->params, Bind(body->body, rebinding_map), func->ret_type, func->type_params, func->attrs, func->span); - auto after = Downcast(rebound_body)->params.size(); - CHECK_EQ(before, after); + size_t after_arity = rebound_body->params.size(); + CHECK_EQ(before_arity, after_arity); + lifted_func = + Function(typed_captured_vars, rebound_body, /*ret_type=*/func->func_type_annotation(), + free_type_vars, /*attrs=*/{}, func->span); lifted_func = - Function(typed_captured_vars, rebound_body, func->func_type_annotation(), free_type_vars); + MaybeFunctionOnDevice(lifted_func, captured_var_device_types, result_device_type); lifted_func = MarkClosure(lifted_func); } @@ -206,7 +213,7 @@ class LambdaLifter : public ExprMutator { module_->Add(global, lifted_func); } - if (captured_vars.size() == 0) { + if (captured_vars.empty()) { return std::move(global); } else { // If we need to allocate a closure, @@ -226,9 +233,7 @@ class LambdaLifter : public ExprMutator { if (auto* n = pair.second.as()) { if (n->GetAttr(attr::kCompiler).defined()) continue; auto func = GetRef(n); - func = Function(func->params, VisitExpr(func->body), func->ret_type, func->type_params, - func->attrs); - module_->Add(pair.first, func, true); + module_->Add(pair.first, Downcast(Mutate(func)), /*update=*/true); } } return module_; diff --git a/src/relay/ir/expr_functor.cc b/src/relay/ir/expr_functor.cc index 5984a208efe0..9a2297a75962 100644 --- a/src/relay/ir/expr_functor.cc +++ b/src/relay/ir/expr_functor.cc @@ -31,6 +31,8 @@ #include +#include "../op/annotation/annotation.h" + namespace tvm { namespace relay { MixedModeVisitor::MixedModeVisitor(int visit_limit) { @@ -527,15 +529,19 @@ Expr Bind(const Expr& expr, const tvm::Map& args_map) { if (const FunctionNode* func = expr.as()) { Expr new_body = ExprBinder(args_map).VisitExpr(func->body); Array new_params; - for (Var param : func->params) { - if (!args_map.count(param)) { - new_params.push_back(param); + std::vector new_param_device_types; + for (size_t i = 0; i < func->params.size(); ++i) { + if (!args_map.count(func->params[i])) { + new_params.push_back(func->params[i]); + new_param_device_types.push_back(GetFunctionParamDeviceType(func, i)); } } if (new_body.same_as(func->body) && new_params.size() == func->params.size()) { return expr; } - auto ret = Function(new_params, new_body, func->ret_type, func->type_params, func->attrs); + auto ret = + Function(new_params, new_body, func->ret_type, func->type_params, func->attrs, func->span); + ret = MaybeFunctionOnDevice(ret, new_param_device_types, GetFunctionResultDeviceType(func)); std::unordered_set set; for (const auto& v : FreeVars(expr)) { set.insert(v); @@ -543,9 +549,19 @@ Expr Bind(const Expr& expr, const tvm::Map& args_map) { for (const auto& v : FreeVars(ret)) { if (set.count(v) == 0) { new_params.push_back(v); + if (GetFunctionResultDeviceType(func) != kInvalidDeviceType) { + // TODO(mbs): The function has been annotated with a device, which means we are supposed + // to be preserving device annotations on every transformation. However there's no + // such context for the free vars in args_map. + LOG(WARNING) << "introduced free var '" << PrettyPrint(v) + << "' into function body but no device is known for it"; + } + new_param_device_types.push_back(kInvalidDeviceType); } } - ret = Function(new_params, new_body, func->ret_type, func->type_params, func->attrs); + ret = + Function(new_params, new_body, func->ret_type, func->type_params, func->attrs, func->span); + ret = MaybeFunctionOnDevice(ret, new_param_device_types, GetFunctionResultDeviceType(func)); ICHECK_EQ(FreeVars(expr).size(), FreeVars(ret).size()); return std::move(ret); } else { diff --git a/src/relay/op/annotation/annotation.cc b/src/relay/op/annotation/annotation.cc index 284f8b88ee0d..beadf4a67ddc 100644 --- a/src/relay/op/annotation/annotation.cc +++ b/src/relay/op/annotation/annotation.cc @@ -71,11 +71,10 @@ Expr MaybeOnDevice(Expr expr, DLDeviceType device_type, bool is_fixed) { // The device can be recovered from the binding site of the global or local variable. return expr; } - if (const auto* function_node = expr.as()) { - if (function_node->HasNonzeroAttr(attr::kPrimitive)) { - // Primitive functions are device polymorphic, matching our interpretation for OpNode above. - return expr; - } + if (expr->IsInstance()) { + // If a primitive function then it is device polymorphic. Otherwise the device is captured + // by the function's attributes. + return expr; } return OnDevice(expr, device_type, is_fixed); } diff --git a/src/relay/op/annotation/annotation.h b/src/relay/op/annotation/annotation.h index 35f8b6bf50b6..b6dff8813fd4 100644 --- a/src/relay/op/annotation/annotation.h +++ b/src/relay/op/annotation/annotation.h @@ -45,14 +45,18 @@ const Op& OnDeviceOp(); Expr OnDevice(Expr expr, DLDeviceType device_type, bool is_fixed); /*! - * \brief Wraps \p expr in an "on_device" CallNode for \p device_type and \p is_fixed. However - * returns \p expr directly if: + * \brief Wraps \p expr in an "on_device" CallNode for \p device_type and \p is_fixed if the + * device for \p expr cannot otherwise be recovered by the lexical scoping convention. This means + * we will NOT wrap if: * - \p device_type is \p kInvalidDeviceType, which signals there are no device annotations * already in play. * - \p expr is an operator or primitive function literal. These are device polymorphic. - * - \p expr is a global or local var. These already have an implied device. - * - \p expr is a constructor. There should probably be device polymorphic but are in an - * in-between state at the moment. + * - \p expr is a non-primitive function literal. The device is captured by the + * "result_device_type" attribute on the function itself. + * - \p expr is a global var. The device is on the function attributes the global is bound to. + * - \p expr is a local var. The device is tracked by the device aware visitors for us. + * - \p expr is a constructor. These should eventually be device polymorphic but are currently + * in an in-between state at the moment. */ Expr MaybeOnDevice(Expr expr, DLDeviceType device_type, bool is_fixed); diff --git a/src/relay/transforms/device_annotation.cc b/src/relay/transforms/device_annotation.cc deleted file mode 100644 index 7457457e4c5c..000000000000 --- a/src/relay/transforms/device_annotation.cc +++ /dev/null @@ -1,581 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file device_annotation.cc - * \brief Passes to rewrite annotated program and retrieve the device allocation - * of expression. - * - * The following passes are performed: - * 1. Validate the unnecessary and redundant annotation. - * 2. Rewrite the annotated program and insert data copy operators. - * 3. Collect the device allocation of each expression. - */ - -#include -#include -#include -#include -#include -#include - -#include -#include -#include - -namespace tvm { -namespace relay { - -namespace { - -bool IsOnDeviceNode(const ExprNode* node) { - if (!node->IsInstance()) return false; - const auto* call_node = static_cast(node); - return call_node->attrs.as(); -} - -bool IsDeviceCopyNode(const ExprNode* node) { - if (!node->IsInstance()) return false; - const auto* call_node = static_cast(node); - - if (call_node->attrs.as()) { - return true; - } - - auto tir_call_attrs = call_node->attrs.as(); - if (tir_call_attrs) { - auto metadata = tir_call_attrs->metadata; - return metadata.count("source_device") == 1 && metadata.count("dst_device") == 1; - } - - return false; -} - -} // namespace - -class ValidateAnnotation : private ExprVisitor { - public: - static std::unordered_map Validate(const Expr& expr) { - ValidateAnnotation valid; - valid(expr); - return valid.annotation_map_; - } - - private: - void VisitExpr_(const CallNode* call_node) final { - ExprVisitor::VisitExpr_(call_node); - if (IsOnDeviceNode(call_node)) { - int device_type = GetDeviceId(call_node); - if (annotation_map_.count(call_node)) { - ICHECK_EQ(annotation_map_.at(call_node), device_type) - << "An expression node can only be annotated to one device."; - } else { - annotation_map_.insert({call_node, GetDeviceId(call_node)}); - } - - ICHECK_EQ(call_node->args.size(), 1U); - const auto* node = call_node->args[0].operator->(); - if (annotation_map_.count(node)) { - ICHECK_EQ(annotation_map_.at(node), device_type) - << "An expression node can only be annotated to one device."; - } else { - annotation_map_.insert({node, GetDeviceId(call_node)}); - } - } - } - - void VisitExpr_(const TupleGetItemNode* get_elem) final { - ExprVisitor::VisitExpr_(get_elem); - const auto* tn = get_elem->tuple.operator->(); - if (annotation_map_.count(tn)) { - annotation_map_.insert({get_elem, annotation_map_.at(tn)}); - } - } - - /* - * \brief Get the device type of the annotation node. - * \param call_node The on_device annotation call node. - * \return The device type. - */ - int GetDeviceId(const CallNode* call_node) { - ICHECK(IsOnDeviceNode(call_node)) << "The input call node must be on_device node."; - const OnDeviceAttrs* on_device_attr = call_node->attrs.as(); - return on_device_attr->device_type; - } - - std::unordered_map annotation_map_; -}; - -// Replace the use of an expression with the output of a `copy_device` operator -// if the `on_device` operator takes the annotated expr as an input. -// -// This actually replaces annotation ops with device copy ops and connects any -// two dependent expressions with a `device_copy` op when needed. Note that the -// device type of a `device_copy` op is identical to that of the destination op -// since it is where the data should be copied to. -class RewriteAnnotation : public ExprMutator { - public: - Expr Rewrite(const Expr& expr, int fallback_device) { - fallback_device_ = fallback_device; - annotation_map_ = ValidateAnnotation::Validate(expr); - return this->VisitExpr(expr); - } - - Expr VisitExpr_(const LetNode* op) final { - Expr value = GetDeviceCopyExpr(op->value, op); - Expr body = GetDeviceCopyExpr(op->body, op); - - if (value.same_as(op->value) && body.same_as(op->body)) { - return ExprMutator::VisitExpr_(op); - } else { - Expr new_let = Let(op->var, value, body); - UpdateAnnotationMap(op, new_let.operator->()); - return this->VisitExpr(new_let); - } - } - - Expr VisitExpr_(const TupleNode* op) { - Array fields; - bool annotated = false; - for (const auto& field : op->fields) { - annotated |= NeedDeviceCopy(field.operator->(), op); - fields.push_back(GetDeviceCopyExpr(field, op)); - } - - if (annotated) { - Expr new_tuple = Tuple(fields); - UpdateAnnotationMap(op, new_tuple.operator->()); - return this->VisitExpr(new_tuple); - } else { - return ExprMutator::VisitExpr_(op); - } - } - - Expr VisitExpr_(const TupleGetItemNode* op) final { - Expr tuple = op->tuple; - if (NeedDeviceCopy(tuple.operator->(), op)) { - Expr new_expr = TupleGetItem(GetDeviceCopyExpr(tuple, op), op->index); - UpdateAnnotationMap(op, new_expr.operator->()); - return this->VisitExpr(new_expr); - } else { - return ExprMutator::VisitExpr_(op); - } - } - - Expr VisitExpr_(const IfNode* if_node) final { - Expr cond = GetDeviceCopyExpr(if_node->cond, if_node); - Expr true_br = GetDeviceCopyExpr(if_node->true_branch, if_node); - Expr false_br = GetDeviceCopyExpr(if_node->false_branch, if_node); - - if (if_node->cond.same_as(cond) && if_node->true_branch.same_as(true_br) && - if_node->false_branch.same_as(false_br)) { - return ExprMutator::VisitExpr_(if_node); - } else { - Expr new_if = If(cond, true_br, false_br); - UpdateAnnotationMap(if_node, new_if.operator->()); - return this->VisitExpr(new_if); - } - } - - Expr VisitExpr_(const CallNode* call_node) final { - if (IsOnDeviceNode(call_node)) { - return this->VisitExpr(call_node->args[0]); - } - - if (IsDeviceCopyNode(call_node)) { - return ExprMutator::VisitExpr_(call_node); - } - - Array new_args; - bool annotated = false; - for (const auto& arg : call_node->args) { - annotated |= NeedDeviceCopy(arg.operator->(), call_node); - new_args.push_back(GetDeviceCopyExpr(arg, call_node)); - } - - if (annotated) { - Call new_call = Call(call_node->op, new_args, call_node->attrs, call_node->type_args); - - UpdateAnnotationMap(call_node, new_call.operator->()); - return this->VisitExpr(new_call); - } else { - return ExprMutator::VisitExpr_(call_node); - } - } - - private: - void UpdateAnnotationMap(const ExprNode* old_node, const ExprNode* new_node) { - const auto it = annotation_map_.find(old_node); - if (it == annotation_map_.end()) { - annotation_map_.insert({new_node, fallback_device_}); - } else { - annotation_map_.insert({new_node, it->second}); - } - this->memo_[GetRef(old_node)] = GetRef(new_node); - } - - Expr GetDeviceCopyExpr(const Expr& src, const ExprNode* dst) { - const auto* src_node = src.operator->(); - if (!NeedDeviceCopy(src_node, dst)) return src; - - const auto sit = annotation_map_.find(src_node); - if (sit == annotation_map_.end()) { - const auto dit = annotation_map_.find(dst); - ICHECK(dit != annotation_map_.end()) - << "Device copy op is not required when both src and dst ops are not " - "annotated."; - return CreateDeviceCopy(src, fallback_device_, dit->second); - } else { - const auto dit = annotation_map_.find(dst); - int dst_dev_type = dit == annotation_map_.end() ? fallback_device_ : dit->second; - return CreateDeviceCopy(src, sit->second, dst_dev_type); - } - } - - // Check if a device copy op is need between two ops. - bool NeedDeviceCopy(const ExprNode* src, const ExprNode* dst) { - if (annotation_map_.count(src)) { - int src_dev_type = annotation_map_.at(src); - if (annotation_map_.count(dst)) { - return src_dev_type != annotation_map_.at(dst); - } else { - return src_dev_type != fallback_device_; - } - } else { - if (annotation_map_.count(dst)) { - // Though data copy op could be inserted whenever the `src` and `dst` - // ops are annotated to different devices, it leads to high overhead. - // - // Here we need across device data transferring only when `src` is a - // CallNode or FunctionNode and the `dst` is annotated with any device - // id other than fallback_device_. - if (src->IsInstance() || src->IsInstance()) { - return annotation_map_.at(dst) != fallback_device_; - } else { - // There shouldn't be any copy nodes between var/constant and another - // expression. - return !(src->IsInstance() || src->IsInstance()); - } - } else { - return false; - } - } - } - - /* - * \brief Create an operator to copy data from the source device to the - * destination device. - * \param src The source expression that produces data to be copied. - * \param src_dev_type The device type where the data is copied from. - * \param dst_dev_type The device type where the data is copied to. - * \return The created call node. - */ - Call CreateDeviceCopy(const Expr& src, int src_dev_type, int dst_dev_type) { - auto attrs = make_object(); - attrs->src_dev_type = src_dev_type; - attrs->dst_dev_type = dst_dev_type; - static const Op& op = Op::Get("device_copy"); - Call device_copy = Call(op, {src}, Attrs(attrs), {}); - annotation_map_.insert({device_copy.operator->(), dst_dev_type}); - return device_copy; - } - - std::unordered_map annotation_map_; - int fallback_device_; -}; - -// Get all annotation expressions. -class AnnotatationVisitor : private ExprVisitor { - public: - static Map GetAnnotations(const Expr& expr) { - AnnotatationVisitor visitor; - visitor(expr); - return visitor.annotations_; - } - - private: - void VisitExpr_(const CallNode* call_node) { - if (IsOnDeviceNode(call_node)) { - const auto* attr = call_node->attrs.as(); - annotations_.Set(GetRef(call_node), attr->device_type); - } - ExprVisitor::VisitExpr_(call_node); - } - Map annotations_; -}; - -/* - * \brief Return device allocation map based on the post order traversed graph. - * For the following program: - * .. code-block:: python - * x = relay.var("x") - * y = relay.var("y") - * add = relay.add(x, y) - * sqrt = relay.sqrt(add) - * log = relay.log(add) - * subtract = relay.subtract(sqrt, log) - * exp = relay.exp(subtract) - * - * Suppose we have annotated add, sqrt, and log with device 1, 2, and 3, - * respectively. The fallback/default device is 4. After Rewriting the - * program, we can have the following graph, where each copy op has both - * source and destination device type denoting which device the data should be - * copied from and to. - * - * x y - * \ / - * add/1 - * / \ - * copy1 copy2 - * | | - * sqrt/2 log/3 - * | | - * copy3 copy4 - * \ / - * subtract - * | - * exp - * - * To Get the device mapping of each expression, we need to propagate the - * device information from the copy ops. This can be done in two passes. - * -Pass 1: Propagating the source device type to ops in a bottom-up way to the - * ancestors until encountering another copy op. For example, this way - * provides add, x, and y device types from the copy operator, `copy1`. - * -Pass 2: Propagating the destination device type of "the last" copy op to the - * remain nodes. For instance, this offers `subtract` and `exp` the - * same device type as `copy3`. - */ - -class DeviceInfo { - public: - static Map GetDeviceMap(const Expr& expr) { - DeviceInfo device_info; - device_info.post_visitor_ = PostDfsOrderVisitor(); - device_info.post_visitor_.Visit(expr); - if (device_info.post_visitor_.num_device_copy_ops_ > 0) { - device_info.PropagateDeviceId(); - return device_info.device_map_; - } else { - return Map(); - } - } - - private: - class PostDfsOrderVisitor : private ExprVisitor { - public: - void Visit(const Expr& expr) { - if (const auto* fn = expr.as()) { - for (const auto& param : fn->params) { - this->VisitExpr(param); - } - this->VisitExpr(fn->body); - } else { - this->VisitExpr(expr); - } - } - - private: - // Post order traversal. - void VisitExpr_(const FunctionNode* fn) final { - // TODO(zhiics) Skip annotation of function node for now. - } - - void VisitExpr_(const ConstantNode* cn) final { device_tag_[cn] = dev_type_; } - - void VisitExpr_(const CallNode* call) final { - // Skip annotation nodes. - if (!IsOnDeviceNode(call)) { - if (const auto* node = GetDeviceCopyNode(call)) { - ICHECK(node->IsInstance()); - const auto* call_node = static_cast(node); - auto attrs = call_node->attrs.as(); - - if (attrs) { - num_device_copy_ops_++; - dev_type_ = attrs->src_dev_type; - for (auto& arg : call->args) { - Visit(arg); - // restore the type for remaining arguments - dev_type_ = attrs->src_dev_type; - } - device_tag_[call] = attrs->dst_dev_type; - // update the out_dev_type_, which should be the dst_dev_type of last copy - out_dev_type_ = attrs->dst_dev_type; - } else { - auto attrs = call_node->attrs.as(); - CHECK(attrs) << "must be non-null"; - num_device_copy_ops_++; - dev_type_ = Downcast(attrs->metadata["source_device"]); - for (auto& arg : call->args) { - Visit(arg); - // restore the type for remaining arguments - dev_type_ = Downcast(attrs->metadata["source_device"]); - } - device_tag_[call] = Downcast(attrs->metadata["dst_device"]); - // update the out_dev_type_, which should be the dst_dev_type of last copy - out_dev_type_ = Downcast(attrs->metadata["dst_device"]); - } - } else { - for (auto& arg : call->args) { - int cur_dev_type = dev_type_; - Visit(arg); - // restore the type for remaining arguments - dev_type_ = cur_dev_type; - } - device_tag_[call] = dev_type_; - } - } - } - - void VisitExpr_(const TupleNode* tn) final { - ExprVisitor::VisitExpr_(tn); - // TODO(zhiics) Skip annotation of tuple node for now. - } - - void VisitExpr_(const TupleGetItemNode* op) final { ExprVisitor::VisitExpr_(op); } - - void VisitExpr_(const VarNode* vn) final { device_tag_[vn] = dev_type_; } - - void VisitExpr_(const LetNode* ln) final { - ExprVisitor::VisitExpr_(ln); - device_tag_[ln] = dev_type_; - } - - void VisitExpr_(const IfNode* in) final { - ExprVisitor::VisitExpr_(in); - device_tag_[in] = dev_type_; - } - - int num_device_copy_ops_{0}; - int dev_type_ = -1; - int out_dev_type_ = -1; - std::unordered_map device_tag_; - friend DeviceInfo; - }; - - /* - * \brief Returns a device copy node based on the current expr node. It - * returns a device copy node either the current expr node is a device copy - * node or the current expr node is a function node whose body is a device - * copy node (i.e. the fused function of a device copy call node). - */ - static const ExprNode* GetDeviceCopyNode(const ExprNode* node) { - if (IsDeviceCopyNode(node)) { - return node; - } else if (node->IsInstance()) { - const auto* call_node = static_cast(node); - if (const auto* fn = call_node->op.as()) { - const ExprNode* body = fn->body.operator->(); - if (IsDeviceCopyNode(body)) { - return body; - } - } - } - return nullptr; - } - - void PropagateDeviceId() { - int out_dev_type = post_visitor_.out_dev_type_; - for (auto& it : post_visitor_.device_tag_) { - if (it.second != -1) { - device_map_.Set(GetRef(it.first), it.second); - } else { - device_map_.Set(GetRef(it.first), out_dev_type); - } - } - } - - PostDfsOrderVisitor post_visitor_; - Map device_map_; -}; - -Expr RewriteAnnotatedOps(const Expr& expr, int fallback_device) { - RewriteAnnotation rewrote = RewriteAnnotation(); - Expr new_expr = rewrote.Rewrite(expr, fallback_device); - - // Remove OnDevice operators. Note that these operators are only present at the - // leaves after annotation. Therefore, we can simply reconstruct the - // Function/Expr by removing them directly. - if (const FunctionNode* fn = new_expr.as()) { - auto params = fn->params; - auto body = fn->body; - std::vector new_body; - if (const TupleNode* tuple = body.as()) { - for (const auto& field : tuple->fields) { - if (!IsOnDeviceNode(field.operator->())) { - new_body.push_back(field); - } - } - ICHECK_GT(new_body.size(), 0U); - if (new_body.size() == 1) { - return Function(params, new_body[0], Type(nullptr), fn->type_params, fn->attrs); - } else if (tuple->fields.size() == new_body.size()) { - return new_expr; - } else { - Tuple tuple_body = Tuple(new_body); - return Function(params, tuple_body, Type(nullptr), fn->type_params, fn->attrs); - } - } else { - return new_expr; - } - } else if (const TupleNode* tuple = new_expr.as()) { - std::vector new_fields; - for (const auto& field : tuple->fields) { - if (!IsOnDeviceNode(field.operator->())) { - new_fields.push_back(field); - } - } - ICHECK_GT(new_fields.size(), 0U); - if (tuple->fields.size() == new_fields.size()) { - return new_fields.size() == 1 ? new_fields[0] : new_expr; - } else { - return new_fields.size() == 1 ? new_fields[0] : Tuple(new_fields); - } - } else { - return new_expr; - } -} - -Map CollectDeviceInfo(const Expr& expr) { return DeviceInfo::GetDeviceMap(expr); } - -Map CollectDeviceAnnotationOps(const Expr& expr) { - return AnnotatationVisitor::GetAnnotations(expr); -} - -TVM_REGISTER_GLOBAL("relay.analysis.CollectDeviceInfo").set_body_typed(CollectDeviceInfo); - -TVM_REGISTER_GLOBAL("relay.analysis.CollectDeviceAnnotationOps") - .set_body_typed(CollectDeviceAnnotationOps); - -namespace transform { - -Pass RewriteAnnotatedOps(int fallback_device) { - runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { - return Downcast(relay::RewriteAnnotatedOps(f, fallback_device)); - }; - return CreateFunctionPass(pass_func, 1, "RewriteAnnotatedOps", {"InferType"}); -} - -TVM_REGISTER_GLOBAL("relay._transform.RewriteDeviceAnnotation").set_body_typed(RewriteAnnotatedOps); - -} // namespace transform - -} // namespace relay -} // namespace tvm diff --git a/src/relay/transforms/device_aware_visitors.cc b/src/relay/transforms/device_aware_visitors.cc index 204bce53207b..28aeab60539c 100644 --- a/src/relay/transforms/device_aware_visitors.cc +++ b/src/relay/transforms/device_aware_visitors.cc @@ -19,7 +19,7 @@ /*! * \file src/relay/transforms/device_aware_visitors.cc - * \brief Visitors which track the device for the current Relay expression and Relay Vars. + * \brief Visitors which track the device for the current Relay expression. */ #include "./device_aware_visitors.h" @@ -28,30 +28,51 @@ namespace tvm { namespace relay { namespace transform { -// TODO(mbs): We'd probably have less tendious code duplication if we redefined the memoizing -// mutator on top of the generic Functor. +// TODO(mbs): This machinery can be used a) on expressions/modules which have not had +// device planning run, and b) on expressions for which we've not kept track of their +// containing module. For now we'll handle b) by being forgiving as possible when recovering +// the device for an expression, and we'll support a) the same way. But better would be +// to ICHECK fail when, eg, a variable is not in scope or the lexical device stack is empty. + +LexicalOnDeviceMixin::LexicalOnDeviceMixin(const Optional& maybe_mod) { + if (maybe_mod) { + for (const auto& pair : maybe_mod.value()->functions) { + if (const auto* function_node = pair.second.as()) { + DLDeviceType device_type = GetFunctionResultDeviceType(function_node); + if (device_type != kInvalidDeviceType) { + global_var_device_types_.emplace(pair.first, device_type); + } + } + } + } +} DLDeviceType LexicalOnDeviceMixin::GetInScopeDeviceType(const Expr& expr) const { auto props = GetOnDeviceProps(expr); if (props.body.defined() && props.is_fixed) { - // Look through any fixed "on_device" annotations. return props.device_type; - } - if (expr->IsInstance()) { + } else if (const auto* var_node = expr.as()) { // Lookup variable binding. - auto itr = var_device_types_.find(Downcast(expr)); - if (itr == var_device_types_.end()) { - return kInvalidDeviceType; - } else { + auto itr = var_device_types_.find(GetRef(var_node)); + if (itr != var_device_types_.end()) { return itr->second; } - } - // Otherwise use the currently in-scope device type. - if (expr_device_types_.empty()) { - return kInvalidDeviceType; + // else: fallthrough to unknown + } else if (const auto* global_var_node = expr.as()) { + // Lookup global variable. + auto itr = global_var_device_types_.find(GetRef(global_var_node)); + if (itr != global_var_device_types_.end()) { + return itr->second; + } + // else: fallthrough to unknown } else { - return expr_device_types_.back(); + if (!expr_device_types_.empty()) { + // Use the currently in-scope device type. + return expr_device_types_.back(); + } + // else: fallthrough to unknown } + return kInvalidDeviceType; } void LexicalOnDeviceMixin::EnterFunctionBody() { ++function_nesting_; } @@ -91,6 +112,9 @@ void LexicalOnDeviceMixin::PopBoundVar(const Var& var) { var_device_types_.erase(itr); } +// TODO(mbs): We'd probably have less tedious code duplication if we redefined the memoizing +// mutator on top of the generic Functor. + void DeviceAwareExprVisitor::VisitExpr_(const FunctionNode* function_node) { if (function_node->HasNonzeroAttr(attr::kPrimitive)) { // No tracking inside primitive functions. diff --git a/src/relay/transforms/device_aware_visitors.h b/src/relay/transforms/device_aware_visitors.h index 8611f87efa06..3f4c5c24481e 100644 --- a/src/relay/transforms/device_aware_visitors.h +++ b/src/relay/transforms/device_aware_visitors.h @@ -42,18 +42,19 @@ namespace transform { /*! * \brief Helper class for expression transformers which need to keep track of the device - * holding the results of expressions and bound variables. This is recovered from the - * "on_device" function attributes and fixed "on_device" CallNodes added by the PlanDevices - * pass. + * holding the results of expressions. This is recovered from function attributes and "on_device" + * CallNodes added by the PlanDevices pass. * - * \sa \p DeviceAwareExpr{Visitor,Mutator}. + * \sa \p DeviceAwareExpr{Functor,Visitor,Mutator}. */ class LexicalOnDeviceMixin { protected: + explicit LexicalOnDeviceMixin(const Optional& maybe_mod); + /*! * \brief Returns the device type on which the result of \p expr should/will be stored, assuming - * Push/Pop DeviceType/BoundVar have been correctly called. Returns \p kInvalidDeviceType if - * stack is empty and no bound vars have device types. + * Push/Pop DeviceType/BoundVar have been correctly called. May return \p kInvalidDeviceType if + * the device planning pass has not been run. */ DLDeviceType GetInScopeDeviceType(const Expr& expr) const; @@ -64,7 +65,7 @@ class LexicalOnDeviceMixin { void ExitFunctionBody(); /*! \brief Push a device type onto the lexical device stack. Ignore if \p kInvalidDeviceType. */ - void PushDeviceType(const DLDeviceType device_type); + void PushDeviceType(DLDeviceType device_type); /*! \brief Pop a device type from the lexical device stack. Ignore if stack is empty. */ void PopDeviceType(); @@ -92,16 +93,28 @@ class LexicalOnDeviceMixin { /*! * \brief The stack of lexically enclosing "on_device" devices types, from outermost to innermost. - * When visiting an expression other than a variable we can assume the expression result is - * to be stored on device_type_.back(). + * When visiting an expression other than a variable we can assume the expression's result is to + * be stored on device_type_.back(). */ std::vector expr_device_types_; + /*! - * \brief A map from in-scope variable to their device types. We may assume the variable is only - * ever bound to a value stored on this device at runtime. + * \brief A map from in-scope local variables to their device types. We may assume the variable is + * only ever bound to a value stored on this device at runtime. + * + * Note: We're playing it safe and keying by object refs here just in case the Relay expression + * being rewritten has no module or other global to keep it alive. */ std::unordered_map var_device_types_; + + /*! + * \brief A map from global variables to their device types, ie the "result_device_type" of the + * function they are bound to in the module we are working on. We calculate this explicitly so + * that we don't neeed to hold on to any module, which is often in the process of being rewritten. + */ + std::unordered_map + global_var_device_types_; }; template @@ -119,6 +132,9 @@ class DeviceAwareExprFunctor : public ExprFunctor; public: + explicit DeviceAwareExprFunctor(const Optional& maybe_mod) + : LexicalOnDeviceMixin(maybe_mod) {} + void VisitExpr_(const FunctionNode* function_node) { if (function_node->HasNonzeroAttr(attr::kPrimitive)) { // No tracking inside primitive functions. @@ -229,6 +245,9 @@ class DeviceAwareExprFunctor : public ExprFunctor& maybe_mod) + : LexicalOnDeviceMixin(maybe_mod) {} + using ExprVisitor::VisitExpr_; void VisitExpr_(const FunctionNode* function_node) final; @@ -272,6 +291,9 @@ class DeviceAwareExprVisitor : public ExprVisitor, public LexicalOnDeviceMixin { /*! \brief ExprMutator which tracks devices. */ class DeviceAwareExprMutator : public ExprMutator, public LexicalOnDeviceMixin { public: + explicit DeviceAwareExprMutator(const Optional& maybe_mod) + : LexicalOnDeviceMixin(maybe_mod) {} + Expr VisitExpr_(const FunctionNode* function_node) final; Expr VisitExpr_(const LetNode* let_node) final; Expr VisitExpr_(const CallNode* call_node) final; diff --git a/src/relay/transforms/device_planner.cc b/src/relay/transforms/device_planner.cc index 35bf406263e4..dc61e79226b6 100644 --- a/src/relay/transforms/device_planner.cc +++ b/src/relay/transforms/device_planner.cc @@ -672,6 +672,7 @@ class DeviceDefaulter : public ExprVisitor { std::unique_ptr Default() { VLOG_CONTEXT << "DeviceDefaulter"; + VLOG(0) << "using default device type " << default_device_type_; for (const auto& pair : mod_->functions) { VLOG(1) << "defaulting devices for '" << PrettyPrint(pair.first) << "'"; VisitExpr(pair.second); diff --git a/src/relay/transforms/fold_scale_axis.cc b/src/relay/transforms/fold_scale_axis.cc index 7b3f2da716aa..ca9a286ff6a7 100644 --- a/src/relay/transforms/fold_scale_axis.cc +++ b/src/relay/transforms/fold_scale_axis.cc @@ -242,6 +242,10 @@ class ForwardPrep : private MixedModeVisitor { message_[key] = message; } } + + // We intended the following overrides on implementations from ExprVisitor. + using MixedModeVisitor::VisitExpr_; + // Visitor pattern override. void VisitExpr_(const TupleGetItemNode* op) final { MixedModeVisitor::VisitExpr_(op); } diff --git a/src/relay/transforms/fuse_ops.cc b/src/relay/transforms/fuse_ops.cc index f1f7a95e33e8..960f56957ebb 100644 --- a/src/relay/transforms/fuse_ops.cc +++ b/src/relay/transforms/fuse_ops.cc @@ -31,8 +31,9 @@ #include #include "../../support/arena.h" -#include "pass_utils.h" -#include "pattern_utils.h" +#include "../op/annotation/annotation.h" +#include "./pass_utils.h" +#include "./pattern_utils.h" namespace tvm { namespace relay { @@ -1028,7 +1029,7 @@ Pass FuseOps(int fuse_opt_level) { auto max_fuse_depth = pc->GetConfig("relay.FuseOps.max_depth", Integer(kMaxFusedOps)); return Downcast(FuseOps(f, opt_level, max_fuse_depth.value(), m)); }; - return CreateFunctionPass(pass_func, 1, "FuseOps", {"InferType"}); + return CreateFunctionPass(pass_func, 0, "FuseOps", {"InferType"}); } TVM_REGISTER_GLOBAL("relay._transform.FuseOps").set_body_typed(FuseOps); diff --git a/src/relay/transforms/higher_order_gradient.cc b/src/relay/transforms/higher_order_gradient.cc index 202275626d5d..3adff6e3099a 100644 --- a/src/relay/transforms/higher_order_gradient.cc +++ b/src/relay/transforms/higher_order_gradient.cc @@ -293,7 +293,7 @@ struct ReverseAD : ExprMutator { return Call(bpv, {}); }); Expr nbp = Function({}, nbp_body, TupleType::Empty(), {}); - ll->Push(RefWrite(bp, transform::ToANormalForm(nbp))); + ll->Push(RefWrite(bp, transform::ToANormalForm(mod, nbp))); // TODO(@M.K.): ToANF should be called on rev. Enhance ToANF for that. return ret; }); diff --git a/src/relay/transforms/let_list.h b/src/relay/transforms/let_list.h index c75f18f6831c..56875f6c16a1 100644 --- a/src/relay/transforms/let_list.h +++ b/src/relay/transforms/let_list.h @@ -65,7 +65,7 @@ class LetList { */ Var Push(Var pv, Expr expr) { ICHECK(!used_); - ICHECK(WellFormed(expr)); + ICHECK(WellFormed(expr)) << "expression:" << std::endl << PrettyPrint(expr); lets_.emplace_back(std::make_pair(pv, expr)); return pv; } diff --git a/src/relay/transforms/memory_alloc.cc b/src/relay/transforms/memory_alloc.cc index 31d3b2c8991a..71917c31ec00 100644 --- a/src/relay/transforms/memory_alloc.cc +++ b/src/relay/transforms/memory_alloc.cc @@ -47,6 +47,7 @@ #include "../op/memory/device_copy.h" #include "../op/memory/memory.h" #include "../op/vm/vm.h" +#include "./device_aware_visitors.h" #include "./let_list.h" #include "./pass_utils.h" #include "./pattern_utils.h" @@ -56,9 +57,6 @@ using namespace tvm::runtime; namespace tvm { namespace relay { -using AnalysisResultMap = - std::unordered_map; - inline Constant MakeConstant(const std::vector& value) { return MakeConstantTensor(DataType::Int(64), {static_cast(value.size())}, value); } @@ -86,29 +84,18 @@ bool IsReshapeOnly(const Expr& expr) { return false; } -class DialectRewriter : public ExprMutator { +class DialectRewriter : public transform::DeviceAwareExprMutator { public: - DialectRewriter(const Target& target_host, const AnalysisResultMap& context_analysis_map) - : target_host_(target_host), context_analysis_map_(context_analysis_map) {} - - // Get the device of an expression. - Device GetDevice(const Expr& expr) const { - auto it = context_analysis_map_.find(expr); - CHECK(it != context_analysis_map_.end()) << "Cannot find expr in the context analysis map:\n" - << AsText(expr, false); - return it->second; - } + DialectRewriter(IRModule mod, const Target& target_host) + : transform::DeviceAwareExprMutator(std::move(mod)), target_host_(target_host) {} - Function Rewrite(const Function& expr) { - auto ret = ExprMutator::Mutate(expr); - return Downcast(ret); - } + Function Rewrite(const Function& expr) { return Downcast(Mutate(expr)); } Expr VisitExpr_(const TupleNode* tn) final { LetList& scope = scopes_.back(); Array new_fields; for (auto field : tn->fields) { - auto new_field = ExprMutator::Mutate(field); + auto new_field = Mutate(field); if (new_field->IsInstance()) { Var const_var("const", Type(nullptr)); new_field = scope.Push(const_var, new_field); @@ -118,32 +105,38 @@ class DialectRewriter : public ExprMutator { return Tuple(new_fields); } - Expr VisitExpr_(const LetNode* ln) final { - scopes_.emplace_back(); + void PreVisitLetBlock_(const LetNode* let_node) final { scopes_.emplace_back(); } - const LetNode* let = ln; - Expr body; - while (let) { - auto new_value = ExprMutator::Mutate(let->value); - scopes_.back().Push(let->var, new_value); - body = let->body; - let = body.as(); - } + std::pair PreVisitLetBinding_(const Var& var, const Expr& value) final { + Expr new_value = Mutate(value); + scopes_.back().Push(var, new_value); + // Since we always need a let block on which to bind sub-expressions the rewritten bindings + // are tracked in the current scopes. But return the rewritten binding anyway. + return {var, new_value}; + } - CHECK(body.defined()); - auto new_body = ExprMutator::Mutate(body); + Expr PostVisitLetBlock_(const LetNode* pre_let_node, const LetNode* post_let_node) final { + // The current scope has captured all the rewritten let-binding, as well as any additional + // bindings we needed to add. All we need is the rewritted body. + Expr new_body = post_let_node->body; + while (const auto* inner_let_node = new_body.as()) { + new_body = inner_let_node->body; + } auto ret = scopes_.back().Get(new_body); scopes_.pop_back(); return ret; } - Expr VisitExpr_(const CallNode* cn) final { + Expr DeviceAwareVisitExpr_(const CallNode* cn) final { + Call call = GetRef(cn); + DLDeviceType device_type = GetInScopeDeviceType(call); if (IsPrimitive(cn)) { // Because we are in ANF we do not need to visit the arguments. + // TODO(mbs): But does so anyway... LetList& scope = scopes_.back(); std::vector new_args; for (const auto& it : cn->args) { - new_args.push_back(ExprMutator::Mutate(it)); + new_args.push_back(Mutate(it)); } Tuple ins(new_args); @@ -171,31 +164,36 @@ class DialectRewriter : public ExprMutator { return DeviceCopy(new_args[0], copy_attr->src_dev_type, copy_attr->dst_dev_type); } else if (IsDynamic(ret_type)) { Function func = Downcast(cn->op); - return DynamicInvoke(&scope, func, ins, new_args, out_types, ret_type); + // TODO(mbs): Device id is always zero. + Device device{device_type, /*device_id=*/0}; + return DynamicInvoke(&scope, func, ins, new_args, out_types, ret_type, device); } else { // Handle the static case Array outs; for (size_t i = 0; i < out_types.size(); ++i) { - Device dev = GetDevice(GetRef(cn)); - auto out = MakeStaticAllocation(&scope, out_types[i], dev, std::to_string(i)); + DLDeviceType device_type = GetInScopeDeviceType(GetRef(cn)); + // TODO(mbs): Device id is always zero. + Device device{device_type, /*device_id=*/0}; + auto out = MakeStaticAllocation(&scope, out_types[i], device, std::to_string(i)); outs.push_back(out); } Tuple output(outs); + // TODO(mbs): Capture device in attributes. Expr invoke = InvokeTVMOp(cn->op, ins, output); - scope.Push(invoke); + scope.Push(OnDevice(invoke, device_type, /*is_fixed=*/true)); return ToTupleType(ret_type, std::vector(output->fields.begin(), output->fields.end())); } } else { - return ExprMutator::VisitExpr_(cn); + return transform::DeviceAwareExprMutator::DeviceAwareVisitExpr_(cn); } } private: // Insert a device copy node. Expr DeviceCopy(const Expr& inp, int src_dev, int dst_dev) { - return ExprMutator::Mutate(relay::DeviceCopy(inp, static_cast(src_dev), - static_cast(dst_dev))); + return Mutate(relay::DeviceCopy(inp, static_cast(src_dev), + static_cast(dst_dev))); } // Check if a call invokes a primitive function. @@ -212,9 +210,9 @@ class DialectRewriter : public ExprMutator { if (const auto* fn = expr.as()) { auto body = fn->body; const CallNode* call = body.as(); - return call && call->op == Op::Get("device_copy"); + return call && call->op == device_copy_op_; } else if (const CallNode* cn = expr.as()) { - return cn->op == Op::Get("device_copy"); + return cn->op == device_copy_op_; } else { return false; } @@ -257,16 +255,20 @@ class DialectRewriter : public ExprMutator { CHECK(imm) << "expect static int shape"; int_shape.push_back(imm->value); } - Expr shape = MakeConstant(int_shape); - Expr size = ComputeStorage(type); + Expr shape = OnDevice(MakeConstant(int_shape), cpu_device_.device_type, /*is_fixed=*/true); + Expr size = OnDevice(ComputeStorage(type), cpu_device_.device_type, /*is_fixed=*/true); + // Alignment is directly captured in the instruction rather than calculated, so we + // don't want to wrap it with an "on_device". Expr alignment = ComputeAlignment(type->dtype); // Run type inference later to get the correct type. Var var("storage_" + name_hint, Type(nullptr)); - Expr value = AllocStorage(size, alignment, dev, type->dtype); + Expr value = OnDevice(AllocStorage(size, alignment, dev, type->dtype), dev.device_type, + /*is_fixed=*/true); auto sto = scope->Push(var, value); // TODO(@jroesch): There is a bug with typing based on the constant shape. - auto tensor = AllocTensor(sto, shape, type->dtype, type->shape); + auto tensor = OnDevice(AllocTensor(sto, shape, type->dtype, /*assert_shape=*/type->shape), + dev.device_type, /*is_fixed=*/true); Var tensor_var("tensor_" + name_hint, Type(nullptr)); return scope->Push(tensor_var, tensor); } @@ -284,7 +286,6 @@ class DialectRewriter : public ExprMutator { Array is_inputs; int input_pos = 0; - Device cpu_dev = default_device_; CHECK_EQ(new_args.size(), input_states.size()); for (size_t i = 0; i < new_args.size(); ++i) { Expr arg = new_args[i]; @@ -296,27 +297,28 @@ class DialectRewriter : public ExprMutator { } int state = input_states[i]->value; // Pass Shapes - if (state == 2) { + if (state == tec::kNeedInputShape) { std::vector exprs = FromTupleType(ty, arg); for (size_t j = 0; j < exprs.size(); ++j) { - Expr sh_of = ExprMutator::Mutate(ShapeOf(exprs[j])); + Expr sh_of = Mutate(ShapeOf(exprs[j])); // already accounts for device Var in_shape_var("in_shape_" + std::to_string(input_pos + j), Type(nullptr)); shape_func_ins.push_back(scope->Push(in_shape_var, sh_of)); input_pos++; } is_inputs.push_back(0); - } else if (state == 1) { - auto new_arg = ExprMutator::Mutate(arg); - auto dev = GetDevice(arg); - if (dev.device_type != cpu_dev.device_type) { - new_arg = DeviceCopy(new_arg, dev.device_type, cpu_dev.device_type); + } else if (state == tec::kNeedInputData) { + auto new_arg = Mutate(arg); // already accounts for device + DLDeviceType device_type = GetInScopeDeviceType(arg); + if (device_type != cpu_device_.device_type) { + new_arg = OnDevice(DeviceCopy(new_arg, device_type, cpu_device_.device_type), + cpu_device_.device_type, /*is_fixed=*/true); } Var in_shape_var("in_shape_" + std::to_string(input_pos), Type(nullptr)); shape_func_ins.push_back(scope->Push(in_shape_var, new_arg)); input_pos++; is_inputs.push_back(1); } else { - // TODO(@jroesch): handle 3rd case + // TODO(@jroesch): handle kNeedBoth LOG(FATAL) << "unsupported shape function input state"; } } @@ -327,12 +329,14 @@ class DialectRewriter : public ExprMutator { auto tt = TensorType(out->shape, out->dtype); // Put shape func on CPU. This also ensures that everything between // shape_of and shape_func are on CPU. - auto alloc = MakeStaticAllocation(scope, tt, cpu_dev, std::to_string(i)); + auto alloc = OnDevice(MakeStaticAllocation(scope, tt, cpu_device_, std::to_string(i)), + cpu_device_.device_type, /*is_fixed=*/true); Var shape_func_out_var("shape_func_out_" + std::to_string(i), Type(nullptr)); alloc = scope->Push(shape_func_out_var, alloc); out_shapes.push_back(alloc); } - auto shape_call = ShapeFunc(func, Tuple(shape_func_ins), Tuple(out_shapes), is_inputs); + auto shape_call = OnDevice(ShapeFunc(func, Tuple(shape_func_ins), Tuple(out_shapes), is_inputs), + cpu_device_.device_type, /*is_fixed=*/true); Var shape_func_var("shape_func", Type(nullptr)); scope->Push(shape_func_var, shape_call); return out_shapes; @@ -341,18 +345,20 @@ class DialectRewriter : public ExprMutator { // Generate the code for invoking a TVM op with a dynamic shape. Expr DynamicInvoke(LetList* scope, const Function& func, const Tuple& ins, const std::vector& new_args, const std::vector& out_types, - const Type& ret_type) { + const Type& ret_type, Device dev) { auto out_shapes = EmitShapeFunc(scope, func, new_args); std::vector storages; - auto func_dev = GetDevice(func); CHECK_EQ(out_shapes.size(), out_types.size()); for (size_t i = 0; i < out_shapes.size(); ++i) { auto out_shape = out_shapes[i]; auto out_type = out_types[i]; - auto size = ComputeStorageInRelay(out_shape, out_type); + auto size = OnDevice(ComputeStorageInRelay(out_shape, out_type), cpu_device_.device_type, + /*is_fixed=*/true); + // Alignment is directly captured in the instruction so don't wrap in "on_device". auto alignment = ComputeAlignment(out_type->dtype); Var sto_var("storage_" + std::to_string(i), Type(nullptr)); - auto val = AllocStorage(size, alignment, func_dev, out_type->dtype); + auto val = OnDevice(AllocStorage(size, alignment, dev, out_type->dtype), dev.device_type, + /*is_fixed=*/true); storages.push_back(scope->Push(sto_var, val)); } @@ -361,13 +367,14 @@ class DialectRewriter : public ExprMutator { auto out_shape = out_shapes[i]; auto out_type = out_types[i]; auto storage = storages[i]; - auto alloc = AllocTensor(storage, out_shape, out_type->dtype, out_type->shape); + auto alloc = OnDevice(AllocTensor(storage, out_shape, out_type->dtype, out_type->shape), + dev.device_type, /*is_fixed=*/true); Var out_var("out_" + std::to_string(i), Type(nullptr)); outs.push_back(scope->Push(out_var, alloc)); } Tuple tuple_outs(outs); - auto invoke = InvokeTVMOp(func, ins, tuple_outs); + auto invoke = OnDevice(InvokeTVMOp(func, ins, tuple_outs), dev.device_type, /*is_fixed=*/true); scope->Push(invoke); return ToTupleType(ret_type, std::vector(tuple_outs->fields.begin(), tuple_outs->fields.end())); @@ -393,12 +400,13 @@ class DialectRewriter : public ExprMutator { } private: + const Op& device_copy_op_ = Op::Get("device_copy"); + Target target_host_; - AnalysisResultMap context_analysis_map_; std::vector scopes_; runtime::DataType compute_dtype_ = runtime::DataType::Int(64); - Device default_device_{kDLCPU, 0}; + Device cpu_device_{kDLCPU, 0}; }; namespace transform { @@ -413,27 +421,11 @@ Pass ManifestAlloc(Target target_host, Map targets) { mod->ImportFromStd("core.rly"); mod = relay::transform::InferType()(mod); - Device fallback_dev; - if (targets.size() > 1) { - auto pass_ctx = PassContext::Current(); - Optional opt_fallback_dev_type = - pass_ctx->GetConfig("relay.fallback_device_type", Integer(static_cast(kDLCPU))); - auto fallback_dev_type = opt_fallback_dev_type.value(); - CHECK_GT(fallback_dev_type->value, 0U); - fallback_dev.device_type = static_cast(fallback_dev_type->value); - fallback_dev.device_id = 0; - } else { - const auto& it = targets.begin(); - fallback_dev.device_type = static_cast((*it).first->value); - fallback_dev.device_id = 0; - } - auto ca = ContextAnalysis(mod, fallback_dev); - auto glob_funcs = mod->functions; for (const auto& it : glob_funcs) { if (auto* func_node = it.second.as()) { auto func = GetRef(func_node); - auto rewriter = DialectRewriter(target_host, ca); + auto rewriter = DialectRewriter(mod, target_host); auto updated_func = rewriter.Rewrite(func); mod->Update(it.first, updated_func); diff --git a/src/relay/transforms/pass_utils.h b/src/relay/transforms/pass_utils.h index bb2f268a23d7..5638804b4aa2 100644 --- a/src/relay/transforms/pass_utils.h +++ b/src/relay/transforms/pass_utils.h @@ -36,7 +36,8 @@ #include #include "../analysis/dependency_graph.h" -#include "let_list.h" +#include "../op/annotation/annotation.h" +#include "./let_list.h" namespace tvm { namespace relay { @@ -118,8 +119,11 @@ inline Expr TransformF(const std::function& func, const Expr& * if so, the compute cost of the expression is bounded so it can be copy without graph mode. */ inline bool IsAtomic(const Expr& e) { - return e.as() || e.as() || e.as() || e.as() || - e.as(); // Constant is always by reference. + auto props = GetOnDeviceProps(e); + Expr true_expr = props.body.defined() ? props.body : e; + return true_expr.as() || true_expr.as() || true_expr.as() || + true_expr.as() || + true_expr.as(); // Constant is always by reference. } /*! @@ -222,57 +226,10 @@ std::pair CalcScope(const DependencyGraph& dg); */ Scope LCA(Scope lhs, Scope rhs); -/* Special care is needed to handle local recursion. - * Fill additionally take a (possibly null) Var argument, - * If it is not null, Fill is required to bind the transformed result to that var. - */ -class Fill : ExprFunctor { - public: - static Expr ToANormalForm(const Expr& e, const DependencyGraph& dg, NodeScopeMap* node_scope); - - // For basic block normal form, bind expressions only if the original expression's - // scope should be lifted - static Expr ToBasicBlockNormalForm(const Expr& e, const DependencyGraph& dg, - NodeScopeMap* node_scope, ExprSet* lifted); - - private: - const DependencyGraph& dg_; - NodeScopeMap* node_scope_ = nullptr; - std::unordered_map memo; - // a set of Expressions to include for let bindings. If set to nullptr - // all Exprs will be pushed to the let list. - ExprSet* include_set_ = nullptr; - - Fill(const DependencyGraph& dg, NodeScopeMap* node_scope, ExprSet* include_set) - : dg_(dg), node_scope_(node_scope), include_set_(include_set) {} - - Scope GetScope(const Expr& e); - Scope GetSubScope(const Expr& e, size_t i); - - Expr VisitExpr(const Expr& e, const Var& v) final; - Expr VisitExpr(const Expr& e); - - Expr Atomic(const Expr& e, const Var& v); - // Bind expression `now` to var `v` if the original expression is in the include set, or if - // v is already defined (e.g. coming from a Let expression). Otherwise return `now` directly. - Expr Compound(const Expr& orig, const Expr& now, const Var& v); - - Expr VisitExpr_(const CallNode* c, const Var& v) final; - Expr VisitExpr_(const TupleNode* t, const Var& v) final; - Expr VisitExpr_(const TupleGetItemNode* t, const Var& v) final; - Expr VisitExpr_(const RefCreateNode* r, const Var& v) final; - Expr VisitExpr_(const RefReadNode* r, const Var& v) final; - Expr VisitExpr_(const RefWriteNode* r, const Var& v) final; - Expr VisitExpr_(const IfNode* i, const Var& v) final; - Expr VisitExpr_(const FunctionNode* f, const Var& v) final; - Expr VisitExpr_(const LetNode* l, const Var& v) final; - Expr VisitExpr_(const ConstantNode* c, const Var& v) final; - Expr VisitExpr_(const VarNode* vn, const Var& v) final; - Expr VisitExpr_(const GlobalVarNode* gvn, const Var& v) final; - Expr VisitExpr_(const OpNode* op, const Var& v) final; - Expr VisitExpr_(const ConstructorNode* c, const Var& v) final; - Expr VisitExpr_(const MatchNode* m, const Var& v) final; -}; +// For basic block normal form. +Expr ToBasicBlockNormalFormAux(const Optional& maybe_mod, const Expr& e); + +// ToANormalForm for expressions and as a Pass are declared in transform.h } // namespace relay } // namespace tvm diff --git a/src/relay/transforms/pattern_utils.h b/src/relay/transforms/pattern_utils.h index 920ac153b63d..692ef3c9f557 100644 --- a/src/relay/transforms/pattern_utils.h +++ b/src/relay/transforms/pattern_utils.h @@ -696,10 +696,6 @@ static inline Expr BroadCastTo(Expr data, Array shape) { return MakeBroadCastTo(data, CheckConstantShapeArrayInteger(shape)); } -Expr StopFusion(Expr data); - -Expr CastHint(Expr data, DataType dtype); - } // namespace relay } // namespace tvm #endif // TVM_RELAY_TRANSFORMS_PATTERN_UTILS_H_ diff --git a/src/relay/transforms/split_args.cc b/src/relay/transforms/split_args.cc index 70d37d822d71..eb647ce5e2a5 100644 --- a/src/relay/transforms/split_args.cc +++ b/src/relay/transforms/split_args.cc @@ -23,7 +23,8 @@ #include #include -#include "pattern_utils.h" +#include "../op/annotation/annotation.h" +#include "./pattern_utils.h" namespace tvm { namespace relay { diff --git a/src/relay/transforms/to_a_normal_form.cc b/src/relay/transforms/to_a_normal_form.cc index 91e8d90c1232..1dc45d38518e 100644 --- a/src/relay/transforms/to_a_normal_form.cc +++ b/src/relay/transforms/to_a_normal_form.cc @@ -30,8 +30,10 @@ #include "../../support/arena.h" #include "../analysis/dependency_graph.h" -#include "let_list.h" -#include "pass_utils.h" +#include "../op/annotation/annotation.h" +#include "./device_aware_visitors.h" +#include "./let_list.h" +#include "./pass_utils.h" namespace tvm { namespace relay { @@ -94,192 +96,319 @@ std::pair CalcScope(const DependencyGraph& dg) { return std::make_pair(expr_scope, lifted_exprs); } -Expr Fill::ToANormalForm(const Expr& e, const DependencyGraph& dg, NodeScopeMap* node_scope) { - Fill fi(dg, node_scope, nullptr); - return fi.GetScope(e)->let_list->Get(fi.VisitExpr(e)); -} +namespace { -// For basic block normal form, bind expressions only if the original expression's scope -// should be lifted -Expr Fill::ToBasicBlockNormalForm(const Expr& e, const DependencyGraph& dg, - NodeScopeMap* node_scope, ExprSet* lifted) { - Fill fi(dg, node_scope, lifted); - auto var = fi.VisitExpr(e); - return fi.GetScope(e)->let_list->Get(var); -} +/* Special care is needed to handle local recursion. + * Fill additionally take a (possibly null) Var argument, + * If it is not null, Fill is required to bind the transformed result to that var. + * + * ToANormalForm and PlanDevices + * ----------------------------- + * If PlanDevices has run this transform must respect the lexical scoping rules for the residual + * "on_device" calls. Eg: + * \code + * on_device(add(subtract(x, y), add(y, z)), device_type=2, is_fixed=true) + * ==> + * let %x0 = on_device(subtract(x, y), device_type=2, is_fixed=true) + * let %x1 = on_device(add(y, z), device_type=2, is_fixed=true) + * let %x2 = on_device(add(%x0, %x1), device_type=2, is_fixed=true) + * %x2 + * \endcode + * + * In addition to conversion to ANF this pass is also handling hoisting implicitly shared + * sub-expressions to the inner-most scope common to all their uses: + * \code + * on_device( + * if y { + * on_device(%0, device_type=2, is_fixed=true) + * } else { + * on_device(subtract(%0, b), device_type=2, is_fixed=true) + * }, + * device_type=1, is_fixed=true) + * (where %0 = add(a, b)) + * ==> + * let %x0 = on_device(add(a, b), device_type=2, is_fixed=true); + * on_device( + * if y { + * on_device(%x0, device_type=2, is_fixed=true) + * } else { + * let %x1 = on_device(subtract(%x0, b), device_type=2, is_fixed=true); + * %x1 + * }, + * device_type=1, is_fixed=true) + * \endcode + * Though the PlanDevices has already avoided inserting "on_device" calls where they are redundant + * due to lexical scope, it's fiddly to do the same in this pass since the notion of 'scope' is + * now determined by the scope map. So we'll just insert them mechanically on every let-binding. + * + * TODO(mbs): Rewrite to derive from DeviceAwareExprMutator and not track device types + * explicitly. It's easy to get rid of the need for the extra var argument on VisitExpr by shifting + * the recursion a '1/2 step' to return a possibly compound expression who's inner expressions are + * all atomic. However the use of the scope map is currently subtle enough I want to leave it + * alone for now. + */ +class Fill : ExprFunctor, private transform::LexicalOnDeviceMixin { + public: + static Expr ToANormalForm(const Optional& maybe_mod, const Expr& e, + const DependencyGraph& dg, NodeScopeMap* node_scope) { + Fill fi(maybe_mod, dg, node_scope, nullptr); + return fi.GetScope(e)->let_list->Get(fi.VisitExpr(e)); + } -Scope Fill::GetScope(const Expr& e) { return node_scope_->at(dg_.expr_node.at(e)); } + // For basic block normal form, bind expressions only if the original expression's scope + // should be lifted + static Expr ToBasicBlockNormalForm(const Optional& maybe_mod, const Expr& e, + const DependencyGraph& dg, NodeScopeMap* node_scope, + ExprSet* lifted) { + Fill fi(maybe_mod, dg, node_scope, lifted); + return fi.GetScope(e)->let_list->Get(fi.VisitExpr(e)); + } -Scope Fill::GetSubScope(const Expr& e, size_t i) { - DependencyGraph::Node* n = dg_.expr_node.at(e); - auto h = n->children.head; - while (i != 0) { + private: + Fill(const Optional& maybe_mod, const DependencyGraph& dg, NodeScopeMap* node_scope, + ExprSet* include_set) + : transform::LexicalOnDeviceMixin(maybe_mod), + dg_(dg), + node_scope_(node_scope), + include_set_(include_set) {} + + Scope GetScope(const Expr& e) { return node_scope_->at(dg_.expr_node.at(e)); } + + Scope GetSubScope(const Expr& e, size_t i) { + DependencyGraph::Node* n = dg_.expr_node.at(e); + auto h = n->children.head; + while (i != 0) { + ICHECK(h); + --i; + h = h->next; + } ICHECK(h); - --i; - h = h->next; + return node_scope_->at(h->value); } - ICHECK(h); - return node_scope_->at(h->value); -} -Expr Fill::VisitExpr(const Expr& e, const Var& v) { - if (memo.count(e) == 0) { - memo.insert({e, ExprFunctor::VisitExpr(e, v)}); - } else if (v.defined()) { - GetScope(e)->let_list->Push(v, memo.at(e)); - } - auto ret = memo.at(e); - // if no include_set is specified, every expression should be atomic. - if (include_set_ == nullptr) ICHECK(IsAtomic(ret)); - return ret; -} + Expr VisitExpr(const Expr& e) { return this->VisitExpr(e, Var()); } -Expr Fill::VisitExpr(const Expr& e) { return this->VisitExpr(e, Var()); } + Expr VisitExpr(const Expr& e, const Var& v) final { + if (memo.count(e) == 0) { + memo.insert({e, ExprFunctor::VisitExpr(e, v)}); + } else if (v.defined()) { + GetScope(e)->let_list->Push(v, memo.at(e)); + } + auto ret = memo.at(e); + // if no include_set is specified, every expression should be atomic. + // TODO(mbs): Note that Constants must be let-bound even though they are considered 'atomic' + // by this test. + if (include_set_ == nullptr && function_nesting() > 0) { + ICHECK(IsAtomic(ret)) << "expression:" << std::endl << PrettyPrint(ret); + } + return ret; + } -Expr Fill::Atomic(const Expr& e, const Var& v) { - return v.defined() ? GetScope(e)->let_list->Push(v, e) : e; -} + Expr Atomic(const Expr& e, const Var& v) { + Expr annotated_expr = MaybeOnDevice(e, GetInScopeDeviceType(e), /*is_fixed=*/true); + return v.defined() ? GetScope(e)->let_list->Push(v, annotated_expr) : annotated_expr; + } -// Bind expression `now` to var `v` if the original expression is in the include set, or if -// v is already defined (e.g. coming from a Let expression). Otherwise return `now` directly -Expr Fill::Compound(const Expr& orig, const Expr& now, const Var& v) { - Var var = v.defined() ? v : Var(String("x"), Type()); - bool not_included = include_set_ && include_set_->find(orig) == include_set_->end(); - if (!v.defined() && not_included) { - return now; - } else { - return GetScope(orig)->let_list->Push(var, now); + // Bind expression `now` to var `v` if the original expression is in the include set, or if + // v is already defined (e.g. coming from a Let expression). Otherwise return `now` directly + Expr Compound(const Expr& orig, const Expr& now, const Var& v) { + Expr annotated_expr = MaybeOnDevice(now, GetInScopeDeviceType(orig), /*is_fixed=*/true); + Var var = v.defined() ? v : Var(String("x"), Type()); + bool not_included = include_set_ && include_set_->find(orig) == include_set_->end(); + if (!v.defined() && not_included) { + return annotated_expr; + } else { + return GetScope(orig)->let_list->Push(var, annotated_expr); + } } -} -Expr Fill::VisitExpr_(const CallNode* c, const Var& v) { - Expr e = GetRef(c); - std::vector args; - for (const auto& a : c->args) { - args.push_back(VisitExpr(a)); + Expr VisitExpr_(const CallNode* c, const Var& v) final { + auto props = GetOnDeviceProps(c); + if (props.body.defined() && props.is_fixed) { + // Keep track of expression device type for lexically enclosing sub-expressions. + PushDeviceType(props.device_type); + Expr body = VisitExpr(props.body, v); + // We are done with this sub-expression. + PopDeviceType(); + // Preserve the "on_device" annotations. + return OnDevice(body, props.device_type, props.is_fixed); + } + + Expr e = GetRef(c); + std::vector args; + for (const auto& a : c->args) { + args.push_back(VisitExpr(a)); + } + return Compound(e, Call(VisitExpr(c->op), args, c->attrs, c->type_args), v); } - return Compound(e, Call(VisitExpr(c->op), args, c->attrs, c->type_args), v); -} -Expr Fill::VisitExpr_(const TupleNode* t, const Var& v) { - Expr e = GetRef(t); - std::vector fields; - for (const auto& a : t->fields) { - fields.push_back(VisitExpr(a)); + Expr VisitExpr_(const TupleNode* t, const Var& v) final { + Expr e = GetRef(t); + std::vector fields; + for (const auto& a : t->fields) { + fields.push_back(VisitExpr(a)); + } + return Compound(e, Tuple(fields), v); } - return Compound(e, Tuple(fields), v); -} -Expr Fill::VisitExpr_(const TupleGetItemNode* t, const Var& v) { - Expr e = GetRef(t); - return Compound(e, TupleGetItem(VisitExpr(t->tuple), t->index), v); -} + Expr VisitExpr_(const TupleGetItemNode* t, const Var& v) final { + Expr e = GetRef(t); + return Compound(e, TupleGetItem(VisitExpr(t->tuple), t->index), v); + } -Expr Fill::VisitExpr_(const RefCreateNode* r, const Var& v) { - Expr e = GetRef(r); - return Compound(e, RefCreate(VisitExpr(r->value)), v); -} + Expr VisitExpr_(const RefCreateNode* r, const Var& v) final { + Expr e = GetRef(r); + return Compound(e, RefCreate(VisitExpr(r->value)), v); + } -Expr Fill::VisitExpr_(const RefReadNode* r, const Var& v) { - Expr e = GetRef(r); - return Compound(e, RefRead(VisitExpr(r->ref)), v); -} + Expr VisitExpr_(const RefReadNode* r, const Var& v) final { + Expr e = GetRef(r); + return Compound(e, RefRead(VisitExpr(r->ref)), v); + } -Expr Fill::VisitExpr_(const RefWriteNode* r, const Var& v) { - Expr e = GetRef(r); - return Compound(e, RefWrite(VisitExpr(r->ref), VisitExpr(r->value)), v); -} + Expr VisitExpr_(const RefWriteNode* r, const Var& v) final { + Expr e = GetRef(r); + return Compound(e, RefWrite(VisitExpr(r->ref), VisitExpr(r->value)), v); + } -Expr Fill::VisitExpr_(const IfNode* i, const Var& v) { - Expr e = GetRef(i); - Expr ret = If(VisitExpr(i->cond), GetSubScope(e, 1)->let_list->Get(VisitExpr(i->true_branch)), - GetSubScope(e, 2)->let_list->Get(VisitExpr(i->false_branch))); - return Compound(e, ret, v); -} + Expr VisitExpr_(const IfNode* i, const Var& v) final { + Expr e = GetRef(i); + Expr ret = If(VisitExpr(i->cond), GetSubScope(e, 1)->let_list->Get(VisitExpr(i->true_branch)), + GetSubScope(e, 2)->let_list->Get(VisitExpr(i->false_branch))); + return Compound(e, ret, v); + } -Expr Fill::VisitExpr_(const FunctionNode* f, const Var& v) { - Expr e = GetRef(f); - Expr ret; - if (f->HasNonzeroAttr(attr::kPrimitive)) { - ret = e; - } else { - ret = Function(f->params, GetSubScope(e, 0)->let_list->Get(VisitExpr(f->body)), f->ret_type, - f->type_params, f->attrs); + Expr VisitExpr_(const FunctionNode* f, const Var& v) final { + Expr e = GetRef(f); + Expr ret; + if (f->HasNonzeroAttr(attr::kPrimitive)) { + ret = e; + } else { + // Keep track of expression and bound variable device types for lexically enclosing + // sub-expressions. + PushDeviceType(GetFunctionResultDeviceType(f)); + for (size_t i = 0; i < f->params.size(); ++i) { + PushBoundVar(f->params[i], GetFunctionParamDeviceType(f, i)); + } + EnterFunctionBody(); + ret = Function(f->params, GetSubScope(e, 0)->let_list->Get(VisitExpr(f->body)), f->ret_type, + f->type_params, f->attrs); + // We are done with this function. + ExitFunctionBody(); + for (size_t i = 0; i < f->params.size(); ++i) { + PopBoundVar(f->params[i]); + } + PopDeviceType(); + } + if (function_nesting() == 0) { + ICHECK(!v.defined()); + // This is a global function which can be bound directly in the module. + return ret; + } else { + // This is a local function which must be let-bound. + return Compound(e, ret, v); + } } - return Compound(e, ret, v); -} -Expr Fill::VisitExpr_(const LetNode* l, const Var& v) { - Expr e = GetRef(l); - VisitExpr(l->value, l->var); - Expr ret = GetSubScope(e, 0)->let_list->Get(VisitExpr(l->body)); - return Compound(e, ret, v); -} + Expr VisitExpr_(const LetNode* l, const Var& v) final { + Expr e = GetRef(l); + // Keep track of bound variable device types for lexically enclosing sub-expressions. + PushBoundVar(l->var, GetInScopeDeviceType(l->value)); + VisitExpr(l->value, l->var); + Expr ret = GetSubScope(e, 0)->let_list->Get(VisitExpr(l->body)); + // We are done with these sub-expressions. + PopBoundVar(l->var); + return Compound(e, ret, v); + } -Expr Fill::VisitExpr_(const ConstantNode* c, const Var& v) { - Expr e = GetRef(c); - return Compound(e, e, v); -} + Expr VisitExpr_(const ConstantNode* c, const Var& v) final { + Expr e = GetRef(c); + return Compound(e, e, v); + } -Expr Fill::VisitExpr_(const VarNode* vn, const Var& v) { - Expr e = GetRef(vn); - return Atomic(e, v); -} + Expr VisitExpr_(const VarNode* vn, const Var& v) final { + Expr e = GetRef(vn); + return Atomic(e, v); + } -Expr Fill::VisitExpr_(const GlobalVarNode* gvn, const Var& v) { - GlobalVar gv = GetRef(gvn); - return Atomic(gv, v); -} + Expr VisitExpr_(const GlobalVarNode* gvn, const Var& v) final { + GlobalVar gv = GetRef(gvn); + return Atomic(gv, v); + } -Expr Fill::VisitExpr_(const OpNode* op, const Var& v) { - Expr e = GetRef(op); - return Atomic(e, v); -} + Expr VisitExpr_(const OpNode* op, const Var& v) final { + Expr e = GetRef(op); + return Atomic(e, v); + } -Expr Fill::VisitExpr_(const ConstructorNode* c, const Var& v) { - Expr e = GetRef(c); - return Atomic(e, v); -} + Expr VisitExpr_(const ConstructorNode* c, const Var& v) final { + Expr e = GetRef(c); + return Atomic(e, v); + } -Expr Fill::VisitExpr_(const MatchNode* m, const Var& v) { - Expr e = GetRef(m); - Expr data = VisitExpr(m->data); - std::vector clauses; - for (const Clause& c : m->clauses) { - clauses.push_back( - Clause(c->lhs, GetSubScope(e, 1 + clauses.size())->let_list->Get(VisitExpr(c->rhs)))); + Expr VisitExpr_(const MatchNode* m, const Var& v) final { + Expr e = GetRef(m); + Expr data = VisitExpr(m->data); + std::vector clauses; + for (const Clause& c : m->clauses) { + clauses.emplace_back(c->lhs, + GetSubScope(e, 1 + clauses.size())->let_list->Get(VisitExpr(c->rhs))); + } + return Compound(e, Match(data, clauses, m->complete), v); } - return Compound(e, Match(data, clauses, m->complete), v); -} -IRModule ToANormalForm(const IRModule& m) { - DLOG(INFO) << "ToANF:" << std::endl << m; + const DependencyGraph& dg_; + NodeScopeMap* node_scope_ = nullptr; + std::unordered_map memo; + // a set of Expressions to include for let bindings. If set to nullptr + // all Exprs will be pushed to the let list. + ExprSet* include_set_ = nullptr; +}; +IRModule ModuleToANormalForm(const IRModule& mod) { tvm::Map updates; - auto funcs = m->functions; + auto funcs = mod->functions; for (const auto& it : funcs) { ICHECK_EQ(FreeVars(it.second).size(), 0); if (const auto* n = it.second.as()) { if (n->GetAttr(attr::kCompiler).defined()) continue; + Function func = GetRef(n); + Function ret = Downcast(transform::ToANormalForm(mod, func)); + ICHECK_EQ(FreeVars(ret).size(), 0) << "rewritten:" << std::endl + << PrettyPrint(ret) << std::endl + << "should not have free vars: " << FreeVars(ret); + VLOG(1) << "rewritten:" << std::endl + << PrettyPrint(func) << std::endl + << "to ANF:" << std::endl + << PrettyPrint(ret); + updates.Set(it.first, ret); } - Expr ret = TransformF([&](const Expr& e) { return transform::ToANormalForm(e); }, it.second); - ICHECK_EQ(FreeVars(ret).size(), 0) - << AsText(ret) << "should not has free vars: " << FreeVars(ret); - updates.Set(it.first, Downcast(ret)); } for (auto pair : updates) { - m->Add(pair.first, pair.second, true); + mod->Add(pair.first, pair.second, true); } - DLOG(INFO) << "ToANF: transformed" << std::endl << m; + return mod; +} + +} // namespace - return m; +Expr ToBasicBlockNormalFormAux(const Optional& maybe_mod, const Expr& e) { + // calculate all the dependency between nodes. + support::Arena arena; + DependencyGraph dg = DependencyGraph::Create(&arena, e); + /* The scope of the whole expr is global. + * The scope of any subexpr, is the lowest common ancestor of all incoming edge. + * We also record the set of expressions whose scope is lifted. + */ + std::pair scopes = CalcScope(dg); + return Fill::ToBasicBlockNormalForm(maybe_mod, e, dg, &scopes.first, &scopes.second); } namespace transform { -Expr ToANormalForm(const Expr& e) { +Expr ToANormalForm(const Optional& maybe_mod, const Expr& e) { /* When you lift a lambda, what is inside is also being lift. * * So we must determine the scope of the lambda before determining the scope of it's body. @@ -302,12 +431,12 @@ Expr ToANormalForm(const Expr& e) { * We do an additional pass to fill all the LetList and we are done. */ std::pair scopes = CalcScope(dg); - return Fill::ToANormalForm(e, dg, &scopes.first); + return Fill::ToANormalForm(maybe_mod, e, dg, &scopes.first); } Pass ToANormalForm() { runtime::TypedPackedFunc pass_func = - [=](IRModule m, PassContext pc) { return relay::ToANormalForm(m); }; + [=](IRModule m, PassContext pc) { return ModuleToANormalForm(m); }; return CreateModulePass(pass_func, 1, "ToANormalForm", {}); } @@ -316,7 +445,7 @@ TVM_REGISTER_GLOBAL("relay._transform.ToANormalForm").set_body_typed([]() { }); TVM_REGISTER_GLOBAL("relay._transform.ToANormalFormExpr").set_body_typed([](const Expr& e) { - return ToANormalForm(e); + return ToANormalForm(Optional(), e); }); } // namespace transform diff --git a/src/relay/transforms/to_basic_block_normal_form.cc b/src/relay/transforms/to_basic_block_normal_form.cc index 8e952d60b8b7..826006b0e603 100644 --- a/src/relay/transforms/to_basic_block_normal_form.cc +++ b/src/relay/transforms/to_basic_block_normal_form.cc @@ -30,48 +30,36 @@ #include "../../support/arena.h" #include "../analysis/dependency_graph.h" -#include "let_list.h" -#include "pass_utils.h" +#include "./pass_utils.h" namespace tvm { namespace relay { -Expr ToBasicBlockNormalFormAux(const Expr& e) { - // calculate all the dependency between nodes. - support::Arena arena; - DependencyGraph dg = DependencyGraph::Create(&arena, e); - /* The scope of the whole expr is global. - * The scope of any subexpr, is the lowest common ancestor of all incoming edge. - * We also record the set of expressions whose scope is lifted. - */ - std::pair scopes = CalcScope(dg); - return Fill::ToBasicBlockNormalForm(e, dg, &scopes.first, &scopes.second); -} - IRModule ToBasicBlockNormalForm(const IRModule& mod) { - DLOG(INFO) << "ToBBlock:" << std::endl << mod; - // Create a new module by shallow copy. - IRModule mod_ = mod->ShallowCopy(); + IRModule new_mod = mod->ShallowCopy(); tvm::Map updates; - auto funcs = mod_->functions; + auto funcs = new_mod->functions; for (const auto& it : funcs) { ICHECK_EQ(FreeVars(it.second).size(), 0) << "Expected no free variables"; if (const auto* n = it.second.as()) { if (n->GetAttr(attr::kCompiler).defined()) continue; + Function func = GetRef(n); + Function ret = Downcast(ToBasicBlockNormalFormAux(mod, func)); + VLOG(1) << "rewritten:" << std::endl + << PrettyPrint(func) << std::endl + << "to BasicBlockANF:" << std::endl + << PrettyPrint(ret); + updates.Set(it.first, Downcast(ret)); } - Expr ret = TransformF([&](const Expr& e) { return ToBasicBlockNormalFormAux(e); }, it.second); - updates.Set(it.first, Downcast(ret)); } for (auto pair : updates) { - mod_->Add(pair.first, pair.second, true); + new_mod->Add(pair.first, pair.second, true); } - DLOG(INFO) << "ToBBlock: transformed" << std::endl << mod_; - - return mod_; + return new_mod; } bool BasicBlockNormalFormCheck(const Expr& e) { diff --git a/src/relay/transforms/type_infer.cc b/src/relay/transforms/type_infer.cc index ebdf1fed2fab..5ca6d86b1d52 100644 --- a/src/relay/transforms/type_infer.cc +++ b/src/relay/transforms/type_infer.cc @@ -486,8 +486,9 @@ class TypeInferencer : private ExprFunctor, if (type_args.size() > fn_ty_node->type_params.size()) { this->EmitFatal(Diagnostic::Error(call->span) << "Incorrect number of type args in " << call->span << ": " - << "Expected " << fn_ty_node->type_params.size() << "but got " - << type_args.size()); + << "Expected " << fn_ty_node->type_params.size() << " but got " + << type_args.size() << " for call:\n" + << PrettyPrint(GetRef(call))); } for (size_t i = type_args.size(); i < fn_ty_node->type_params.size(); i++) { type_args.push_back(IncompleteType(TypeKind::kType)); diff --git a/src/runtime/vm/serialize_utils.h b/src/runtime/vm/serialize_utils.h index b4a10806caaf..cbcdb1bdfa16 100644 --- a/src/runtime/vm/serialize_utils.h +++ b/src/runtime/vm/serialize_utils.h @@ -59,13 +59,13 @@ struct VMFunctionSerializer { /*! \brief The parameters of the VMFunction. */ std::vector params; /*! \brief The device type of each parameter of the VMFunction. */ - std::vector params_device_type; + std::vector params_device_type; VMFunctionSerializer() = default; VMFunctionSerializer(const std::string& name, Index register_file_size, size_t num_instructions, const std::vector& params, - const std::vector& params_device_type) + const std::vector& params_device_type) : name(name), register_file_size(register_file_size), num_instructions(num_instructions), diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index 4df013baa2fb..c7a1baa1430d 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -233,7 +233,7 @@ void VirtualMachine::SetInput(std::string func_name, TVMArgs args, int offset) { << "The number of provided parameters doesn't match the number of assigned devices"; std::vector func_args(param_names.size()); for (int i = offset; i < args.size(); ++i) { - Index device_type = vm_func.params_device_type[i - offset]; + DLDeviceType device_type = vm_func.params_device_type[i - offset]; Device dev = GetDevice(device_type); if (args[i].type_code() == kTVMDLTensorHandle) { @@ -664,7 +664,7 @@ void VirtualMachine::RunLoop() { NDArray shape_tensor = Downcast(CopyTo(shape_obj, cpu_dev)); const DLTensor* dl_tensor = shape_tensor.operator->(); ICHECK_EQ(dl_tensor->dtype.code, 0u); - ICHECK_EQ(dl_tensor->dtype.bits, 64); + ICHECK_EQ(dl_tensor->dtype.bits, 64u); int64_t* dims = reinterpret_cast(dl_tensor->data); int64_t ndim = shape_tensor->shape[0]; std::vector shape(dims, dims + ndim); diff --git a/tests/cpp/relay/relay/transforms/device_domains_test.cc b/tests/cpp/relay/transforms/device_domains_test.cc similarity index 100% rename from tests/cpp/relay/relay/transforms/device_domains_test.cc rename to tests/cpp/relay/transforms/device_domains_test.cc diff --git a/tests/python/relay/test_pass_annotation.py b/tests/python/relay/test_pass_annotation.py deleted file mode 100644 index c33bd5792242..000000000000 --- a/tests/python/relay/test_pass_annotation.py +++ /dev/null @@ -1,663 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Unit tests for heterogeneous compilation and execution.""" -import json -import numpy as np - -import tvm -from tvm import relay -from tvm.contrib import graph_executor -from tvm.relay.expr_functor import ExprMutator -from tvm.relay import transform -from tvm.ir.instrument import pass_instrument -import tvm.testing - - -@tvm.instrument.pass_instrument -class Trace: - def run_before_pass(self, module, pass_info): - if pass_info.name == "ManifestAlloc": - pass # import pdb; pdb.set_trace() - - def run_after_pass(self, module, pass_info): - if pass_info.name == "ManifestAlloc": - pass # import pdb; pdb.set_trace() - - -def check_graph_executor( - target, ref_res, device, func, params, config, opt_level, expected_index=None -): - with tvm.transform.PassContext(opt_level=opt_level, config=config): - graph_executor_factory = relay.build(func, target, params=params) - - contexts = [tvm.cpu(0), tvm.device(device)] - graph_json = json.loads(graph_executor_factory.graph_json) - if "device_index" in graph_json["attrs"]: - device_index = graph_json["attrs"]["device_index"][1] - assert device_index == expected_index - mod = graph_executor.GraphModule(graph_executor_factory["default"](*contexts)) - mod.run() - res = mod.get_output(0).numpy() - tvm.testing.assert_allclose(res, ref_res, rtol=1e-5, atol=1e-5) - - -def check_vm_runtime(target, ref_res, device, func, params, config, opt_level, expected_index=None): - with tvm.transform.PassContext(opt_level=opt_level, instruments=[Trace()], config=config): - mod = tvm.IRModule() - mod["main"] = func - exe = relay.vm.compile(mod, target) - dev = [tvm.cpu(0), tvm.device(device)] - vm = tvm.runtime.vm.VirtualMachine(exe, dev) - res = vm.invoke("main", **params) - tvm.testing.assert_allclose(res.numpy(), ref_res, rtol=1e-5, atol=1e-5) - - -def run_opt_pass(expr, passes): - passes = passes if isinstance(passes, list) else [passes] - mod = tvm.IRModule.from_expr(expr) - seq = tvm.transform.Sequential(passes) - with tvm.transform.PassContext(opt_level=3): - mod = seq(mod) - return mod["main"] - - -def test_redundant_annotation(): - dev1 = tvm.device(1) - dev2 = tvm.device(2) - x = relay.var("x", shape=(3,)) - y = relay.var("y", shape=(3,)) - z = relay.var("z", shape=(3,)) - - def annotated(): - add = relay.add(x, y) - _add1 = relay.annotation.on_device(add, dev2) - _add2 = relay.annotation.on_device(add, dev2) - sub1 = relay.subtract(_add1, z) - sub2 = relay.subtract(_add2, z) - - func = relay.Function([x, y, z], relay.Tuple([sub1, sub2])) - func = run_opt_pass(func, transform.RewriteAnnotatedOps(dev1.device_type)) - return func - - def expected(): - add = relay.add(x, y) - copy_add_sub1 = relay.device_copy(add, dev2, dev1) - sub1 = relay.subtract(copy_add_sub1, z) - copy_add_sub2 = relay.device_copy(add, dev2, dev1) - sub2 = relay.subtract(copy_add_sub2, z) - func = relay.Function([x, y, z], relay.Tuple([sub1, sub2])) - return func - - annotated_func = annotated() - expected_func = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(annotated_func, expected_func) - - -def test_annotate_expr(): - dev1 = tvm.device(1) - dev2 = tvm.device(2) - x = relay.var("x", shape=(3,)) - y = relay.var("y", shape=(3,)) - z = relay.var("z", shape=(3,)) - - def annotated(): - add = relay.add(x, y) - _add = relay.annotation.on_device(add, dev1) - sub = relay.subtract(_add, z) - _sub = relay.annotation.on_device(sub, dev2) - expr = run_opt_pass(_sub, transform.RewriteAnnotatedOps(dev1.device_type)) - return expr - - def expected(): - add = relay.add(x, y) - copy_add_sub = relay.device_copy(add, dev1, dev2) - sub = relay.subtract(copy_add_sub, z) - return sub - - annotated_expr = annotated() - expected_expr = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(annotated_expr, expected_expr) - - -def test_annotate_all(): - dev1 = tvm.device(1) - dev2 = tvm.device(2) - x = relay.var("x", shape=(3,)) - y = relay.var("y", shape=(3,)) - z = relay.var("z", shape=(3,)) - - def annotated(): - add = relay.add(x, y) - _add = relay.annotation.on_device(add, dev2) - sub = relay.subtract(_add, z) - _sub = relay.annotation.on_device(sub, dev2) - - func = relay.Function([x, y, z], _sub) - func = run_opt_pass(func, transform.RewriteAnnotatedOps(dev1.device_type)) - return func - - def expected(): - add = relay.add(x, y) - sub = relay.subtract(add, z) - func = relay.Function([x, y, z], sub) - return func - - annotated_func = annotated() - expected_func = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(annotated_func, expected_func) - - -def test_annotate_none(): - dev1 = tvm.device(1) - dev2 = tvm.device(2) - x = relay.var("x", shape=(3,)) - y = relay.var("y", shape=(3,)) - z = relay.var("z", shape=(3,)) - - def annotated(): - add = relay.add(x, y) - sub = relay.subtract(add, z) - func = relay.Function([x, y, z], sub) - func = run_opt_pass(func, transform.RewriteAnnotatedOps(dev1.device_type)) - return func - - def expected(): - add = relay.add(x, y) - sub = relay.subtract(add, z) - func = relay.Function([x, y, z], sub) - return func - - annotated_func = annotated() - expected_func = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(annotated_func, expected_func) - - -def check_annotated_graph(annotated_func, expected_func): - annotated_func = run_opt_pass(annotated_func, transform.InferType()) - expected_func = run_opt_pass(expected_func, transform.InferType()) - assert tvm.ir.structural_equal(annotated_func, expected_func) - - -def test_conv_network(): - r"""The network is as following: - data1 data2 - | | - conv2d conv2d - \ / - add - | - conv2d - """ - batch_size = 1 - dshape = (batch_size, 64, 56, 56) - weight = relay.var("weight", shape=(64, 64, 3, 3)) - data1 = relay.var("data1", shape=dshape) - data2 = relay.var("data2", shape=dshape) - dev1 = tvm.device(1) - dev2 = tvm.device(2) - - def original(): - conv2d_1 = relay.nn.conv2d(data1, weight, channels=64, kernel_size=(3, 3), padding=(1, 1)) - conv2d_2 = relay.nn.conv2d(data2, weight, channels=64, kernel_size=(3, 3), padding=(1, 1)) - add = relay.add(conv2d_1, conv2d_2) - conv2d_3 = relay.nn.conv2d(add, weight, channels=64, kernel_size=(3, 3), padding=(1, 1)) - - func = relay.Function([data1, data2, weight], conv2d_3) - func = run_opt_pass(func, transform.RewriteAnnotatedOps(tvm.device(3).device_type)) - return func - - def annotated(): - conv2d_1 = relay.nn.conv2d(data1, weight, channels=64, kernel_size=(3, 3), padding=(1, 1)) - _conv2d_1 = relay.annotation.on_device(conv2d_1, dev2) - conv2d_2 = relay.nn.conv2d(data2, weight, channels=64, kernel_size=(3, 3), padding=(1, 1)) - _conv2d_2 = relay.annotation.on_device(conv2d_2, dev2) - add = relay.add(_conv2d_1, _conv2d_2) - _add = relay.annotation.on_device(add, dev1) - conv2d_3 = relay.nn.conv2d(_add, weight, channels=64, kernel_size=(3, 3), padding=(1, 1)) - _conv2d_3 = relay.annotation.on_device(conv2d_3, dev2) - - func = relay.Function([data1, data2, weight], _conv2d_3) - func = run_opt_pass(func, transform.RewriteAnnotatedOps(tvm.device(3).device_type)) - return func - - class ScheduleConv2d(ExprMutator): - def __init__(self, device): - self.device = device - super().__init__() - - def visit_call(self, expr): - visit = super().visit_call(expr) - if expr.op == tvm.relay.op.get("nn.conv2d"): - return relay.annotation.on_device(visit, self.device) - else: - return visit - - def annotate_with_visitor(func): - sched = ScheduleConv2d(dev2) - func = sched.visit(func) - func = run_opt_pass(func, transform.RewriteAnnotatedOps(dev1.device_type)) - return func - - def expected(): - conv2d_1 = relay.nn.conv2d(data1, weight, channels=64, kernel_size=(3, 3), padding=(1, 1)) - device_copy1 = relay.device_copy(conv2d_1, dev2, dev1) - conv2d_2 = relay.nn.conv2d(data2, weight, channels=64, kernel_size=(3, 3), padding=(1, 1)) - device_copy2 = relay.device_copy(conv2d_2, dev2, dev1) - add = relay.add(device_copy1, device_copy2) - device_copy3 = relay.device_copy(add, dev1, dev2) - conv2d_3 = relay.nn.conv2d( - device_copy3, weight, channels=64, kernel_size=(3, 3), padding=(1, 1) - ) - - func = relay.Function([data1, data2, weight], conv2d_3) - return func - - def check_storage_and_device_types(): - func = annotated() - func = run_opt_pass(func, [transform.RewriteAnnotatedOps(3), transform.FuseOps(2)]) - smap = relay.backend._backend.GraphPlanMemory(func) - storage_ids = [] - device_types = [] - for _, storage_info in smap.expr_to_storage_info.items(): - - for sid in storage_info.storage_ids: - storage_ids.append(sid.value) - - for did in storage_info.device_types: - device_types.append(did.value) - - assert len(storage_ids) == 10 - assert len(set(storage_ids)) == 8 - assert len(set(device_types)) == 2 - assert set(device_types) == {1, 2} - - def test_manual_annotation(): - annotated_func = annotated() - expected_func = expected() - check_annotated_graph(annotated_func, expected_func) - check_storage_and_device_types() - - def test_visitor_annotation(): - annotated_func = annotate_with_visitor(original()) - expected_func = expected() - check_annotated_graph(annotated_func, expected_func) - - test_manual_annotation() - test_visitor_annotation() - - -def test_propogation(): - R""" The network and device type is as following: - x 1 - | - log 1 - / \ - log2 log10 2 - \ / - add 2 - | - tan 1 - """ - dev1 = tvm.device(1) - dev2 = tvm.device(2) - - expected_dev_type = {"log": dev1, "log2": dev2, "log10": dev2, "add": dev2, "tan": dev1} - - x = relay.var("x", shape=(3,)) - - def annotated(): - log = relay.log(x) - _log = relay.annotation.on_device(log, expected_dev_type["log"]) - log2 = relay.log2(_log) - _log2 = relay.annotation.on_device(log2, expected_dev_type["log2"]) - log10 = relay.log10(_log) - _log10 = relay.annotation.on_device(log10, expected_dev_type["log10"]) - add = relay.add(_log2, _log10) - _add = relay.annotation.on_device(add, expected_dev_type["add"]) - tan = relay.tan(_add) - _tan = relay.annotation.on_device(tan, expected_dev_type["tan"]) - - func = run_opt_pass(_tan, transform.RewriteAnnotatedOps(dev1.device_type)) - return func - - def expected(): - log = relay.log(x) - _log_left = relay.device_copy(log, dev1, dev2) - _log_right = relay.device_copy(log, dev1, dev2) - log2 = relay.log2(_log_left) - log10 = relay.log10(_log_right) - add = relay.add(log2, log10) - _add = relay.device_copy(add, dev2, dev1) - tan = relay.tan(_add) - - func = run_opt_pass(tan, transform.InferType()) - return func - - annotated_expr = annotated() - expected_expr = expected() - assert tvm.ir.structural_equal(annotated_expr, expected_expr) - - smap = relay.backend._backend.GraphPlanMemory(annotated_expr) - for expr, storage_info in smap.expr_to_storage_info.items(): - # x is dev1 as output is dev1 - if isinstance(expr, tvm.relay.expr.Var): - assert storage_info.device_types[0] == dev1.device_type - else: - # device_copy op should be its dst_dev_type - if isinstance(expr.attrs, tvm.relay.op.op_attrs.DeviceCopyAttrs): - assert storage_info.device_types[0] == expr.attrs.dst_dev_type - else: - assert storage_info.device_types[0] == expected_dev_type[expr.op.name].device_type - - -def run_fusible_network(dev, tgt): - R""" The network is as following: - x y - \ / - add - / \ - sqrt log - \ / - subtract - | - exp - """ - x = relay.var("x", shape=(1, 10)) - y = relay.var("y", shape=(10, 10)) - x_data = np.random.rand(1, 10).astype("float32") - y_data = np.random.rand(10, 10).astype("float32") - tmp_add = x_data + y_data - tmp_sqrt = np.sqrt(tmp_add) - tmp_log = np.log(tmp_add) - tmp_sub = np.subtract(tmp_sqrt, tmp_log) - ref_res = np.exp(tmp_sub) - params = {"x": x_data, "y": y_data} - - def get_func(): - add = relay.add(x, y) - sqrt = relay.sqrt(add) - log = relay.log(add) - subtract = relay.subtract(sqrt, log) - exp = relay.exp(subtract) - - func = relay.Function([x, y], exp) - return func - - def test_fuse_log_add(device, tgt): - """Only log and add are fused.""" - fallback_device = tvm.device("cpu") - target = {"cpu": "llvm", device: tgt} - cpu_dev = fallback_device - dev_dev = tvm.device(device) - - def annotated(): - add = relay.add(x, y) - sqrt = relay.sqrt(add) - _sqrt = relay.annotation.on_device(sqrt, dev_dev) - log = relay.log(add) - subtract = relay.subtract(_sqrt, log) - exp = relay.exp(subtract) - _exp = relay.annotation.on_device(exp, dev_dev) - - func = relay.Function([x, y], _exp) - func = run_opt_pass(func, transform.RewriteAnnotatedOps(cpu_dev.device_type)) - return func - - def expected(): - add = relay.add(x, y) - copy_add_sqrt = relay.device_copy(add, cpu_dev, dev_dev) - sqrt = relay.sqrt(copy_add_sqrt) - log = relay.log(add) - copy_sqrt_subtract = relay.device_copy(sqrt, dev_dev, cpu_dev) - subtract = relay.subtract(copy_sqrt_subtract, log) - copy_sub_exp = relay.device_copy(subtract, cpu_dev, dev_dev) - exp = relay.exp(copy_sub_exp) - - func = relay.Function([x, y], exp) - return func - - annotated_func = annotated() - expected_func = expected() - dev = tvm.device(device, 0) - dev_idx = dev.device_type - expected_index = [1, 1, 1, dev_idx, dev_idx, 1, 1, dev_idx, dev_idx] - check_annotated_graph(annotated_func, expected_func) - opt_level = 1 - config = {"relay.fallback_device_type": fallback_device.device_type} - check_graph_executor( - target, ref_res, device, annotated_func, params, config, opt_level, expected_index - ) - opt_level = 2 - check_vm_runtime( - target, ref_res, device, annotated_func, params, config, opt_level, expected_index - ) - - def test_fuse_all(device, tgt): - """Fuse all operators.""" - fallback_device = tvm.device("cpu") - target = {"cpu": "llvm", device: tgt} - cpu_dev = fallback_device - dev_dev = tvm.device(device) - - def annotated(): - add = relay.add(x, y) - _add = relay.annotation.on_device(add, dev_dev) - sqrt = relay.sqrt(_add) - _sqrt = relay.annotation.on_device(sqrt, dev_dev) - log = relay.log(_add) - _log = relay.annotation.on_device(log, dev_dev) - subtract = relay.subtract(_sqrt, _log) - _subtract = relay.annotation.on_device(subtract, dev_dev) - exp = relay.exp(_subtract) - _exp = relay.annotation.on_device(exp, dev_dev) - - func = relay.Function([x, y], _exp) - func = run_opt_pass(func, transform.RewriteAnnotatedOps(cpu_dev.device_type)) - return func - - annotated_func = annotated() - expected_func = get_func() - check_annotated_graph(annotated_func, expected_func) - opt_level = 1 - config = {"relay.fallback_device_type": fallback_device.device_type} - check_graph_executor(target, ref_res, device, annotated_func, params, config, opt_level) - opt_level = 2 - check_vm_runtime(target, ref_res, device, annotated_func, params, config, opt_level) - - def test_fallback_exp(device, tgt): - fallback_device = tvm.device("cpu") - target = {"cpu": "llvm", device: tgt} - cpu_dev = fallback_device - dev_dev = tvm.device(device) - - def annotated(): - add = relay.add(x, y) - sqrt = relay.sqrt(add) - log = relay.log(add) - subtract = relay.subtract(sqrt, log) - exp = relay.exp(subtract) - _exp = relay.annotation.on_device(exp, cpu_dev) - - func = relay.Function([x, y], _exp) - func = run_opt_pass(func, transform.RewriteAnnotatedOps(dev_dev.device_type)) - return func - - def expected(): - add = relay.add(x, y) - sqrt = relay.sqrt(add) - log = relay.log(add) - subtract = relay.subtract(sqrt, log) - copy_sub_exp = relay.device_copy(subtract, dev_dev, cpu_dev) - exp = relay.exp(copy_sub_exp) - - func = relay.Function([x, y], exp) - return func - - annotated_func = annotated() - expected_func = expected() - dev = tvm.device(device, 0) - dev_idx = dev.device_type - expected_index = [dev_idx, dev_idx, dev_idx, 1, 1] - opt_level = 1 - config = {"relay.fallback_device_type": fallback_device.device_type} - check_annotated_graph(annotated_func, expected_func) - check_graph_executor( - target, ref_res, device, annotated_func, params, config, opt_level, expected_index - ) - opt_level = 2 - check_vm_runtime( - target, ref_res, device, annotated_func, params, config, opt_level, expected_index - ) - - def test_fallback_all_operators(device, tgt): - target = {device: tgt, "cpu": "llvm"} - annotated_func = get_func() - expected_func = get_func() - check_annotated_graph(annotated_func, expected_func) - opt_level = 2 - check_graph_executor(target, ref_res, device, annotated_func, params, {}, opt_level) - check_vm_runtime(target, ref_res, device, annotated_func, params, {}, opt_level) - - test_fuse_log_add(dev, tgt) - test_fuse_all(dev, tgt) - test_fallback_exp(dev, tgt) - test_fallback_all_operators(dev, tgt) - - -def run_unpropagatable_graph(dev, tgt): - r"""The network is as following: - a b c d - \ / \ / - add mul - \ / - subtract - """ - - a = relay.var("a", shape=(10, 10)) - b = relay.var("b", shape=(10, 10)) - c = relay.var("c", shape=(10, 10)) - d = relay.var("d", shape=(10, 10)) - a_data = np.random.rand(10, 10).astype("float32") - b_data = np.random.rand(10, 10).astype("float32") - c_data = np.random.rand(10, 10).astype("float32") - d_data = np.random.rand(10, 10).astype("float32") - tmp_add = a_data + b_data - tmp_mul = np.multiply(c_data, d_data) - ref_res = np.subtract(tmp_add, tmp_mul) - - fallback_device = tvm.device("cpu") - target = {"cpu": "llvm", dev: tgt} - cpu_dev = fallback_device - dev_dev = tvm.device(dev) - - def annotated(): - add = relay.add(a, b) - _add = relay.annotation.on_device(add, dev_dev) - mul = relay.multiply(c, d) - _mul = relay.annotation.on_device(mul, cpu_dev) - sub = relay.subtract(_add, _mul) - _sub = relay.annotation.on_device(sub, dev_dev) - func = relay.Function([a, b, c, d], _sub) - func = run_opt_pass(func, transform.RewriteAnnotatedOps(dev_dev.device_type)) - return func - - def expected(): - add = relay.add(a, b) - mul = relay.multiply(c, d) - copy_mul_sub = relay.device_copy(mul, cpu_dev, dev_dev) - sub = relay.subtract(add, copy_mul_sub) - func = relay.Function([a, b, c, d], sub) - return func - - annotated_func = annotated() - expected_func = expected() - expected_index = [2, 2, 2, 1, 1, 1, 2, 2] - check_annotated_graph(annotated_func, expected_func) - params = {"a": a_data, "b": b_data, "c": c_data, "d": d_data} - opt_level = 0 - config = {"relay.fallback_device_type": fallback_device.device_type} - - check_graph_executor( - target, ref_res, dev, annotated_func, params, config, opt_level, expected_index - ) - - opt_level = 2 - check_vm_runtime(target, ref_res, dev, annotated_func, params, config, opt_level) - - -@tvm.testing.requires_opencl -def test_check_run_opencl(): - dev = "opencl" - tgt = "opencl" - run_fusible_network(dev, tgt) - run_unpropagatable_graph(dev, tgt) - - -@tvm.testing.requires_opencl -def test_check_run_opencl_intel(): - dev = "opencl" - tgt = str(tvm.target.intel_graphics()) - run_fusible_network(dev, tgt) - run_unpropagatable_graph(dev, tgt) - - -@tvm.testing.requires_cuda -def test_check_run_cuda(): - dev = "cuda" - tgt = "cuda" - run_fusible_network(dev, tgt) - run_unpropagatable_graph(dev, tgt) - - -@tvm.testing.requires_cuda -def test_tuple_get_item(): - dev = "cuda" - cpu_dev = tvm.cpu(0) - gpu_dev = tvm.device(dev) - - def expected(): - x = relay.var("x", relay.ty.TensorType((3, 3, 4), "float32")) - split = relay.op.split(x, 3) - elem0 = relay.device_copy(split[0], gpu_dev, cpu_dev) - elem1 = relay.device_copy(split[1], gpu_dev, cpu_dev) - sub = elem0 - elem1 - func = relay.Function(relay.analysis.free_vars(sub), sub) - return func - - def annotated(): - x = relay.var("x", relay.ty.TensorType((3, 3, 4), "float32")) - split = relay.op.split(x, 3) - split = split.astuple() - split = relay.annotation.on_device(split, gpu_dev) - split = relay.TupleWrapper(split, 3) - sub = split[0] - split[1] - func = relay.Function(relay.analysis.free_vars(sub), sub) - func = run_opt_pass(func, transform.RewriteAnnotatedOps(cpu_dev.device_type)) - return func - - annotated_func = annotated() - expected_func = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(annotated_func, expected_func) - - -if __name__ == "__main__": - test_redundant_annotation() - test_annotate_expr() - test_annotate_all() - test_annotate_none() - test_conv_network() - test_tuple_get_item() diff --git a/tests/python/relay/test_pass_context_analysis.py b/tests/python/relay/test_pass_context_analysis.py deleted file mode 100644 index fe19c479292f..000000000000 --- a/tests/python/relay/test_pass_context_analysis.py +++ /dev/null @@ -1,205 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=no-else-return,invalid-name,len-as-condition,too-many-nested-blocks - -import numpy as np -import pytest - -import tvm -from tvm import relay -from tvm.relay import expr as _expr -from tvm.relay.analysis import context_analysis - - -def test_device_copy(): - if not tvm.testing.device_enabled("cuda") or not tvm.cuda(0).exist: - return - - mod = tvm.IRModule() - x = relay.var("x", shape=(2, 3)) - copy = relay.op.device_copy(x, tvm.cpu(), tvm.cuda()) - out = copy + relay.const(np.random.rand(2, 3)) - glb_var = relay.GlobalVar("main") - mod[glb_var] = relay.Function([x], out) - ca = context_analysis(mod, tvm.cpu()) - - cpu_dev = tvm.cpu().device_type - gpu_dev = tvm.cuda().device_type - for expr, dev in ca.items(): - if isinstance(expr, _expr.Call): - assert dev[0].value == gpu_dev - elif isinstance(expr, _expr.Var): - assert dev[0].value == cpu_dev - elif isinstance(expr, _expr.Constant): - assert dev[0].value == gpu_dev - - -def test_shape_func(): - if not tvm.testing.device_enabled("cuda") or not tvm.cuda(0).exist: - return - - mod = tvm.IRModule() - data_shape = (relay.Any(),) - x = relay.var("x", shape=data_shape) - y = relay.op.vm.shape_of(x) - z = relay.nn.relu(y) - p0 = relay.var("p0", shape=data_shape) - fn = relay.Function([p0], z) - out = relay.var("out", shape=(1,), dtype="int64") - ins = relay.Tuple([y]) - outs = relay.Tuple([out]) - is_inputs = [False] - shape_func = relay.op.vm.shape_func(fn, ins, outs, is_inputs) - mod["main"] = relay.Function([x, out], shape_func) - ca = context_analysis(mod, tvm.cuda()) - main = mod["main"] - - cpu_dev = tvm.cpu().device_type - gpu_dev = tvm.cuda().device_type - assert main.params[0] in ca and ca[main.params[0]][0].value == gpu_dev - # The output of shape func should be on cpu. - assert main.params[1] in ca and ca[main.params[1]][0].value == cpu_dev - # shape func is the body and it should be on cpu - assert main.body in ca and ca[main.body][0].value == cpu_dev - - -def test_vm_shape_of(): - if not tvm.testing.device_enabled("cuda") or not tvm.cuda(0).exist: - return - - mod = tvm.IRModule() - data_shape = (relay.Any(),) - x = relay.var("x", shape=data_shape) - y = relay.op.vm.shape_of(x) - mod["main"] = relay.Function([x], y) - ca = context_analysis(mod, tvm.cuda()) - main = mod["main"] - - cpu_dev = tvm.cpu().device_type - gpu_dev = tvm.cuda().device_type - assert main.params[0] in ca and ca[main.params[0]][0].value == gpu_dev - assert main.body in ca and ca[main.body][0].value == cpu_dev - - -def test_alloc_storage(): - if not tvm.testing.device_enabled("cuda") or not tvm.cuda(0).exist: - return - - mod = tvm.IRModule() - mod.import_from_std("core.rly") - size = relay.Var("size", relay.scalar_type("int64")) - alignment = relay.Var("alignment", relay.scalar_type("int64")) - # allocate a chunk on of memory on gpu. - sto = relay.op.memory.alloc_storage(size, alignment, tvm.cuda()) - mod["main"] = relay.Function([size, alignment], sto) - ca = context_analysis(mod, tvm.cuda()) - main = mod["main"] - body = main.body - - cpu_dev = tvm.cpu().device_type - gpu_dev = tvm.cuda().device_type - # Inputs are unified with alloc storage inputs which are on cpu - assert main.params[0] in ca and ca[main.params[0]][0].value == cpu_dev - assert main.params[1] in ca and ca[main.params[1]][0].value == cpu_dev - - assert isinstance(body, relay.Call) and len(body.args) == 2 - # size of alloc_storage is on cpu - assert body.args[0] in ca and ca[body.args[0]][0].value == cpu_dev - # alignment of alloc_storage is on cpu - assert body.args[1] in ca and ca[body.args[1]][0].value == cpu_dev - # alloc_storage is on gpu as specified - assert body in ca and ca[body][0].value == gpu_dev - - -def test_alloc_tensor(): - if not tvm.testing.device_enabled("cuda") or not tvm.cuda(0).exist: - return - - mod = tvm.IRModule() - mod.import_from_std("core.rly") - sto_type = relay.TypeCall(mod.get_global_type_var("Storage"), []) - sto = relay.Var("x", sto_type) - sh = relay.const(np.array([3, 2]), dtype="int64") - at = relay.op.memory.alloc_tensor(sto, relay.const(0, dtype="int64"), sh) - mod["main"] = relay.Function([sto], at) - ca = context_analysis(mod, tvm.cuda()) - main = mod["main"] - body = main.body - - cpu_dev = tvm.cpu().device_type - gpu_dev = tvm.cuda().device_type - # Input of the function falls back to the default device gpu - assert main.params[0] in ca and ca[main.params[0]][0].value == gpu_dev - - assert isinstance(body, relay.Call) and len(body.args) == 3 - # storage of alloc_tensor falls back to the default device gpu - assert body.args[0] in ca and ca[body.args[0]][0].value == gpu_dev - # shape of alloc_tensor is on cpu - assert body.args[1] in ca and ca[body.args[1]][0].value == cpu_dev - # alloc_tensor keeps the same device context as storage which is is on gpu - assert body in ca and ca[body][0].value == gpu_dev - - -def test_vm_reshape_tensor(): - if not tvm.testing.device_enabled("cuda") or not tvm.cuda(0).exist: - return - - x = relay.var("x", shape=(2, 8), dtype="float32") - shape = relay.const([-1, 4, 2], dtype="int64") - y = relay.op.vm.reshape_tensor(x, shape, [2, 4, 2]) - mod = tvm.IRModule() - mod["main"] = relay.Function([x], y) - ca = context_analysis(mod, tvm.cuda()) - main = mod["main"] - body = main.body - - cpu_dev = tvm.cpu().device_type - gpu_dev = tvm.cuda().device_type - # Input of the function falls back to the default device gpu - assert main.params[0] in ca and ca[main.params[0]][0].value == gpu_dev - - # dats of reshape_tensor falls back to the default device gpu - assert body.args[0] in ca and ca[body.args[0]][0].value == gpu_dev - # shape of reshape_tensor is on cpu - assert body.args[1] in ca and ca[body.args[1]][0].value == cpu_dev - # reshape_tensor sits on the same device as the data - assert body in ca and ca[body][0].value == gpu_dev - - -def test_dynamic_input(): - if not tvm.testing.device_enabled("cuda") or not tvm.cuda(0).exist: - return - - mod = tvm.IRModule() - data_shape = (relay.Any(), relay.Any()) - x0 = relay.var("x0", shape=data_shape) - x1 = relay.var("x1", shape=data_shape) - mod["main"] = relay.Function([x0, x1], x0 + x1) - - compiler = relay.vm.VMCompiler() - mod, _ = compiler.optimize(mod, target="cuda") - ca = context_analysis(mod, tvm.cpu()) - main = mod["main"] - - gpu_dev = tvm.cuda().device_type - assert main.params[0] in ca and ca[main.params[0]][0].value == gpu_dev - assert main.params[1] in ca and ca[main.params[1]][0].value == gpu_dev - assert main.body in ca and ca[main.body][0].value == gpu_dev - - -if __name__ == "__main__": - pytest.main([__file__]) diff --git a/tests/python/relay/test_pass_plan_devices.py b/tests/python/relay/test_pass_plan_devices.py index 2252d8a235c9..e3218ab1a829 100644 --- a/tests/python/relay/test_pass_plan_devices.py +++ b/tests/python/relay/test_pass_plan_devices.py @@ -80,9 +80,8 @@ def exercise(in_mod: tvm.IRModule, expected_mod: tvm.IRModule, reference_func, a # Idempotence rewrite_and_assert(expected_mod, expected_mod) # The VM can compile and possibly even run the module - # TODO(mbs): Disabled until VM supports new device planning. - # if not (reference_func is None) and not (args is None): - # eval_and_assert(in_mod, reference_func, args) + if not (reference_func is None) and not (args is None): + eval_and_assert(in_mod, reference_func, args) def test_plain(): @@ -380,10 +379,9 @@ def expected(): def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32], param_device_types=[1, 1, 2, 2], result_device_type=2) { - %0 = fn (%x, %y, param_device_types=[1, 1], result_device_type=1) { + let %f = fn (%x, %y, param_device_types=[1, 1], result_device_type=1) { add(%x, %y) }; - let %f = on_device(%0, device_type=1, is_fixed=True); %1 = %f(%a, %b); %2 = on_device(%1, device_type=1, is_fixed=True); %3 = device_copy(%2, src_dev_type=1, dst_dev_type=2); @@ -564,10 +562,9 @@ def expected(): #[version = "0.0.5"] def @main(%x: Tensor[(?), float32], %s: Tensor[(1), int64], param_device_types=[2, 1], result_device_type=1) { - %0 = fn (%y: Tensor[(?), float32], param_device_types=[2], result_device_type=2) { + let %p = fn (%y: Tensor[(?), float32], param_device_types=[2], result_device_type=2) { nn.relu(%y) }; - let %p = on_device(%0, device_type=2, is_fixed=True); %1 = vm.shape_of(%x, dtype="int64"); %2 = (%1,); %3 = (%s,); @@ -950,17 +947,17 @@ def ref(x): def test_propogation(): r""" The network and devices are as follows: - x <--- CPU + x <--- CPU | - log <--- CPU + negative <--- CPU / \ - log2 log10 <--- GPU + negative negative <--- GPU \ / - add <--- GPU + add <--- GPU | - tan <--- CPU + negative <--- CPU | - <--- CPU + <--- CPU """ def input(): @@ -968,16 +965,16 @@ def input(): """ #[version = "0.0.5"] def @main(%x: Tensor[(5, 7), float32]) { - %0 = log(%x); + %0 = negative(%x); %1 = on_device(%0, device_type=1); - %2 = log2(%1); + %2 = negative(%1); %3 = on_device(%0, device_type=1); - %4 = log10(%3); + %4 = negative(%3); %5 = on_device(%2, device_type=2); %6 = on_device(%4, device_type=2); %7 = add(%5, %6); %8 = on_device(%7, device_type=2); - %9 = tan(%8); + %9 = negative(%8); on_device(%9, device_type=1) } """ @@ -988,24 +985,24 @@ def expected(): """ #[version = "0.0.5"] def @main(%x: Tensor[(5, 7), float32], param_device_types=[1], result_device_type=1) { - %0 = log(%x); + %0 = negative(%x); %1 = on_device(%0, device_type=1, is_fixed=True); %2 = device_copy(%1, src_dev_type=1, dst_dev_type=2); %3 = on_device(%0, device_type=1, is_fixed=True); %4 = device_copy(%3, src_dev_type=1, dst_dev_type=2); - %5 = log2(%2); - %6 = log10(%4); + %5 = negative(%2); + %6 = negative(%4); %7 = add(%5, %6); %8 = on_device(%7, device_type=2, is_fixed=True); %9 = device_copy(%8, src_dev_type=2, dst_dev_type=1); - tan(%9) + negative(%9) } """ ) def ref(x): - y = np.log(x) - return np.tan(np.add(np.log2(y), np.log10(y))) + y = np.negative(x) + return np.negative(np.add(np.negative(y), np.negative(y))) exercise(input(), expected(), ref, rands((5, 7), 1)) diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py index 725d2765477f..d39390dab7b7 100644 --- a/tests/python/relay/test_vm.py +++ b/tests/python/relay/test_vm.py @@ -959,4 +959,6 @@ def test_benchmark_end_to_end_rpc(): if __name__ == "__main__": + import sys + sys.exit(pytest.main(sys.argv))