From e90636fe4e21f0964ebbe0d8c0560ecb61c09fa8 Mon Sep 17 00:00:00 2001 From: Marisa Kirisame Date: Wed, 19 Aug 2020 14:03:05 +0000 Subject: [PATCH] commit save fix --- include/tvm/ir/module.h | 7 ++ python/tvm/relay/transform/memory_alloc.py | 11 ++- src/ir/module.cc | 12 ++- src/relay/backend/graph_plan_memory.cc | 4 +- src/relay/backend/vm/compiler.cc | 31 +++++- src/relay/transforms/fold_constant.cc | 16 +++ src/relay/transforms/gradient.cc | 109 +++++++++++++++++---- src/relay/transforms/partial_eval.cc | 6 +- tests/python/relay/test_pass_gradient.py | 32 +++++- 9 files changed, 198 insertions(+), 30 deletions(-) diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h index 7af84b687f5fd..9a6ea19506d73 100644 --- a/include/tvm/ir/module.h +++ b/include/tvm/ir/module.h @@ -181,6 +181,13 @@ class IRModuleNode : public Object { */ TVM_DLL BaseFunc Lookup(const GlobalVar& var) const; + /*! + * \brief Check if a global function exist by its variable. + * \param var The global var to lookup. + * \returns Wether the function named by the variable argument exist. + */ + TVM_DLL bool Exist(const GlobalVar& var) const; + /*! * \brief Look up a global function by its string name * \param name The name of the function. diff --git a/python/tvm/relay/transform/memory_alloc.py b/python/tvm/relay/transform/memory_alloc.py index ae7db33842142..1105ad7a804cc 100644 --- a/python/tvm/relay/transform/memory_alloc.py +++ b/python/tvm/relay/transform/memory_alloc.py @@ -40,18 +40,25 @@ def is_primitive(call): class CheckReshapeOnly(ExprVisitor): - """A pass to check if the fused op contains only reshape ops.""" + """ + A pass to check if the fused op contains only reshape ops. + TODO(@Jared) this is capturing the case where there is no any ops at all. + I had put a quick hack and require that it must have >= 1 reshape, but this must be masking some bigger problem. + Please fix - if you want to collapse reshape on reshape, perhaps you should do fusion as such. + """ def __init__(self): super().__init__() self._reshape_ops = [op.get("reshape"), op.get("contrib_reverse_reshape"), op.get("dyn.reshape")] self.reshape_only = True + self.has_reshape = False def visit_call(self, call): if not self.reshape_only: return if call.op not in self._reshape_ops: self.reshape_only = False + self.has_reshape = True for arg in call.args: self.visit(arg) @@ -60,7 +67,7 @@ def is_reshape_only(func): """Check if the primitive function contains only reshape ops.""" check = CheckReshapeOnly() check.visit(func) - return check.reshape_only + return check.reshape_only and check.has_reshape class ManifestAllocPass(ExprMutator): diff --git a/src/ir/module.cc b/src/ir/module.cc index bcab39aabf32e..91328d8ea4baa 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -218,7 +218,12 @@ void IRModuleNode::Add(const GlobalVar& var, const BaseFunc& f, bool update) { } void IRModuleNode::AddUnchecked(const GlobalVar& var, const BaseFunc& func) { - this->functions.Set(var, func); + auto* fn = func.as(); + if (fn) { + this->functions.Set(var, Downcast(tvm::relay::DeDup(GetRef(fn)))); + } else { + this->functions.Set(var, func); + } auto it = global_var_map_.find(var->name_hint); if (it != global_var_map_.end()) { @@ -284,6 +289,11 @@ BaseFunc IRModuleNode::Lookup(const GlobalVar& var) const { return (*it).second; } +bool IRModuleNode::Exist(const GlobalVar& var) const { + auto it = functions.find(var); + return it != functions.end(); +} + BaseFunc IRModuleNode::Lookup(const String& name) const { GlobalVar id = this->GetGlobalVar(name); return this->Lookup(id); diff --git a/src/relay/backend/graph_plan_memory.cc b/src/relay/backend/graph_plan_memory.cc index 820e17f8a4987..6ba1ce777f4fb 100644 --- a/src/relay/backend/graph_plan_memory.cc +++ b/src/relay/backend/graph_plan_memory.cc @@ -18,7 +18,7 @@ */ /*! - * \file relay/backend/graph_mem_alloca.cc + * \file relay/backend/graph_plan_memory.cc * \brief Memory index assignment pass for executing * the program in the graph runtime. */ @@ -68,7 +68,7 @@ class StorageAllocaBaseVisitor : public ExprVisitor { } void VisitExpr_(const FunctionNode* op) final { - // do not recursive into sub function. + // do not recurse into sub function. } void VisitExpr_(const GlobalVarNode* op) final { diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 33854f783d453..9e59ece1fd170 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -923,6 +923,31 @@ transform::Sequential MemoryOpt(tvm::Target host_target) { return transform::Sequential(pass_seqs); } +Pass CheckPrimeFunc() { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { + struct CheckPrimeFuncVisitor : ExprVisitor { + bool inside_primitive = false; + void VisitExpr_(const ConstantNode* op) override { + CHECK_EQ(inside_primitive, false); + } + void VisitExpr_(const FunctionNode* op) override { + if (op->HasNonzeroAttr(attr::kPrimitive)) { + CHECK_EQ(inside_primitive, false); + inside_primitive = true; + VisitExpr(op->body); + inside_primitive = false; + } else { + VisitExpr(op->body); + } + } + } vis; + vis(f); + return f; + }; + return CreateFunctionPass(pass_func, 1, "CheckPrimeFunc", {}); +} + IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targets, const Target& target_host) { Array pass_seqs; @@ -956,14 +981,19 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe } *rv = false; }); + pass_seqs.push_back(transform::EliminateCommonSubexpr(fskip)); pass_seqs.push_back(transform::SimplifyExpr()); pass_seqs.push_back(transform::InlinePrimitives()); + pass_seqs.push_back(transform::CombineParallelConv2D(3)); pass_seqs.push_back(transform::CombineParallelDense(3)); pass_seqs.push_back(transform::CombineParallelBatchMatmul(3)); + pass_seqs.push_back(transform::FoldConstant()); + //pass_seqs.push_back(tvm::transform::PrintIR()); + //pass_seqs.push_back(CheckPrimeFunc()); pass_seqs.push_back(transform::FoldScaleAxis()); pass_seqs.push_back(transform::CanonicalizeCast()); pass_seqs.push_back(transform::CanonicalizeOps()); @@ -1004,7 +1034,6 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe } void VMCompiler::PopulateGlobalMap() { - // First we populate global map. size_t global_index = 0; for (auto named_func : context_.module->functions) { auto gvar = named_func.first; diff --git a/src/relay/transforms/fold_constant.cc b/src/relay/transforms/fold_constant.cc index 0ecbfea8c9054..7a1941b161a16 100644 --- a/src/relay/transforms/fold_constant.cc +++ b/src/relay/transforms/fold_constant.cc @@ -104,7 +104,23 @@ class ConstantFolder : public ExprMutator { } } + bool inside_primitive = false; + Expr VisitExpr_(const FunctionNode* op) final { + if (op->HasNonzeroAttr(attr::kPrimitive)) { + CHECK_EQ(inside_primitive, false); + inside_primitive = true; + auto ret = ExprMutator::VisitExpr_(op); + inside_primitive = false; + return ret; + } else { + return ExprMutator::VisitExpr_(op); + } + } + Expr VisitExpr_(const CallNode* call) final { + if (inside_primitive) { + return GetRef(call); + } static auto op_stateful = Op::GetAttrMap("TOpIsStateful"); std::unordered_set skip_list{"zeros_like", "ones_like", "full_like", "full"}; diff --git a/src/relay/transforms/gradient.cc b/src/relay/transforms/gradient.cc index 7894c34de55db..002cc37a526ff 100644 --- a/src/relay/transforms/gradient.cc +++ b/src/relay/transforms/gradient.cc @@ -85,7 +85,7 @@ Expr DeGlobal(const Optional& mod, const Expr& e) { if (mod.defined() && x) { BaseFunc base_func = mod.value()->Lookup(GetRef(x)); if (auto* n = base_func.as()) { - return n->body; + return GetRef(n); } else { return e; } @@ -338,11 +338,22 @@ Expr FirstOrderGradient(const Expr& re, const Optional& mod) { TVM_REGISTER_GLOBAL("relay._transform.first_order_gradient").set_body_typed(FirstOrderGradient); +Type bpt = RelayRefType(FuncType({}, TupleType(Array()), {}, {})); + struct ReverseADType : TypeMutator { Type VisitType_(const TensorTypeNode* ttn) final { Type t = GetRef(ttn); return TupleType({t, RelayRefType(t)}); } + + Type VisitType_(const FuncTypeNode* ftn) final { + std::vector arg_types; + for (const auto& t : ftn->arg_types) { + arg_types.push_back(VisitType(t)); + } + arg_types.push_back(bpt); + return FuncType(arg_types, ftn->ret_type, ftn->type_params, ftn->type_constraints); + } }; Type ReverseType(const Type& t) { return ReverseADType()(t); } @@ -438,12 +449,21 @@ Expr BPEmpty() { struct ReverseAD : ExprMutator { using ADVarMap = std::unordered_map; - + using ADGVarMap = std::unordered_map; + Optional mod; Var bp; std::shared_ptr ad_vars; + std::shared_ptr ad_gvars; const OpAttrMap rev_map = Op::GetAttrMap("FPrimalGradient"); - explicit ReverseAD(const Var& bp, std::shared_ptr ad_vars) : bp(bp), ad_vars(ad_vars) {} + explicit ReverseAD(const Optional& mod, + const Var& bp, + const std::shared_ptr& ad_vars, + const std::shared_ptr& ad_gvars) : + mod(mod), + bp(bp), + ad_vars(ad_vars), + ad_gvars(ad_gvars) { } Expr VisitExpr_(const OpNode* op) final { LOG(FATAL) << "op should only be inside call"; @@ -481,8 +501,7 @@ struct ReverseAD : ExprMutator { Expr nbp = Function({}, LetList::With([&](LetList* ll) { // we need a new ReverseAD visitor to avoid clobbering the bp local var auto dup_bp = ll->Push(BPEmpty()); - ReverseAD dup_diff(dup_bp, ad_vars); - auto dup_ad = ll->Push(dup_diff.VisitExpr(DeDup(x))); + auto dup_ad = ll->Push(ReverseAD(mod, dup_bp, ad_vars, ad_gvars)(DeDup(x))); TransferGrads(call->checked_type(), ret, dup_ad, ll); ll->Push(Call(RefRead(dup_bp), {})); @@ -518,22 +537,32 @@ struct ReverseAD : ExprMutator { orig_var->checked_type_ = call->checked_type(); auto ret = ll->Push(GetRev(call->checked_type(), orig_var, ll)); auto bpv = ll->Push(RefRead(bp)); - Expr nbp = Function({}, LetList::With([&](LetList* ll) { - tvm::Array rev = - rev_map[op_ref](orig, GetGrad(call->checked_type(), ret, ll)); - CHECK(args.size() == rev.size()); - for (size_t i = 0; i < args.size(); ++i) { - UpdateGrad(call->args[i]->checked_type(), args[i], rev[i], ll); - } - return Call(bpv, {}); - }), - TupleType::Empty(), {}); + Expr nbp_body = + LetList::With([&](LetList* ll) { + tvm::Array rev = + rev_map[op_ref](orig, GetGrad(call->checked_type(), ret, ll)); + CHECK(args.size() == rev.size()); + for (size_t i = 0; i < args.size(); ++i) { + UpdateGrad(call->args[i]->checked_type(), args[i], rev[i], ll); + } + return Call(bpv, {}); + }); + Expr nbp = Function({}, nbp_body, TupleType::Empty(), {}); ll->Push(RefWrite(bp, transform::ToANormalForm(nbp))); // TODO(@M.K.): ToANF should be called on rev. Enhance ToANF for that. return ret; }); + } else if (const ConstructorNode* con = call->op.as()) { + return ExprMutator::VisitExpr_(call); + } + else { + std::vector args; + for (const auto& arg : call->args) { + args.push_back(VisitExpr(arg)); + } + args.push_back(bp); + return Call(VisitExpr(call->op), args); } - return ExprMutator::VisitExpr_(call); } Expr VisitExpr_(const ConstantNode* op) final { @@ -559,6 +588,38 @@ struct ReverseAD : ExprMutator { return ad_vars->at(var_ref); } + Expr VisitExpr_(const GlobalVarNode* op) final { + // todo: concatenating string to add attribute seems like a brittle hack. + // maybe get module indexed by a rose tree of string? + CHECK(mod.defined()); + auto orig_gv = GetRef(op); + if (ad_gvars->count(orig_gv) == 0) { + GlobalVar gv(op->name_hint + "_grad"); + (*ad_gvars)[orig_gv] = gv; + Function orig_f = Downcast(mod.value()->Lookup(GetRef(op))); + std::vector params; + for (const auto& p : orig_f->params) { + params.push_back(Downcast(VisitExpr(p))); + } + params.push_back(bp); + Expr body = VisitExpr(orig_f->body); + Function f(params, body, VisitType(orig_f->ret_type), orig_f->type_params, orig_f->attrs); + std::cout << "gv " << op->name_hint << ": " << AsText(f, false) << std::endl; + mod.value()->Add(gv, f); + } + return ad_gvars->at(orig_gv); + } + + Expr VisitExpr_(const FunctionNode* op) final { + std::vector params; + for (const auto& var: op->params) { + params.push_back(Downcast(VisitExpr(var))); + } + auto new_bp = Var("bp", bpt); + params.push_back(new_bp); + return Function(params, ReverseAD(mod, new_bp, ad_vars, ad_gvars)(op->body), VisitType(op->ret_type), op->type_params, op->attrs); + } + Type VisitType(const Type& t) final { return t.defined() ? ReverseType(t) : t; } }; @@ -604,12 +665,18 @@ Expr Gradient(const Expr& re, const Optional& mod) { } CHECK(!MissingGrad(e)) << "input has operators with missing gradients"; Expr body = LetList::With([&](LetList* ll) { - Var bp = ll->Push(BPEmpty()); - Expr rev = ReverseAD(bp, std::make_shared())(e); - std::vector args; + Var bp = ll->Push(BPEmpty(), bpt); + Expr rev = ReverseAD(mod, + bp, + std::make_shared(), + std::make_shared())(e); + std::vector normal_args, args; for (const auto& p : f->params) { - args.push_back(ll->Push(Pair(p, RefCreate(ZerosLike(p))))); + auto x = ll->Push(Pair(p, RefCreate(ZerosLike(p)))); + normal_args.push_back(x); + args.push_back(x); } + args.push_back(bp); auto c = ll->Push(Call(rev, args)); std::function init_grad; init_grad = [&](const Expr& e, const Type& t) { @@ -626,7 +693,7 @@ Expr Gradient(const Expr& re, const Optional& mod) { init_grad(c, f->body->checked_type()); ll->Push(Call(RefRead(bp), {})); std::vector ret; - for (const auto& a : args) { + for (const auto& a : normal_args) { ret.push_back(RefRead(GetField(a, 1))); } std::function get_final_result; diff --git a/src/relay/transforms/partial_eval.cc b/src/relay/transforms/partial_eval.cc index e07dbea59bd1b..d3e3498c3bc3a 100644 --- a/src/relay/transforms/partial_eval.cc +++ b/src/relay/transforms/partial_eval.cc @@ -91,6 +91,7 @@ */ #include #include +#include #include #include #include @@ -776,8 +777,9 @@ class PartialEvaluator : public ExprFunctor Func VisitFuncStatic(const Function& func, const Expr& var) { CHECK(IsAtomic(var)); + // todo: figure out primitive semantic if (func->HasNonzeroAttr(attr::kPrimitive)) { - return ConstEvaluateFunc(func); + // return ConstEvaluateFunc(func); } std::vector > free_vars; for (const auto& v : FreeVars(func)) { @@ -1200,7 +1202,7 @@ namespace transform { Pass PartialEval() { runtime::TypedPackedFunc pass_func = [=](IRModule m, PassContext pc) { return relay::PartialEval(m); }; - return CreateModulePass(pass_func, 1, "PartialEval", {}); + return CreateModulePass(pass_func, 1, "PartialEvaluate", {}); } TVM_REGISTER_GLOBAL("relay._transform.PartialEvaluate").set_body_typed(PartialEval); diff --git a/tests/python/relay/test_pass_gradient.py b/tests/python/relay/test_pass_gradient.py index 296d3e5e9354f..291d780f7b413 100644 --- a/tests/python/relay/test_pass_gradient.py +++ b/tests/python/relay/test_pass_gradient.py @@ -21,6 +21,7 @@ import tvm from tvm import te from tvm import relay +from tvm.relay import GlobalVar from tvm.relay.analysis import free_vars, free_type_vars from tvm.relay import create_executor, transform from tvm.relay.transform import gradient @@ -29,7 +30,7 @@ import tvm.relay.op as op -def test_id(): +def test_fo_id(): shape = (10, 10) dtype = 'float32' t = relay.TensorType(shape, dtype) @@ -44,6 +45,21 @@ def test_id(): tvm.testing.assert_allclose(forward.asnumpy(), x.asnumpy()) tvm.testing.assert_allclose(grad.asnumpy(), np.ones_like(x.asnumpy())) +def test_id(): + shape = (10, 10) + dtype = 'float32' + t = relay.TensorType(shape, dtype) + x = relay.var("x", t) + func = relay.Function([x], x) + func = run_infer_type(func) + back_func = run_infer_type(gradient(func)) + assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])])) + ex = create_executor() + x = rand(dtype, *shape) + forward, (grad,) = ex.evaluate(back_func)(x) + tvm.testing.assert_allclose(forward.asnumpy(), x.asnumpy()) + tvm.testing.assert_allclose(grad.asnumpy(), np.ones_like(x.asnumpy())) + def test_relu(): shape = (10, 10) @@ -341,5 +357,19 @@ def test_no_duplication(): counts = count_ops(gr) assert counts['nn.dense'] == 3, "We expect 3 dense (1 forward, two backward)" + +def test_global_function(): + m = tvm.IRModule() + t = relay.TensorType([]) + x = relay.Var('x', t) + d = GlobalVar('double') + m[d] = relay.Function([x], x + x) + y = relay.Var('y', t) + q = GlobalVar('q') + m[q] = relay.Function([y], d(d(y))) + g = GlobalVar('grad') + m[g] = tvm.relay.transform.gradient(q, m) + + if __name__ == "__main__": pytest.main([__file__])