From 8d517c48ddf5d5532d8c53af87a1002fdb56477b Mon Sep 17 00:00:00 2001 From: Jiawei Liu Date: Thu, 11 Aug 2022 19:55:57 -0500 Subject: [PATCH] [Pass][UX] Statement rewriter for DataflowBlock (#210) - Implements a few APIs to quickly perform statement-level mutation: `add`/`remove_unused`/`remove_all_unused`/`replace_all_uses`. - Implemented `remove_all_unused` to remove dead statements inside `DataflowBlock` cc: @psrivas2 - Address minor issues (unnecessary headers and bad docstrings) in https://github.com/tlc-pack/relax/pull/163 --- include/tvm/relax/analysis.h | 40 ++- include/tvm/relax/binding_rewrite.h | 115 +++++++ include/tvm/relax/utils.h | 36 ++- python/tvm/relax/analysis/analysis.py | 23 +- python/tvm/relax/binding_rewrite.py | 153 ++++++++++ src/relax/analysis/udchain.cc | 49 +-- src/relax/analysis/var2value.cc | 27 ++ src/relax/ir/binding_rewrite.cc | 326 ++++++++++++++++++++ src/relax/ir/dataflow_pattern.cc | 2 - tests/python/relax/test_analysis.py | 121 +++++++- tests/python/relax/test_binding_rewrite.py | 333 +++++++++++++++++++++ 11 files changed, 1196 insertions(+), 29 deletions(-) create mode 100644 include/tvm/relax/binding_rewrite.h create mode 100644 python/tvm/relax/binding_rewrite.py create mode 100644 src/relax/ir/binding_rewrite.cc create mode 100644 tests/python/relax/test_binding_rewrite.py diff --git a/include/tvm/relax/analysis.h b/include/tvm/relax/analysis.h index 5c89b80bfc..f25b13a8ff 100644 --- a/include/tvm/relax/analysis.h +++ b/include/tvm/relax/analysis.h @@ -30,6 +30,8 @@ #include #include +#include + namespace tvm { namespace relax { @@ -111,26 +113,34 @@ TVM_DLL tvm::Array AllGlobalVars(const Expr& expr); /*! * \brief Analyze var -> value mapping from VarBindings. * - * \param m the IRModule to check. + * \param m The IRModule to check. * \return Var -> Value (Expr) */ -TVM_DLL runtime::Map AnalyzeVar2Value(const IRModule& m); +TVM_DLL Map AnalyzeVar2Value(const IRModule& m); /*! * \brief Analyze var -> value mapping from VarBindings. * - * \param expr the expression to check. + * \param expr The expression to check. * \return Var -> Value (Expr) */ -TVM_DLL runtime::Map AnalyzeVar2Value(const Expr& expr); +TVM_DLL Map AnalyzeVar2Value(const Expr& expr); /*! * \brief Analyze var -> value mapping from VarBindings. * - * \param dfb the dataflow block to check. + * \param dfb The dataflow block to check. * \return Var -> Value (Expr) */ -TVM_DLL runtime::Map AnalyzeVar2Value(const DataflowBlock& dfb); +TVM_DLL Map AnalyzeVar2Value(const DataflowBlock& dfb); + +/*! + * \brief Return a mapping from variable name to its Bindings. + * + * \param fn The function to be analyzed. + * \return A mapping from variable name to its Bindings. + */ +TVM_DLL Map> NameToBinding(const Function& fn); /*! * \brief Get the use-def chain of variables inside a dataflow block. @@ -138,7 +148,23 @@ TVM_DLL runtime::Map AnalyzeVar2Value(const DataflowBlock& dfb); * \param dfb The dataflow block to be analyzed. * \return A map mapping variable definitoins to a set of uses. */ -TVM_DLL runtime::Map> UseDefChain(const DataflowBlock& dfb); +TVM_DLL Map> DataflowBlockUseDef(const DataflowBlock& dfb); + +/*! + * \brief Get the use-def chain of variables inside a function. + * + * \param fn The function to be analyzed. + * \return A map from variable definitoins to a set of uses and variables needed by return value. + */ +std::pair>, Array> FunctionUseDef(const Function& fn); + +/*! + * \brief Remove unused statements inside DataflowBlocks. + * + * \param fn The function to remove unused statements. + * \return The function that contains no unused statements in DataflowBlock. + */ +TVM_DLL Function RemoveAllUnused(const Function fn); } // namespace relax } // namespace tvm diff --git a/include/tvm/relax/binding_rewrite.h b/include/tvm/relax/binding_rewrite.h new file mode 100644 index 0000000000..7cee0bf1d5 --- /dev/null +++ b/include/tvm/relax/binding_rewrite.h @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/binding_rewrite.h + * \brief An IR rewriter to easily add/remove/replace bindings (statements). + */ + +#ifndef TVM_RELAX_BINDING_REWRITE_H_ + +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace tvm { +namespace relax { + +/*! \brief Statement rewriter for relax.DataflowBlock. */ +class DataflowBlockRewriteNode : public Object { + public: + /*! \brief Replace all uses of old_var with new_var. */ + void ReplaceAllUses(Var old_var, Var new_var); + /*! \brief Insert a Binding statement. */ + void Add(Binding binding); + /*! \brief Insert an expression as VarBinding with variable name. */ + void Add(String var_name, Expr expr, bool is_dfvar = false) { + auto var = is_dfvar ? DataflowVar(var_name, expr->shape(), expr->checked_type()) + : Var(var_name, expr->shape(), expr->checked_type()); + Add(VarBinding(std::move(var), std::move(expr))); + } + /*! \brief Insert an expression as VarBinding with automatic variable name. */ + void Add(Expr expr, bool is_dfvar = false) { + Add(name_table_.GetUniqueName("tmp"), expr, is_dfvar); + } + /*! \brief Remove the definition statement of an unused variable. */ + void RemoveUnused(Var unused, bool allow_undef = false); + /*! \brief Remove the definition statements of all unused variables. */ + void RemoveAllUnused(); + + /*! \brief The rewritten dataflow block. */ + DataflowBlock MutatedDataflowBlock() { return dfb_.value(); } + /*! \brief The rewritten function. */ + Function MutatedFunc() { return root_fn_.value(); } + /*! \brief The rewritten IRModule. */ + IRModule MutateIRModule(IRModule irmod); + + /*! \brief Visit attributes. */ + void VisitAttrs(AttrVisitor* v) { + v->Visit("dfb", &dfb_); + v->Visit("root_fn", &root_fn_); + } + + static constexpr const char* _type_key = "relax.DataflowBlockRewrite"; + TVM_DECLARE_FINAL_OBJECT_INFO(DataflowBlockRewriteNode, Object); + + protected: + friend class DataflowBlockRewrite; + + Optional dfb_; //!< The rewritten dataflow block. + Optional root_fn_; //!< The rewritten function. + const FunctionNode* original_fn_ptr_; //!< Pointer to the original function. + Map> to_users_; //!< Map from variable to its users. + Array fn_outputs_; //!< Variables required by function outputs. + + private: + NameTable name_table_; //!< Name table for tracking and generating unique names. +}; + +/*! + * \brief A statement rewriter for relax.DataflowBlock. + * \sa DataflowBlockRewriteNode + */ +class DataflowBlockRewrite : public ObjectRef { + public: + TVM_DLL explicit DataflowBlockRewrite(DataflowBlock dfb, Function root_fn); + + /*! + * \brief mutable accessor. + * \return mutable access pointer. + */ + DataflowBlockRewriteNode* operator->() { + ICHECK(get() != nullptr); + return static_cast(get_mutable()); + } + + TVM_DEFINE_OBJECT_REF_METHODS(DataflowBlockRewrite, ObjectRef, DataflowBlockRewriteNode); +}; + +} // namespace relax +} // namespace tvm + +#define TVM_RELAX_BINDING_REWRITE_H_ +#endif // TVM_RELAX_BINDING_REWRITE_H_ diff --git a/include/tvm/relax/utils.h b/include/tvm/relax/utils.h index 9c6beb2697..b851aabfe2 100644 --- a/include/tvm/relax/utils.h +++ b/include/tvm/relax/utils.h @@ -24,8 +24,12 @@ #ifndef TVM_RELAX_UTILS_H_ #define TVM_RELAX_UTILS_H_ +#include + #include +#include #include +#include #include namespace tvm { @@ -53,8 +57,38 @@ class NameTable { return unique_prefix; } + NameTable() = default; + + template + explicit NameTable(Iter begin, Iter end, Lambda f) { + // static_assert is more reader-friendly than SFINAE when template specialization is not needed. + static_assert(std::is_convertible::value, + "Lambda f must has a signature of [?](*it) -> string {}"); + for (auto it = begin; it != end; ++it) { + const std::string& name = f(*it); + const size_t idx_last_first_num = std::distance( + std::find_if(name.rbegin(), name.rend(), [](char c) { return !std::isdigit(c); }), + name.rend()); + // name = {O = others}{D = consecutive digits} + // let O -> prefix; + std::string prefix = name.substr(0, idx_last_first_num); + ICHECK(prefix.size() > 0 && std::isalpha(prefix[0])) << "Invalid variable name: " << name; + if (0 == alloc_map_.count(prefix)) alloc_map_[prefix] = 0; + if (idx_last_first_num < name.size()) { // has some digits. + // let D's nearest natural number -> idx; + // note: stoul("000123") = 123; + alloc_map_[prefix] = + std::max(alloc_map_[prefix], std::stoi(name.substr(idx_last_first_num))); + } + } + } + + template + explicit NameTable(Iter begin, Iter end) + : NameTable(begin, end, [](const decltype(*begin)& v) { return v; }) {} + private: - std::unordered_map alloc_map_; + std::unordered_map alloc_map_; }; /*! diff --git a/python/tvm/relax/analysis/analysis.py b/python/tvm/relax/analysis/analysis.py index 65efee8987..3f6871e818 100644 --- a/python/tvm/relax/analysis/analysis.py +++ b/python/tvm/relax/analysis/analysis.py @@ -24,7 +24,7 @@ from typing import Dict, List import tvm -from tvm.relax.expr import DataflowBlock, Var, Expr, Function +from tvm.relax.expr import DataflowBlock, Var, Expr, Function, Binding from . import _ffi_api @@ -92,3 +92,24 @@ def udchain(dfb: DataflowBlock) -> Dict[Var, List[Var]]: A mapping from variable definition to its uses. """ return _ffi_api.udchain(dfb) + + +def name_to_binding(func: Function) -> Dict[str, List[Binding]]: + """Return a map from variable name to its bindings.""" + return _ffi_api.name_to_binding(func) + + +def remove_all_unused(func: Function) -> Function: + """Remove all unused variables from the function. + + Parameters + ---------- + func : Function + The input function to be analyzed. + + Returns + ------- + Function + The function with unused variables removed. + """ + return _ffi_api.remove_all_unused(func) diff --git a/python/tvm/relax/binding_rewrite.py b/python/tvm/relax/binding_rewrite.py new file mode 100644 index 0000000000..3f0c8d5cb3 --- /dev/null +++ b/python/tvm/relax/binding_rewrite.py @@ -0,0 +1,153 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=no-else-return, invalid-name +"""Developer API of add/remove/replace bindings in Relax.""" + +from typing import Optional + +import tvm +import tvm._ffi +from tvm.runtime import Object +from . import Binding, DataflowBlock, Expr, Function, Var +from . import _ffi_api + + +@tvm._ffi.register_object("relax.DataflowBlockRewrite") +class DataflowBlockRewrite(Object): + """ + A binding/statement-level dataflow block rewriter. + + Notes + ----- + Due to the immutable and copy-on-write nature of TVM AST nodes, the rewriting is not done in + place. Instead, a new DataflowBlock is created and returned with mutated_dfb. Similarly, its new + root Function is created and returned by mutated_root_fn. To apply this change for an IRModule, + use mutate_irmodule which rewrites the old function that registered in the constructor. + """ + + def __init__(self, dfb: DataflowBlock, root_fn: Function): + """ + Construct a rewriter with the DataflowBlock to rewrite and its root function. + + Parameters + ---------- + dfb : DataflowBlock + The DataflowBlock to rewrite. + root_fn : Function + The root function of the DataflowBlock. + """ + self.func_name = root_fn.__name__ if hasattr(root_fn, "__name__") else None + self.__init_handle_by_constructor__(_ffi_api.DataflowBlockRewrite, dfb, root_fn) + + def replace_all_uses(self, old_var: Var, new_var: Var) -> None: + """ + Replace all uses of old_var with new_var. + + Parameters + ---------- + old_var : Var + The old variable to replace. + new_var : Var + The new variable to replace with. + """ + _ffi_api.dfb_rewrite_replace_all_uses(self, old_var, new_var) + + def add_binding(self, binding: Binding) -> None: + return _ffi_api.dfb_rewrite_add_binding(self, binding) + + def add(self, expr: Expr, name: Optional[str] = None, is_dfvar: bool = False) -> None: + """ + Add a new statement to the DataflowBlock with an automatically generated variable name. + + Parameters + ---------- + expr : Expr + The expression to add. + name : Optional[str], optional + Variable name, by default None + is_dfvar : bool, optional + The variable type, by default False + + Notes + ----- + If the variable name is not given, it will be automatically generated in a form of + "tmp${COUNTER}". The variable type will be DataflowVar if is_dfvar is True, otherwise + it will be Var. Being Var means the variables are output variables of the DataflowBlock. + While being DataflowVar means the variables are internal variables of the DataflowBlock. + """ + _ffi_api.dfb_rewrite_add(self, expr, name, is_dfvar) + + def remove_unused(self, var: Var, allow_undef=False) -> None: + """ + Remove a statement by its variable definition if and only if it is unused. + + Parameters + ---------- + var : Var + The unused variable definition. + allow_undef : bool, optional + Whether to allow var being undefined variable, by default False + + Raises + ------ + TVMError if the variable is used or undefined (allow_undef=False). + """ + _ffi_api.dfb_rewrite_remove_unused(self, var, allow_undef) + + def remove_all_unused(self) -> None: + """ + Remove all unused variables. + + Notes + ----- + This could remove unused variables in other DataflowBlocks as well. + """ + _ffi_api.dfb_rewrite_remove_all_unused(self) + + def mutated_dfb(self) -> DataflowBlock: + """ + Returns the mutated DataflowBlock. + """ + return self.dfb + + def mutated_root_fn(self) -> Function: + """ + Returns the mutated root function. + """ + ret = self.root_fn + if self.func_name: + ret.__name__ = self.func_name + return ret + + def mutate_irmodule(self, irmodule: tvm.IRModule) -> tvm.IRModule: + """ + Return an updated IRModule by replacing the old function with the mutated root function. + + Parameters + ---------- + irmodule : tvm.IRModule + The base IRModule to update. + + Returns + ------- + tvm.IRModule + The updated IRModule. + """ + ret = _ffi_api.dfb_rewrite_mutate_irmodule(self, irmodule) + if hasattr(irmodule, "__name__"): + ret.__name__ = irmodule.__name__ + return ret diff --git a/src/relax/analysis/udchain.cc b/src/relax/analysis/udchain.cc index 651d79b953..f3d9b4686b 100644 --- a/src/relax/analysis/udchain.cc +++ b/src/relax/analysis/udchain.cc @@ -23,19 +23,12 @@ */ #include -#include -#include #include #include -#include -#include #include #include -#include #include -#include -#include #include #include @@ -44,7 +37,8 @@ namespace relax { class UDChain : public relax::ExprVisitor { public: - std::map> def2use; + // nullptr users means it is the output of the function. + std::map> to_users; const VarNode* cur_user_; @@ -56,23 +50,44 @@ class UDChain : public relax::ExprVisitor { cur_user_ = nullptr; } - void VisitExpr_(const VarNode* op) override { - if (nullptr == cur_user_) return; - - def2use[op].insert(cur_user_); - } - void VisitVarDef(const Var& var) override { def2use[var.get()] = {}; } + void VisitExpr_(const VarNode* op) override { to_users[op].insert(cur_user_); } + void VisitVarDef(const Var& var) override { to_users[var.get()] = {}; } + void VisitExpr_(const FunctionNode* op) override { ExprVisitor::VisitExpr_(op); } void VisitExpr_(const DataflowVarNode* op) override { VisitExpr_(static_cast(op)); } }; -runtime::Map> UseDefChain(const DataflowBlock& dfb) { +std::pair>, runtime::Array> FunctionUseDef( + const Function& fn) { + UDChain udchain; + udchain.VisitExpr_(fn.get()); + + Map> user_map; + Array fn_outs; + + for (const auto& kv : udchain.to_users) { + Array uses{}; + uses.reserve(kv.second.size()); + for (const auto& v : kv.second) { + if (nullptr == v && + fn_outs.end() == std::find(fn_outs.begin(), fn_outs.end(), GetRef(kv.first))) { + fn_outs.push_back(GetRef(kv.first)); + } else { + uses.push_back(GetRef(v)); + } + } + user_map.Set(GetRef(kv.first), std::move(uses)); + } + return std::make_pair(std::move(user_map), std::move(fn_outs)); +} + +runtime::Map> DataflowBlockUseDef(const DataflowBlock& dfb) { UDChain udchain; udchain.VisitBindingBlock_(dfb.get()); runtime::Map> ret; - for (const auto& kv : udchain.def2use) { + for (const auto& kv : udchain.to_users) { Array uses{}; uses.reserve(kv.second.size()); for (const auto& v : kv.second) uses.push_back(GetRef(v)); @@ -81,7 +96,7 @@ runtime::Map> UseDefChain(const DataflowBlock& dfb) { return ret; } -TVM_REGISTER_GLOBAL("relax.analysis.udchain").set_body_typed(UseDefChain); +TVM_REGISTER_GLOBAL("relax.analysis.udchain").set_body_typed(DataflowBlockUseDef); } // namespace relax } // namespace tvm diff --git a/src/relax/analysis/var2value.cc b/src/relax/analysis/var2value.cc index e1693ec4ed..680a2a7261 100644 --- a/src/relax/analysis/var2value.cc +++ b/src/relax/analysis/var2value.cc @@ -17,6 +17,8 @@ * under the License. */ +#include +#include #include namespace tvm { @@ -58,5 +60,30 @@ TVM_REGISTER_GLOBAL(("relax.analysis.get_var2val")).set_body_typed([](const Func return AnalyzeVar2Value(f); }); +class Name2BindingAnalysis : public relax::ExprVisitor { + public: + // runtime::Map is not suitable for doing in-place update. + // so we use standard container for internal usage. + std::map> name2bindings_; + void VisitBinding_(const VarBindingNode* binding) override { + const auto& vname = binding->var->name_hint(); + name2bindings_[vname].push_back(GetRef(binding)); + } + + void VisitBinding_(const MatchShapeNode* binding) override { + const auto& vname = binding->var->name_hint(); + name2bindings_[vname].push_back(GetRef(binding)); + } +}; + +Map> NameToBinding(const Function& fn) { + Name2BindingAnalysis analysis{}; + analysis.VisitExpr_(fn.get()); + return Map>(std::make_move_iterator(analysis.name2bindings_.begin()), + std::make_move_iterator(analysis.name2bindings_.end())); +} + +TVM_REGISTER_GLOBAL(("relax.analysis.name_to_binding")).set_body_typed(NameToBinding); + } // namespace relax } // namespace tvm diff --git a/src/relax/ir/binding_rewrite.cc b/src/relax/ir/binding_rewrite.cc new file mode 100644 index 0000000000..a6d6b27d81 --- /dev/null +++ b/src/relax/ir/binding_rewrite.cc @@ -0,0 +1,326 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/ir/binding_rewrite.cc + * \brief Implementation of binding rewriters. + */ + +#include +#include +#include +#include + +#include +#include + +namespace tvm { +namespace relax { + +TVM_REGISTER_NODE_TYPE(DataflowBlockRewriteNode); +DataflowBlockRewrite::DataflowBlockRewrite(DataflowBlock dfb, Function root_fn) { + auto n = make_object(); + n->dfb_ = dfb; + n->root_fn_ = root_fn; + n->original_fn_ptr_ = root_fn.get(); + auto p = FunctionUseDef(root_fn); + n->to_users_ = std::move(p.first); + n->fn_outputs_ = std::move(p.second); + n->name_table_ = NameTable(n->to_users_.begin(), n->to_users_.end(), + [](const auto& p) { return p.first->name_hint(); }); + + data_ = std::move(n); +} + +TVM_REGISTER_GLOBAL("relax.DataflowBlockRewrite") + .set_body_typed([](DataflowBlock dfb, Function root_fn) { + return DataflowBlockRewrite(dfb, root_fn); + }); + +void DataflowBlockRewriteNode::ReplaceAllUses(Var old_var, Var new_var) { + class ReplaceAllUsePass : public ExprMutator { + Var old_var, new_var; + const DataflowBlockNode* const to_catch; + + public: + const DataflowBlockNode* caught = nullptr; + + ReplaceAllUsePass(Var old_var, Var new_var, const DataflowBlockNode* to_catch) + : old_var(old_var), new_var(new_var), to_catch(to_catch) {} + + using ExprMutator::VisitExpr_; + + Expr VisitExpr_(const VarNode* op) override { + return (op == old_var.get()) ? new_var : GetRef(op); + } + + Expr VisitExpr_(const DataflowVarNode* op) override { + return (op == old_var.get()) ? new_var : GetRef(op); + } + + BindingBlock VisitBindingBlock_(const DataflowBlockNode* op) override { + BindingBlock res = ExprMutator::VisitBindingBlock_(op); + if (op == to_catch) caught = static_cast(res.get()); + return res; + } + }; + + ICHECK(to_users_.find(old_var) != to_users_.end()) << "Cannot find " << old_var; + ICHECK(to_users_.find(new_var) != to_users_.end()) << "Cannot find " << new_var; + + // replace uses inside the DataflowBlock. + ReplaceAllUsePass replacer(old_var, new_var, dfb_.get()); + root_fn_ = Downcast(replacer.VisitExpr_(root_fn_.get())); + dfb_ = GetRef(replacer.caught); + + // update udchain + // old_var -> old_var users | changed to {} + // new_var -> {?} | changed to old_var users + for (Var user : to_users_[old_var]) { + auto new_var_uses = to_users_[new_var]; + if (new_var_uses.end() == std::find(new_var_uses.begin(), new_var_uses.end(), user)) { + new_var_uses.push_back(user); + } + } + + to_users_.Set(old_var, {}); + + auto it_old_output = std::find(fn_outputs_.begin(), fn_outputs_.end(), old_var); + if (it_old_output != fn_outputs_.end()) { + fn_outputs_.Set(std::distance(fn_outputs_.begin(), it_old_output), new_var); + } +} + +TVM_REGISTER_GLOBAL("relax.dfb_rewrite_replace_all_uses") + .set_body_typed([](DataflowBlockRewrite rwt, Var old_var, Var new_var) { + rwt->ReplaceAllUses(old_var, new_var); + }); + +class UpdateDFB : public ExprMutator { + private: + DataflowBlock old_dfb, new_dfb; + + public: + UpdateDFB(DataflowBlock old_dfb, DataflowBlock new_dfb) + : old_dfb(std::move(old_dfb)), new_dfb(std::move(new_dfb)) {} + + BindingBlock VisitBindingBlock_(const DataflowBlockNode* op) override { + return old_dfb.get() == op ? new_dfb : old_dfb; + } +}; + +void DataflowBlockRewriteNode::Add(Binding binding) { + auto p = [binding] { + if (auto vb = binding.as()) { + return std::make_pair(vb->var, vb->value); + } else if (auto ms = binding.as()) { + return std::make_pair(ms->var, ms->value); + } + LOG(FATAL) << "Unsupported binding type"; + return std::make_pair(Var{}, Expr{}); + }(); + + Var var = p.first; + Expr val = p.second; + + ICHECK(0 == to_users_.count(var)) << var << " has been defined so cannot be added."; + + // Add this VarBinding statement after the definition of uses. + std::set used_vars = [val] { + class UsedVars : public ExprVisitor { + public: + std::set used_vars; + void VisitExpr_(const VarNode* op) override { used_vars.insert(op); } + void VisitExpr_(const DataflowVarNode* op) override { used_vars.insert(op); } + } uvar{}; + uvar.VisitExpr(val); + return std::move(uvar.used_vars); + }(); + + size_t line_last_req_def = 0; + for (size_t i = 0; i < dfb_.value()->bindings.size(); ++i) { + auto line = dfb_.value()->bindings[i]; + if (auto varbind = line.as()) { + if (used_vars.find(varbind->var.get()) != used_vars.cend()) line_last_req_def = i; + } else if (auto mshape = line.as()) { + if (used_vars.find(mshape->var.get()) != used_vars.cend()) line_last_req_def = i; + } + } + + auto old_dfb = dfb_.value(); + + dfb_ = [old_dfb, binding, line_last_req_def, this] { + auto new_dfb = dfb_.value(); + new_dfb.CopyOnWrite()->bindings.insert(dfb_.value()->bindings.begin() + 1 + line_last_req_def, + binding); + return new_dfb; + }(); + + auto updater = UpdateDFB(old_dfb, dfb_.value()); + root_fn_ = Downcast(updater.VisitExpr_(root_fn_.get())); + + for (const VarNode* v : used_vars) to_users_.Get(GetRef(v)).value().push_back(var); +} + +TVM_REGISTER_GLOBAL("relax.dfb_rewrite_add_binding") + .set_body_typed([](DataflowBlockRewrite rwt, Binding vb) { rwt->Add(vb); }); + +TVM_REGISTER_GLOBAL("relax.dfb_rewrite_add") + .set_body_typed([](DataflowBlockRewrite rwt, Expr expr, Optional name, bool is_dfvar) { + if (name.get()) { + rwt->Add(name.value(), expr, is_dfvar); + } else { + rwt->Add(expr, is_dfvar); + } + }); + +class RemoveUnusedVars : public ExprMutator { + public: + std::set unused_vars; + Optional caught_rewrite = NullOpt; + + RemoveUnusedVars(Map> users, Array fn_outputs) + : unused_vars([&] { + std::vector unused; + + // iterative dataflow algorithm. + size_t prev_size; + do { + prev_size = unused.size(); + + for (const auto& kv : users) { + // var -> [users...] + // var is unused iff + // user -> empty + // var is not output var + if (kv.second.empty() && // kv.first is not used by fn outputs. + fn_outputs.end() == std::find(fn_outputs.begin(), fn_outputs.end(), kv.first)) { + unused.push_back(kv.first); + } + } + + for (size_t i = prev_size; i < unused.size(); ++i) { + users.erase(unused[i]); + // remove def site. + for (auto kv : users) { // remove use site. + auto it = std::find(kv.second.begin(), kv.second.end(), unused[i]); + if (it != kv.second.end()) { + kv.second.erase(it); + users.Set(kv.first, std::move(kv.second)); + } + } + } + } while (prev_size != unused.size()); // changed? => continue. + + return std::set(unused.begin(), unused.end()); + }()) {} + + RemoveUnusedVars(std::pair>, Array> users_and_outputs) + : RemoveUnusedVars(std::move(users_and_outputs.first), std::move(users_and_outputs.second)) {} + RemoveUnusedVars(Function fn) : RemoveUnusedVars(FunctionUseDef(fn)) {} + RemoveUnusedVars(std::set unused_vars) : unused_vars(std::move(unused_vars)) {} + + BindingBlock VisitBindingBlock_(const DataflowBlockNode* block) { + auto prev_dfb = GetRef(block); + builder_->BeginDataflowBlock(); + for (Binding binding : block->bindings) { + if (const auto* node = binding.as()) { + if (!unused_vars.count(node->var)) VisitBinding_(node); + } else if (const auto* node = binding.as()) { + if (!unused_vars.count(node->var)) VisitBinding_(node); + } else { + LOG(FATAL) << "TypeError: Invalid type: " << binding->GetTypeKey(); + } + } + auto new_dfb = builder_->EndBlock(); + if (caught_rewrite == prev_dfb) caught_rewrite = Downcast(new_dfb); + return std::move(new_dfb); + } +}; + +void DataflowBlockRewriteNode::RemoveUnused(Var unused, bool allow_undef) { + // first need to check if this var is used. + if (0 == to_users_.count(unused)) { // no def. + if (allow_undef) return; + LOG(FATAL) << unused << " undefined. Set allow_undef=True to allow 'removing' undefined var"; + } + + ICHECK(to_users_[unused].empty()) + << unused << " is used by " << to_users_[unused].size() << " vars"; + + auto old_dfb = dfb_.value(); + + RemoveUnusedVars remover({unused}); + dfb_ = Downcast(remover.VisitBindingBlock_(old_dfb.get())); + + auto updater = UpdateDFB(old_dfb, dfb_.value()); + root_fn_ = Downcast(updater.VisitExpr_(root_fn_.get())); + + to_users_.erase(unused); // update use-def chain. +} + +TVM_REGISTER_GLOBAL("relax.dfb_rewrite_remove_unused") + .set_body_typed([](DataflowBlockRewrite rwt, Var unused, bool allow_undef) { + rwt->RemoveUnused(unused, allow_undef); + }); + +void DataflowBlockRewriteNode::RemoveAllUnused() { + RemoveUnusedVars remover(to_users_, fn_outputs_); + remover.caught_rewrite = dfb_.value(); + + // this could also clean unused variables in other DataflowBlock. + root_fn_ = Downcast(remover.VisitExpr_(root_fn_.get())); + + // DataflowBlock could be None. + dfb_ = remover.caught_rewrite.value(); + + // clean up use-def chain. + for (const auto& unused : remover.unused_vars) to_users_.erase(unused); +} + +TVM_REGISTER_GLOBAL("relax.dfb_rewrite_remove_all_unused") + .set_body_typed([](DataflowBlockRewrite rwt) { rwt->RemoveAllUnused(); }); + +Function RemoveAllUnused(Function fn) { + RemoveUnusedVars remover(fn); + return Downcast(remover.VisitExpr_(fn.get())); +} + +TVM_REGISTER_GLOBAL("relax.analysis.remove_all_unused").set_body_typed(RemoveAllUnused); + +IRModule DataflowBlockRewriteNode::MutateIRModule(IRModule irmod) { + BlockBuilder builder = BlockBuilder::Create(irmod); + + for (auto& p : irmod->functions) { + if (original_fn_ptr_ == p.second.get()) { + builder->UpdateFunction(p.first, root_fn_.value()); + break; + } + } + + return builder->GetContextIRModule(); +} + +TVM_REGISTER_GLOBAL("relax.dfb_rewrite_mutate_irmodule") + .set_body_typed([](DataflowBlockRewrite rwt, IRModule irmod) { + return rwt->MutateIRModule(irmod); + }); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/ir/dataflow_pattern.cc b/src/relax/ir/dataflow_pattern.cc index 1fdaeed500..69ef767b0f 100644 --- a/src/relax/ir/dataflow_pattern.cc +++ b/src/relax/ir/dataflow_pattern.cc @@ -28,8 +28,6 @@ #include #include -#include "tvm/runtime/memory.h" - #define RELAX_PATTERN_PRINTER_DEF(NODE_TYPE, REPR_LAMBDA) \ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) \ .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { \ diff --git a/tests/python/relax/test_analysis.py b/tests/python/relax/test_analysis.py index 3c5b891743..601e5bf86b 100644 --- a/tests/python/relax/test_analysis.py +++ b/tests/python/relax/test_analysis.py @@ -14,12 +14,15 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + +from __future__ import annotations import pytest import tvm from tvm import tir from tvm import relax as rx -from tvm.relax.analysis import udchain +from tvm.relax.analysis import udchain, remove_all_unused, name_to_binding +from tvm.script import relax as R def test_dispatch_var(): @@ -91,5 +94,121 @@ def test_use_def(): assert set(udc[gv0]) == set() +def test_chained_remove_all_unused(): + @tvm.script.ir_module + class IdentityUnused: + @R.function + def main(x: Tensor((32, 32), "float32")) -> Tensor: + with R.dataflow(): + lv0 = x + unused0 = R.call_tir(my_sigmoid, (x,), (32, 32), dtype="float32") + unused1 = R.call_tir(my_sigmoid, (unused0,), (32, 32), dtype="float32") + R.output(lv0) + return lv0 + + optimized = remove_all_unused(IdentityUnused["main"]) + + @tvm.script.ir_module + class GroundTruth: + @R.function + def main(x: Tensor((32, 32), "float32")) -> Tensor: + with R.dataflow(): + lv0 = x + R.output(lv0) + return lv0 + + tvm.ir.assert_structural_equal(optimized, GroundTruth["main"]) + + +def test_binding_block_remove_all_unused(): + @tvm.script.ir_module + class IdentityUnused: + @R.function + def main(x: Tensor((32, 32), "float32")) -> Tensor: + with R.dataflow(): + lv0 = x + unused0 = R.call_tir(my_sigmoid, (x,), (32, 32), dtype="float32") + unused1 = R.call_tir(my_sigmoid, (unused0,), (32, 32), dtype="float32") + R.output(lv0) + z = R.call_packed("vm.builtin.copy", lv0, type_args=(Tensor((32, 32), "float32"))) + return z + + optimized = remove_all_unused(IdentityUnused["main"]) + + @tvm.script.ir_module + class GroundTruth: + @R.function + def main(x: Tensor((32, 32), "float32")) -> Tensor: + with R.dataflow(): + lv0 = x + R.output(lv0) + z = R.call_packed("vm.builtin.copy", lv0, type_args=(Tensor((32, 32), "float32"))) + return z + + tvm.ir.assert_structural_equal(optimized, GroundTruth["main"]) + + +def test_binding_block_fake_unused_remove_all_unused(): + @tvm.script.ir_module + class IdentityUnused: + @R.function + def main(x: Tensor((32, 32), "float32")) -> Tensor: + with R.dataflow(): + lv0 = x + R.output(lv0) + z = R.call_packed("vm.builtin.copy", lv0, type_args=(Tensor((32, 32), "float32"))) + return lv0 + + optimized = remove_all_unused(IdentityUnused["main"]) + + @tvm.script.ir_module + class GroundTruth: + @R.function + def main(x: Tensor((32, 32), "float32")) -> Tensor: + with R.dataflow(): + lv0 = x + R.output(lv0) + # This might bring side effect so cannot be removed. + z = R.call_packed("vm.builtin.copy", lv0, type_args=(Tensor((32, 32), "float32"))) + return lv0 + + tvm.ir.assert_structural_equal(optimized, GroundTruth["main"]) + + +def test_edge_binding_block_fake_unused_remove_all_unused(): + @tvm.script.ir_module + class IdentityUnused: + @R.function + def main(x: Tensor((32, 32), "float32")) -> Tensor((32, 32), "float32"): + z = R.call_packed("vm.builtin.copy", x, type_args=(Tensor((32, 32), "float32"))) + return x + + optimized = remove_all_unused(IdentityUnused["main"]) + tvm.ir.assert_structural_equal(optimized, IdentityUnused["main"]) + + +def test_name_to_binding_var_shadowing(): + @R.function + def main(x: Tensor((32, 32), "float32")) -> Tensor: + with R.dataflow(): + lv0 = x + lv1 = lv0 + R.output(lv1) + + with R.dataflow(): + lv0 = lv1 # shadowing + lv2 = lv0 + R.output(lv2) + return lv2 + + n2binding = name_to_binding(main) + + assert "lv0" in n2binding + assert "lv1" in n2binding + assert "lv2" in n2binding + + assert len(n2binding["lv0"]) == 2 + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/relax/test_binding_rewrite.py b/tests/python/relax/test_binding_rewrite.py new file mode 100644 index 0000000000..86959cf0a3 --- /dev/null +++ b/tests/python/relax/test_binding_rewrite.py @@ -0,0 +1,333 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations +import pytest + +import re + +import tvm +from tvm._ffi.base import TVMError +from tvm.relax.binding_rewrite import DataflowBlockRewrite +from tvm.relax.analysis import name_to_binding +from tvm.relax.expr import DataflowVar, Var +from tvm.script import relax as R + + +@tvm.script.ir_module +class Identity: + @R.function + def main(x: Tensor((32, 32), "float32")) -> Tensor: + with R.dataflow(): + lv0 = x + R.output(lv0) + return lv0 + + +def assert_immutability(rwt, original_dfb, original_root_fn): + assert rwt.mutated_dfb() != original_dfb + assert rwt.mutated_root_fn() != original_root_fn + assert rwt.mutated_root_fn().body.blocks[0] != original_dfb + assert rwt.mutated_root_fn().body.blocks[0] == rwt.mutated_dfb() + + +def test_null_construct(): + root_fn = Identity["main"] + dfb = root_fn.body.blocks[0] + + DataflowBlockRewrite(dfb, root_fn) + + +def test_simple_add(): + root_fn = Identity["main"] + dfb = root_fn.body.blocks[0] + + rwt = DataflowBlockRewrite(dfb, root_fn) + rwt.add(name="tmp", expr=Identity["main"].params[0], is_dfvar=True) + + assert_immutability(rwt, dfb, root_fn) + + # check "tmp" added + assert "tmp" in name_to_binding(rwt.mutated_root_fn()) + + @tvm.script.ir_module + class GroundTruth: + @R.function + def main(x: Tensor((32, 32), "float32")) -> Tensor: + with R.dataflow(): + lv0 = x + tmp: Tensor((32, 32), "float32") = x + R.output(lv0) + return lv0 + + tvm.ir.assert_structural_equal(rwt.mutated_root_fn(), GroundTruth["main"]) + + +def test_simple_auto_add_var(): + root_fn = Identity["main"] + dfb = root_fn.body.blocks[0] + + rwt = DataflowBlockRewrite(dfb, root_fn) + rwt.add(root_fn.params[0], is_dfvar=False) + + assert isinstance(rwt.mutated_dfb().bindings[-1].var, Var) + + assert_immutability(rwt, dfb, root_fn) + + +def test_simple_auto_add_dfvar(): + root_fn = Identity["main"] + dfb = root_fn.body.blocks[0] + + rwt = DataflowBlockRewrite(dfb, root_fn) + rwt.add(root_fn.params[0], is_dfvar=True) + + assert isinstance(rwt.mutated_dfb().bindings[-1].var, DataflowVar) + + # immutatbility + assert_immutability(rwt, dfb, root_fn) + + +def test_simple_remove_unused(): + @tvm.script.ir_module + class IdentityUnused: + @R.function + def main(x: Tensor((32, 32), "float32")) -> Tensor: + with R.dataflow(): + lv0 = x + unused = lv0 + R.output(lv0) + return lv0 + + root_fn = IdentityUnused["main"] + dfb = root_fn.body.blocks[0] + + n2binding = name_to_binding(IdentityUnused["main"]) + + rwt = DataflowBlockRewrite(dfb, root_fn) + rwt.remove_unused(n2binding["unused"][0].var) + + assert_immutability(rwt, dfb, root_fn) + + # check "unused" removed + assert "unused" not in name_to_binding(rwt.mutated_root_fn()) + + @tvm.script.ir_module + class GroundTruth: + @R.function + def main(x: Tensor((32, 32), "float32")) -> Tensor: + with R.dataflow(): + lv0 = x + R.output(lv0) + return lv0 + + tvm.ir.assert_structural_equal(rwt.mutated_root_fn(), GroundTruth["main"]) + + +def test_remove_unused_undef(): + root_fn = Identity["main"] + dfb = root_fn.body.blocks[0] + + with pytest.raises(TVMError): + rwt = DataflowBlockRewrite(dfb, root_fn) + rwt.remove_unused(Var("whatever")) + + rwt = DataflowBlockRewrite(dfb, root_fn) + rwt.remove_unused(Var("whatever"), allow_undef=True) + + assert root_fn == rwt.mutated_root_fn() + + +def test_simple_rm_all_unused(): + @tvm.script.ir_module + class IdentityUnused: + @R.function + def main(x: Tensor((32, 32), "float32")) -> Tensor: + with R.dataflow(): + lv0 = x + unused0 = lv0 + unused1 = lv0 + R.output(lv0) + return lv0 + + root_fn = IdentityUnused["main"] + dfb = root_fn.body.blocks[0] + + rwt = DataflowBlockRewrite(dfb, root_fn) + rwt.remove_all_unused() + + @tvm.script.ir_module + class GroundTruth: + @R.function + def main(x: Tensor((32, 32), "float32")) -> Tensor: + with R.dataflow(): + lv0 = x + R.output(lv0) + return lv0 + + tvm.ir.assert_structural_equal(rwt.mutated_root_fn(), GroundTruth["main"]) + + +@tvm.script.ir_module +class DeadDFBlock: + @R.function + def main(x: Tensor((32, 32), "float32")) -> Tensor((32, 32), "float32"): + with R.dataflow(): + lv0 = x + R.output(lv0) + return x + + +def test_empty_dfb_after_removal(): + root_fn = DeadDFBlock["main"] + dfb = root_fn.body.blocks[0] + + rwt = DataflowBlockRewrite(dfb, root_fn) + rwt.remove_unused(DeadDFBlock["main"].body.blocks[0].bindings[0].var) + + @tvm.script.ir_module + class GroundTruth: + @R.function + def main(x: Tensor((32, 32), "float32")) -> Tensor((32, 32), "float32"): + return x + + tvm.ir.assert_structural_equal(rwt.mutated_root_fn(), GroundTruth["main"]) + + +def test_empty_dfb_after_all_removal(): + dfb = DeadDFBlock["main"].body.blocks[0] + root_fn = DeadDFBlock["main"] + + rwt = DataflowBlockRewrite(dfb, root_fn) + rwt.remove_all_unused() + + @tvm.script.ir_module + class GroundTruth: + @R.function + def main(x: Tensor((32, 32), "float32")) -> Tensor((32, 32), "float32"): + return x + + tvm.ir.assert_structural_equal(rwt.mutated_root_fn(), GroundTruth["main"]) + + +def test_chained_rm_all_unused(): + @tvm.script.ir_module + class IdentityChainedUnused: + @R.function + def main(x: Tensor((32, 32), "float32")) -> Tensor: + with R.dataflow(): + lv0 = x + unused0 = R.call_tir(my_sigmoid, (x,), (32, 32), dtype="float32") + unused1 = R.call_tir(my_sigmoid, (unused0,), (32, 32), dtype="float32") + R.output(lv0) + return lv0 + + root_fn = IdentityChainedUnused["main"] + dfb = root_fn.body.blocks[0] + + rwt = DataflowBlockRewrite(dfb, root_fn) + rwt.remove_all_unused() + + @tvm.script.ir_module + class GroundTruth: + @R.function + def main(x: Tensor((32, 32), "float32")) -> Tensor: + with R.dataflow(): + lv0 = x + R.output(lv0) + return lv0 + + tvm.ir.assert_structural_equal(rwt.mutated_root_fn(), GroundTruth["main"]) + + +def test_simple_replace_all_uses(): + @tvm.script.ir_module + class Lv0To1: + @R.function + def main(x: Tensor((32, 32), "float32")) -> Tensor((32, 32), "float32"): + # lv0 => lv1 + # / \ + # lv2 lv3 + # \ / + # lv4 + with R.dataflow(): + lv0: Tensor((32, 32), "float32") = R.call_tir( + my_relu, (x,), (32, 32), dtype="float32" + ) + lv1: Tensor((32, 32), "float32") = R.call_tir( + my_sigmoid, (x,), (32, 32), dtype="float32" + ) + lv2: Tensor((32, 32), "float32") = R.call_tir( + my_add, (x, lv0), (32, 32), dtype="float32" + ) + lv3: Tensor((32, 32), "float32") = R.call_tir( + my_mul, (x, lv0), (32, 32), dtype="float32" + ) + lv4: Tensor((32, 32), "float32") = R.call_tir( + my_whatever, (lv2, lv3), (32, 32), dtype="float32" + ) + R.output(lv4) + return lv4 + + root_fn = Lv0To1["main"] + dfb = root_fn.body.blocks[0] + + n2binding = name_to_binding(root_fn) + + rwt = DataflowBlockRewrite(dfb, root_fn) + rwt.replace_all_uses(n2binding["lv0"][0].var, n2binding["lv1"][0].var) + rwt.remove_unused(n2binding["lv0"][0].var) + + assert_immutability(rwt, dfb, root_fn) + + n2binding_after = name_to_binding(rwt.mutated_root_fn()) + assert "lv0" not in n2binding_after + + +def test_simple_module_update(): + @tvm.script.ir_module + class Identity: + @R.function + def main(x: Tensor((32, 32), "float32")) -> Tensor: + with R.dataflow(): + lv0 = x + R.output(lv0) + return lv0 + + root_fn = Identity["main"] + dfb = root_fn.body.blocks[0] + + rwt = DataflowBlockRewrite(dfb, root_fn) + rwt.add(name="tmp", expr=root_fn.params[0], is_dfvar=True) + + new_ir = rwt.mutate_irmodule(Identity) + + # immutatbility + assert new_ir != Identity + assert 2 == len(new_ir["main"].body.blocks[0].bindings) + + @tvm.script.ir_module + class GroundTruth: + @R.function + def main(x: Tensor((32, 32), "float32")) -> Tensor: + with R.dataflow(): + lv0 = x + tmp: Tensor((32, 32), "float32") = x + R.output(lv0) + return lv0 + + tvm.ir.assert_structural_equal(new_ir, GroundTruth)