From 3ec296d895403e8ebddd7b55405d4fcdae1df465 Mon Sep 17 00:00:00 2001 From: Marisa Kirisame Date: Wed, 19 Aug 2020 14:03:05 +0000 Subject: [PATCH 1/2] commit save fix --- include/tvm/ir/module.h | 7 +++++ python/tvm/relay/transform/memory_alloc.py | 11 ++++++-- src/ir/module.cc | 12 ++++++++- src/relay/backend/vm/compiler.cc | 31 +++++++++++++++++++++- src/relay/transforms/fold_constant.cc | 16 +++++++++++ src/relay/transforms/partial_eval.cc | 6 +++-- 6 files changed, 77 insertions(+), 6 deletions(-) diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h index 7af84b687f5f..9a6ea19506d7 100644 --- a/include/tvm/ir/module.h +++ b/include/tvm/ir/module.h @@ -181,6 +181,13 @@ class IRModuleNode : public Object { */ TVM_DLL BaseFunc Lookup(const GlobalVar& var) const; + /*! + * \brief Check if a global function exist by its variable. + * \param var The global var to lookup. + * \returns Wether the function named by the variable argument exist. + */ + TVM_DLL bool Exist(const GlobalVar& var) const; + /*! * \brief Look up a global function by its string name * \param name The name of the function. diff --git a/python/tvm/relay/transform/memory_alloc.py b/python/tvm/relay/transform/memory_alloc.py index ae7db3384214..1105ad7a804c 100644 --- a/python/tvm/relay/transform/memory_alloc.py +++ b/python/tvm/relay/transform/memory_alloc.py @@ -40,18 +40,25 @@ def is_primitive(call): class CheckReshapeOnly(ExprVisitor): - """A pass to check if the fused op contains only reshape ops.""" + """ + A pass to check if the fused op contains only reshape ops. + TODO(@Jared) this is capturing the case where there is no any ops at all. + I had put a quick hack and require that it must have >= 1 reshape, but this must be masking some bigger problem. + Please fix - if you want to collapse reshape on reshape, perhaps you should do fusion as such. + """ def __init__(self): super().__init__() self._reshape_ops = [op.get("reshape"), op.get("contrib_reverse_reshape"), op.get("dyn.reshape")] self.reshape_only = True + self.has_reshape = False def visit_call(self, call): if not self.reshape_only: return if call.op not in self._reshape_ops: self.reshape_only = False + self.has_reshape = True for arg in call.args: self.visit(arg) @@ -60,7 +67,7 @@ def is_reshape_only(func): """Check if the primitive function contains only reshape ops.""" check = CheckReshapeOnly() check.visit(func) - return check.reshape_only + return check.reshape_only and check.has_reshape class ManifestAllocPass(ExprMutator): diff --git a/src/ir/module.cc b/src/ir/module.cc index bcab39aabf32..91328d8ea4ba 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -218,7 +218,12 @@ void IRModuleNode::Add(const GlobalVar& var, const BaseFunc& f, bool update) { } void IRModuleNode::AddUnchecked(const GlobalVar& var, const BaseFunc& func) { - this->functions.Set(var, func); + auto* fn = func.as(); + if (fn) { + this->functions.Set(var, Downcast(tvm::relay::DeDup(GetRef(fn)))); + } else { + this->functions.Set(var, func); + } auto it = global_var_map_.find(var->name_hint); if (it != global_var_map_.end()) { @@ -284,6 +289,11 @@ BaseFunc IRModuleNode::Lookup(const GlobalVar& var) const { return (*it).second; } +bool IRModuleNode::Exist(const GlobalVar& var) const { + auto it = functions.find(var); + return it != functions.end(); +} + BaseFunc IRModuleNode::Lookup(const String& name) const { GlobalVar id = this->GetGlobalVar(name); return this->Lookup(id); diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 33854f783d45..9e59ece1fd17 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -923,6 +923,31 @@ transform::Sequential MemoryOpt(tvm::Target host_target) { return transform::Sequential(pass_seqs); } +Pass CheckPrimeFunc() { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { + struct CheckPrimeFuncVisitor : ExprVisitor { + bool inside_primitive = false; + void VisitExpr_(const ConstantNode* op) override { + CHECK_EQ(inside_primitive, false); + } + void VisitExpr_(const FunctionNode* op) override { + if (op->HasNonzeroAttr(attr::kPrimitive)) { + CHECK_EQ(inside_primitive, false); + inside_primitive = true; + VisitExpr(op->body); + inside_primitive = false; + } else { + VisitExpr(op->body); + } + } + } vis; + vis(f); + return f; + }; + return CreateFunctionPass(pass_func, 1, "CheckPrimeFunc", {}); +} + IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targets, const Target& target_host) { Array pass_seqs; @@ -956,14 +981,19 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe } *rv = false; }); + pass_seqs.push_back(transform::EliminateCommonSubexpr(fskip)); pass_seqs.push_back(transform::SimplifyExpr()); pass_seqs.push_back(transform::InlinePrimitives()); + pass_seqs.push_back(transform::CombineParallelConv2D(3)); pass_seqs.push_back(transform::CombineParallelDense(3)); pass_seqs.push_back(transform::CombineParallelBatchMatmul(3)); + pass_seqs.push_back(transform::FoldConstant()); + //pass_seqs.push_back(tvm::transform::PrintIR()); + //pass_seqs.push_back(CheckPrimeFunc()); pass_seqs.push_back(transform::FoldScaleAxis()); pass_seqs.push_back(transform::CanonicalizeCast()); pass_seqs.push_back(transform::CanonicalizeOps()); @@ -1004,7 +1034,6 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe } void VMCompiler::PopulateGlobalMap() { - // First we populate global map. size_t global_index = 0; for (auto named_func : context_.module->functions) { auto gvar = named_func.first; diff --git a/src/relay/transforms/fold_constant.cc b/src/relay/transforms/fold_constant.cc index 0ecbfea8c905..7a1941b161a1 100644 --- a/src/relay/transforms/fold_constant.cc +++ b/src/relay/transforms/fold_constant.cc @@ -104,7 +104,23 @@ class ConstantFolder : public ExprMutator { } } + bool inside_primitive = false; + Expr VisitExpr_(const FunctionNode* op) final { + if (op->HasNonzeroAttr(attr::kPrimitive)) { + CHECK_EQ(inside_primitive, false); + inside_primitive = true; + auto ret = ExprMutator::VisitExpr_(op); + inside_primitive = false; + return ret; + } else { + return ExprMutator::VisitExpr_(op); + } + } + Expr VisitExpr_(const CallNode* call) final { + if (inside_primitive) { + return GetRef(call); + } static auto op_stateful = Op::GetAttrMap("TOpIsStateful"); std::unordered_set skip_list{"zeros_like", "ones_like", "full_like", "full"}; diff --git a/src/relay/transforms/partial_eval.cc b/src/relay/transforms/partial_eval.cc index e07dbea59bd1..d3e3498c3bc3 100644 --- a/src/relay/transforms/partial_eval.cc +++ b/src/relay/transforms/partial_eval.cc @@ -91,6 +91,7 @@ */ #include #include +#include #include #include #include @@ -776,8 +777,9 @@ class PartialEvaluator : public ExprFunctor Func VisitFuncStatic(const Function& func, const Expr& var) { CHECK(IsAtomic(var)); + // todo: figure out primitive semantic if (func->HasNonzeroAttr(attr::kPrimitive)) { - return ConstEvaluateFunc(func); + // return ConstEvaluateFunc(func); } std::vector > free_vars; for (const auto& v : FreeVars(func)) { @@ -1200,7 +1202,7 @@ namespace transform { Pass PartialEval() { runtime::TypedPackedFunc pass_func = [=](IRModule m, PassContext pc) { return relay::PartialEval(m); }; - return CreateModulePass(pass_func, 1, "PartialEval", {}); + return CreateModulePass(pass_func, 1, "PartialEvaluate", {}); } TVM_REGISTER_GLOBAL("relay._transform.PartialEvaluate").set_body_typed(PartialEval); From 86eb174ec1baf71cfa086c6253fd8f450a816ded Mon Sep 17 00:00:00 2001 From: Marisa Kirisame Date: Thu, 27 Aug 2020 20:25:51 +0000 Subject: [PATCH 2/2] drop changes --- include/tvm/ir/module.h | 7 ------- src/ir/module.cc | 12 +----------- 2 files changed, 1 insertion(+), 18 deletions(-) diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h index 9a6ea19506d7..7af84b687f5f 100644 --- a/include/tvm/ir/module.h +++ b/include/tvm/ir/module.h @@ -181,13 +181,6 @@ class IRModuleNode : public Object { */ TVM_DLL BaseFunc Lookup(const GlobalVar& var) const; - /*! - * \brief Check if a global function exist by its variable. - * \param var The global var to lookup. - * \returns Wether the function named by the variable argument exist. - */ - TVM_DLL bool Exist(const GlobalVar& var) const; - /*! * \brief Look up a global function by its string name * \param name The name of the function. diff --git a/src/ir/module.cc b/src/ir/module.cc index 91328d8ea4ba..bcab39aabf32 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -218,12 +218,7 @@ void IRModuleNode::Add(const GlobalVar& var, const BaseFunc& f, bool update) { } void IRModuleNode::AddUnchecked(const GlobalVar& var, const BaseFunc& func) { - auto* fn = func.as(); - if (fn) { - this->functions.Set(var, Downcast(tvm::relay::DeDup(GetRef(fn)))); - } else { - this->functions.Set(var, func); - } + this->functions.Set(var, func); auto it = global_var_map_.find(var->name_hint); if (it != global_var_map_.end()) { @@ -289,11 +284,6 @@ BaseFunc IRModuleNode::Lookup(const GlobalVar& var) const { return (*it).second; } -bool IRModuleNode::Exist(const GlobalVar& var) const { - auto it = functions.find(var); - return it != functions.end(); -} - BaseFunc IRModuleNode::Lookup(const String& name) const { GlobalVar id = this->GetGlobalVar(name); return this->Lookup(id);