Skip to content

Commit

Permalink
commit
Browse files Browse the repository at this point in the history
save

fix
  • Loading branch information
MarisaKirisame committed Aug 25, 2020
1 parent 1a26a2e commit e90636f
Show file tree
Hide file tree
Showing 9 changed files with 198 additions and 30 deletions.
7 changes: 7 additions & 0 deletions include/tvm/ir/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,13 @@ class IRModuleNode : public Object {
*/
TVM_DLL BaseFunc Lookup(const GlobalVar& var) const;

/*!
* \brief Check if a global function exist by its variable.
* \param var The global var to lookup.
* \returns Wether the function named by the variable argument exist.
*/
TVM_DLL bool Exist(const GlobalVar& var) const;

/*!
* \brief Look up a global function by its string name
* \param name The name of the function.
Expand Down
11 changes: 9 additions & 2 deletions python/tvm/relay/transform/memory_alloc.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,25 @@ def is_primitive(call):


class CheckReshapeOnly(ExprVisitor):
"""A pass to check if the fused op contains only reshape ops."""
"""
A pass to check if the fused op contains only reshape ops.
TODO(@Jared) this is capturing the case where there is no any ops at all.
I had put a quick hack and require that it must have >= 1 reshape, but this must be masking some bigger problem.
Please fix - if you want to collapse reshape on reshape, perhaps you should do fusion as such.
"""
def __init__(self):
super().__init__()
self._reshape_ops = [op.get("reshape"), op.get("contrib_reverse_reshape"),
op.get("dyn.reshape")]
self.reshape_only = True
self.has_reshape = False

def visit_call(self, call):
if not self.reshape_only:
return
if call.op not in self._reshape_ops:
self.reshape_only = False
self.has_reshape = True
for arg in call.args:
self.visit(arg)

Expand All @@ -60,7 +67,7 @@ def is_reshape_only(func):
"""Check if the primitive function contains only reshape ops."""
check = CheckReshapeOnly()
check.visit(func)
return check.reshape_only
return check.reshape_only and check.has_reshape


class ManifestAllocPass(ExprMutator):
Expand Down
12 changes: 11 additions & 1 deletion src/ir/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,12 @@ void IRModuleNode::Add(const GlobalVar& var, const BaseFunc& f, bool update) {
}

void IRModuleNode::AddUnchecked(const GlobalVar& var, const BaseFunc& func) {
this->functions.Set(var, func);
auto* fn = func.as<tvm::relay::FunctionNode>();
if (fn) {
this->functions.Set(var, Downcast<tvm::relay::Function>(tvm::relay::DeDup(GetRef<tvm::relay::Function>(fn))));
} else {
this->functions.Set(var, func);
}

auto it = global_var_map_.find(var->name_hint);
if (it != global_var_map_.end()) {
Expand Down Expand Up @@ -284,6 +289,11 @@ BaseFunc IRModuleNode::Lookup(const GlobalVar& var) const {
return (*it).second;
}

bool IRModuleNode::Exist(const GlobalVar& var) const {
auto it = functions.find(var);
return it != functions.end();
}

BaseFunc IRModuleNode::Lookup(const String& name) const {
GlobalVar id = this->GetGlobalVar(name);
return this->Lookup(id);
Expand Down
4 changes: 2 additions & 2 deletions src/relay/backend/graph_plan_memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
*/

/*!
* \file relay/backend/graph_mem_alloca.cc
* \file relay/backend/graph_plan_memory.cc
* \brief Memory index assignment pass for executing
* the program in the graph runtime.
*/
Expand Down Expand Up @@ -68,7 +68,7 @@ class StorageAllocaBaseVisitor : public ExprVisitor {
}

void VisitExpr_(const FunctionNode* op) final {
// do not recursive into sub function.
// do not recurse into sub function.
}

void VisitExpr_(const GlobalVarNode* op) final {
Expand Down
31 changes: 30 additions & 1 deletion src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -923,6 +923,31 @@ transform::Sequential MemoryOpt(tvm::Target host_target) {
return transform::Sequential(pass_seqs);
}

Pass CheckPrimeFunc() {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
struct CheckPrimeFuncVisitor : ExprVisitor {
bool inside_primitive = false;
void VisitExpr_(const ConstantNode* op) override {
CHECK_EQ(inside_primitive, false);
}
void VisitExpr_(const FunctionNode* op) override {
if (op->HasNonzeroAttr(attr::kPrimitive)) {
CHECK_EQ(inside_primitive, false);
inside_primitive = true;
VisitExpr(op->body);
inside_primitive = false;
} else {
VisitExpr(op->body);
}
}
} vis;
vis(f);
return f;
};
return CreateFunctionPass(pass_func, 1, "CheckPrimeFunc", {});
}

IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targets,
const Target& target_host) {
Array<Pass> pass_seqs;
Expand Down Expand Up @@ -956,14 +981,19 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe
}
*rv = false;
});

pass_seqs.push_back(transform::EliminateCommonSubexpr(fskip));
pass_seqs.push_back(transform::SimplifyExpr());
pass_seqs.push_back(transform::InlinePrimitives());


pass_seqs.push_back(transform::CombineParallelConv2D(3));
pass_seqs.push_back(transform::CombineParallelDense(3));
pass_seqs.push_back(transform::CombineParallelBatchMatmul(3));

pass_seqs.push_back(transform::FoldConstant());
//pass_seqs.push_back(tvm::transform::PrintIR());
//pass_seqs.push_back(CheckPrimeFunc());
pass_seqs.push_back(transform::FoldScaleAxis());
pass_seqs.push_back(transform::CanonicalizeCast());
pass_seqs.push_back(transform::CanonicalizeOps());
Expand Down Expand Up @@ -1004,7 +1034,6 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe
}

void VMCompiler::PopulateGlobalMap() {
// First we populate global map.
size_t global_index = 0;
for (auto named_func : context_.module->functions) {
auto gvar = named_func.first;
Expand Down
16 changes: 16 additions & 0 deletions src/relay/transforms/fold_constant.cc
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,23 @@ class ConstantFolder : public ExprMutator {
}
}

bool inside_primitive = false;
Expr VisitExpr_(const FunctionNode* op) final {
if (op->HasNonzeroAttr(attr::kPrimitive)) {
CHECK_EQ(inside_primitive, false);
inside_primitive = true;
auto ret = ExprMutator::VisitExpr_(op);
inside_primitive = false;
return ret;
} else {
return ExprMutator::VisitExpr_(op);
}
}

Expr VisitExpr_(const CallNode* call) final {
if (inside_primitive) {
return GetRef<Expr>(call);
}
static auto op_stateful = Op::GetAttrMap<TOpIsStateful>("TOpIsStateful");

std::unordered_set<std::string> skip_list{"zeros_like", "ones_like", "full_like", "full"};
Expand Down
109 changes: 88 additions & 21 deletions src/relay/transforms/gradient.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ Expr DeGlobal(const Optional<IRModule>& mod, const Expr& e) {
if (mod.defined() && x) {
BaseFunc base_func = mod.value()->Lookup(GetRef<GlobalVar>(x));
if (auto* n = base_func.as<FunctionNode>()) {
return n->body;
return GetRef<Function>(n);
} else {
return e;
}
Expand Down Expand Up @@ -338,11 +338,22 @@ Expr FirstOrderGradient(const Expr& re, const Optional<IRModule>& mod) {

TVM_REGISTER_GLOBAL("relay._transform.first_order_gradient").set_body_typed(FirstOrderGradient);

Type bpt = RelayRefType(FuncType({}, TupleType(Array<Type>()), {}, {}));

struct ReverseADType : TypeMutator {
Type VisitType_(const TensorTypeNode* ttn) final {
Type t = GetRef<Type>(ttn);
return TupleType({t, RelayRefType(t)});
}

Type VisitType_(const FuncTypeNode* ftn) final {
std::vector<Type> arg_types;
for (const auto& t : ftn->arg_types) {
arg_types.push_back(VisitType(t));
}
arg_types.push_back(bpt);
return FuncType(arg_types, ftn->ret_type, ftn->type_params, ftn->type_constraints);
}
};

Type ReverseType(const Type& t) { return ReverseADType()(t); }
Expand Down Expand Up @@ -438,12 +449,21 @@ Expr BPEmpty() {

struct ReverseAD : ExprMutator {
using ADVarMap = std::unordered_map<Var, Var, ObjectPtrHash, ObjectPtrEqual>;

using ADGVarMap = std::unordered_map<GlobalVar, GlobalVar, ObjectPtrHash, ObjectPtrEqual>;
Optional<IRModule> mod;
Var bp;
std::shared_ptr<ADVarMap> ad_vars;
std::shared_ptr<ADGVarMap> ad_gvars;
const OpAttrMap<FPrimalGradient> rev_map = Op::GetAttrMap<FPrimalGradient>("FPrimalGradient");

explicit ReverseAD(const Var& bp, std::shared_ptr<ADVarMap> ad_vars) : bp(bp), ad_vars(ad_vars) {}
explicit ReverseAD(const Optional<IRModule>& mod,
const Var& bp,
const std::shared_ptr<ADVarMap>& ad_vars,
const std::shared_ptr<ADGVarMap>& ad_gvars) :
mod(mod),
bp(bp),
ad_vars(ad_vars),
ad_gvars(ad_gvars) { }

Expr VisitExpr_(const OpNode* op) final {
LOG(FATAL) << "op should only be inside call";
Expand Down Expand Up @@ -481,8 +501,7 @@ struct ReverseAD : ExprMutator {
Expr nbp = Function({}, LetList::With([&](LetList* ll) {
// we need a new ReverseAD visitor to avoid clobbering the bp local var
auto dup_bp = ll->Push(BPEmpty());
ReverseAD dup_diff(dup_bp, ad_vars);
auto dup_ad = ll->Push(dup_diff.VisitExpr(DeDup(x)));
auto dup_ad = ll->Push(ReverseAD(mod, dup_bp, ad_vars, ad_gvars)(DeDup(x)));

TransferGrads(call->checked_type(), ret, dup_ad, ll);
ll->Push(Call(RefRead(dup_bp), {}));
Expand Down Expand Up @@ -518,22 +537,32 @@ struct ReverseAD : ExprMutator {
orig_var->checked_type_ = call->checked_type();
auto ret = ll->Push(GetRev(call->checked_type(), orig_var, ll));
auto bpv = ll->Push(RefRead(bp));
Expr nbp = Function({}, LetList::With([&](LetList* ll) {
tvm::Array<Expr> rev =
rev_map[op_ref](orig, GetGrad(call->checked_type(), ret, ll));
CHECK(args.size() == rev.size());
for (size_t i = 0; i < args.size(); ++i) {
UpdateGrad(call->args[i]->checked_type(), args[i], rev[i], ll);
}
return Call(bpv, {});
}),
TupleType::Empty(), {});
Expr nbp_body =
LetList::With([&](LetList* ll) {
tvm::Array<Expr> rev =
rev_map[op_ref](orig, GetGrad(call->checked_type(), ret, ll));
CHECK(args.size() == rev.size());
for (size_t i = 0; i < args.size(); ++i) {
UpdateGrad(call->args[i]->checked_type(), args[i], rev[i], ll);
}
return Call(bpv, {});
});
Expr nbp = Function({}, nbp_body, TupleType::Empty(), {});
ll->Push(RefWrite(bp, transform::ToANormalForm(nbp)));
// TODO(@M.K.): ToANF should be called on rev. Enhance ToANF for that.
return ret;
});
} else if (const ConstructorNode* con = call->op.as<ConstructorNode>()) {
return ExprMutator::VisitExpr_(call);
}
else {
std::vector<Expr> args;
for (const auto& arg : call->args) {
args.push_back(VisitExpr(arg));
}
args.push_back(bp);
return Call(VisitExpr(call->op), args);
}
return ExprMutator::VisitExpr_(call);
}

Expr VisitExpr_(const ConstantNode* op) final {
Expand All @@ -559,6 +588,38 @@ struct ReverseAD : ExprMutator {
return ad_vars->at(var_ref);
}

Expr VisitExpr_(const GlobalVarNode* op) final {
// todo: concatenating string to add attribute seems like a brittle hack.
// maybe get module indexed by a rose tree of string?
CHECK(mod.defined());
auto orig_gv = GetRef<GlobalVar>(op);
if (ad_gvars->count(orig_gv) == 0) {
GlobalVar gv(op->name_hint + "_grad");
(*ad_gvars)[orig_gv] = gv;
Function orig_f = Downcast<Function>(mod.value()->Lookup(GetRef<GlobalVar>(op)));
std::vector<Var> params;
for (const auto& p : orig_f->params) {
params.push_back(Downcast<Var>(VisitExpr(p)));
}
params.push_back(bp);
Expr body = VisitExpr(orig_f->body);
Function f(params, body, VisitType(orig_f->ret_type), orig_f->type_params, orig_f->attrs);
std::cout << "gv " << op->name_hint << ": " << AsText(f, false) << std::endl;
mod.value()->Add(gv, f);
}
return ad_gvars->at(orig_gv);
}

Expr VisitExpr_(const FunctionNode* op) final {
std::vector<Var> params;
for (const auto& var: op->params) {
params.push_back(Downcast<Var>(VisitExpr(var)));
}
auto new_bp = Var("bp", bpt);
params.push_back(new_bp);
return Function(params, ReverseAD(mod, new_bp, ad_vars, ad_gvars)(op->body), VisitType(op->ret_type), op->type_params, op->attrs);
}

Type VisitType(const Type& t) final { return t.defined() ? ReverseType(t) : t; }
};

Expand Down Expand Up @@ -604,12 +665,18 @@ Expr Gradient(const Expr& re, const Optional<IRModule>& mod) {
}
CHECK(!MissingGrad(e)) << "input has operators with missing gradients";
Expr body = LetList::With([&](LetList* ll) {
Var bp = ll->Push(BPEmpty());
Expr rev = ReverseAD(bp, std::make_shared<ReverseAD::ADVarMap>())(e);
std::vector<Expr> args;
Var bp = ll->Push(BPEmpty(), bpt);
Expr rev = ReverseAD(mod,
bp,
std::make_shared<ReverseAD::ADVarMap>(),
std::make_shared<ReverseAD::ADGVarMap>())(e);
std::vector<Expr> normal_args, args;
for (const auto& p : f->params) {
args.push_back(ll->Push(Pair(p, RefCreate(ZerosLike(p)))));
auto x = ll->Push(Pair(p, RefCreate(ZerosLike(p))));
normal_args.push_back(x);
args.push_back(x);
}
args.push_back(bp);
auto c = ll->Push(Call(rev, args));
std::function<void(const Expr&, const Type&)> init_grad;
init_grad = [&](const Expr& e, const Type& t) {
Expand All @@ -626,7 +693,7 @@ Expr Gradient(const Expr& re, const Optional<IRModule>& mod) {
init_grad(c, f->body->checked_type());
ll->Push(Call(RefRead(bp), {}));
std::vector<Expr> ret;
for (const auto& a : args) {
for (const auto& a : normal_args) {
ret.push_back(RefRead(GetField(a, 1)));
}
std::function<Expr(const Expr&, const Type&)> get_final_result;
Expand Down
6 changes: 4 additions & 2 deletions src/relay/transforms/partial_eval.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@
*/
#include <tvm/ir/type_functor.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/feature.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/feature.h>
#include <tvm/relay/interpreter.h>
Expand Down Expand Up @@ -776,8 +777,9 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>

Func VisitFuncStatic(const Function& func, const Expr& var) {
CHECK(IsAtomic(var));
// todo: figure out primitive semantic
if (func->HasNonzeroAttr(attr::kPrimitive)) {
return ConstEvaluateFunc(func);
// return ConstEvaluateFunc(func);
}
std::vector<std::pair<Var, PStatic> > free_vars;
for (const auto& v : FreeVars(func)) {
Expand Down Expand Up @@ -1200,7 +1202,7 @@ namespace transform {
Pass PartialEval() {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
[=](IRModule m, PassContext pc) { return relay::PartialEval(m); };
return CreateModulePass(pass_func, 1, "PartialEval", {});
return CreateModulePass(pass_func, 1, "PartialEvaluate", {});
}

TVM_REGISTER_GLOBAL("relay._transform.PartialEvaluate").set_body_typed(PartialEval);
Expand Down
Loading

0 comments on commit e90636f

Please sign in to comment.