Skip to content

Commit

Permalink
commit
Browse files Browse the repository at this point in the history
save
  • Loading branch information
MarisaKirisame committed Aug 22, 2020
1 parent aae096a commit 19a3e65
Show file tree
Hide file tree
Showing 10 changed files with 205 additions and 33 deletions.
7 changes: 7 additions & 0 deletions include/tvm/ir/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,13 @@ class IRModuleNode : public Object {
*/
TVM_DLL BaseFunc Lookup(const GlobalVar& var) const;

/*!
* \brief Check if a global function exist by its variable.
* \param var The global var to lookup.
* \returns Wether the function named by the variable argument exist.
*/
TVM_DLL bool Exist(const GlobalVar& var) const;

/*!
* \brief Look up a global function by its string name
* \param name The name of the function.
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/backend/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def __init__(self):
self._codegen = self.mod["codegen"]
self._get_exec = self.mod["get_executable"]
self._set_params_func = self.mod["set_params"]
self._get_params_func = self.mod["get_params"]
self._gets_params_func = self.mod["get_params"]
self._optimize = self.mod["optimize"]

def set_params(self, params):
Expand Down
11 changes: 9 additions & 2 deletions python/tvm/relay/transform/memory_alloc.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,25 @@ def is_primitive(call):


class CheckReshapeOnly(ExprVisitor):
"""A pass to check if the fused op contains only reshape ops."""
"""
A pass to check if the fused op contains only reshape ops.
TODO(@Jared) this is capturing the case where there is no any ops at all.
I had put a quick hack and require that it must have >= 1 reshape, but this must be masking some bigger problem.
Please fix - if you want to collapse reshape on reshape, perhaps you should do fusion as such.
"""
def __init__(self):
super().__init__()
self._reshape_ops = [op.get("reshape"), op.get("contrib_reverse_reshape"),
op.get("dyn.reshape")]
self.reshape_only = True
self.has_reshape = False

def visit_call(self, call):
if not self.reshape_only:
return
if call.op not in self._reshape_ops:
self.reshape_only = False
self.has_reshape = True
for arg in call.args:
self.visit(arg)

Expand All @@ -60,7 +67,7 @@ def is_reshape_only(func):
"""Check if the primitive function contains only reshape ops."""
check = CheckReshapeOnly()
check.visit(func)
return check.reshape_only
return check.reshape_only and check.has_reshape


class ManifestAllocPass(ExprMutator):
Expand Down
12 changes: 11 additions & 1 deletion src/ir/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,12 @@ void IRModuleNode::Add(const GlobalVar& var, const BaseFunc& f, bool update) {
}

void IRModuleNode::AddUnchecked(const GlobalVar& var, const BaseFunc& func) {
this->functions.Set(var, func);
auto* fn = func.as<tvm::relay::FunctionNode>();
if (fn) {
this->functions.Set(var, Downcast<tvm::relay::Function>(tvm::relay::DeDup(GetRef<tvm::relay::Function>(fn))));
} else {
this->functions.Set(var, func);
}

auto it = global_var_map_.find(var->name_hint);
if (it != global_var_map_.end()) {
Expand Down Expand Up @@ -284,6 +289,11 @@ BaseFunc IRModuleNode::Lookup(const GlobalVar& var) const {
return (*it).second;
}

bool IRModuleNode::Exist(const GlobalVar& var) const {
auto it = functions.find(var);
return it != functions.end();
}

BaseFunc IRModuleNode::Lookup(const String& name) const {
GlobalVar id = this->GetGlobalVar(name);
return this->Lookup(id);
Expand Down
4 changes: 2 additions & 2 deletions src/relay/backend/graph_plan_memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
*/

/*!
* \file relay/backend/graph_mem_alloca.cc
* \file relay/backend/graph_plan_memory.cc
* \brief Memory index assignment pass for executing
* the program in the graph runtime.
*/
Expand Down Expand Up @@ -68,7 +68,7 @@ class StorageAllocaBaseVisitor : public ExprVisitor {
}

void VisitExpr_(const FunctionNode* op) final {
// do not recursive into sub function.
// do not recurse into sub function.
}

void VisitExpr_(const GlobalVarNode* op) final {
Expand Down
32 changes: 31 additions & 1 deletion src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -923,7 +923,33 @@ transform::Sequential MemoryOpt(tvm::Target host_target) {
return transform::Sequential(pass_seqs);
}

Pass CheckPrimeFunc() {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
struct CheckPrimeFuncVisitor : ExprVisitor {
bool inside_primitive = false;
void VisitExpr_(const ConstantNode* op) override {
CHECK_EQ(inside_primitive, false);
}
void VisitExpr_(const FunctionNode* op) override {
if (op->HasNonzeroAttr(attr::kPrimitive)) {
CHECK_EQ(inside_primitive, false);
inside_primitive = true;
VisitExpr(op->body);
inside_primitive = false;
} else {
VisitExpr(op->body);
}
}
} vis;
vis(f);
return f;
};
return CreateFunctionPass(pass_func, 1, "CheckPrimeFunc", {});
}

IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targets) {

Array<Pass> pass_seqs;
Array<runtime::String> entry_functions{"main"};
pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions));
Expand Down Expand Up @@ -955,14 +981,19 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe
}
*rv = false;
});

pass_seqs.push_back(transform::EliminateCommonSubexpr(fskip));
pass_seqs.push_back(transform::SimplifyExpr());
pass_seqs.push_back(transform::InlinePrimitives());


pass_seqs.push_back(transform::CombineParallelConv2D(3));
pass_seqs.push_back(transform::CombineParallelDense(3));
pass_seqs.push_back(transform::CombineParallelBatchMatmul(3));

pass_seqs.push_back(transform::FoldConstant());
//pass_seqs.push_back(tvm::transform::PrintIR());
//pass_seqs.push_back(CheckPrimeFunc());
pass_seqs.push_back(transform::FoldScaleAxis());
pass_seqs.push_back(transform::CanonicalizeCast());
pass_seqs.push_back(transform::CanonicalizeOps());
Expand Down Expand Up @@ -1003,7 +1034,6 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe
}

void VMCompiler::PopulateGlobalMap() {
// First we populate global map.
size_t global_index = 0;
for (auto named_func : context_.module->functions) {
auto gvar = named_func.first;
Expand Down
16 changes: 16 additions & 0 deletions src/relay/transforms/fold_constant.cc
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,23 @@ class ConstantFolder : public ExprMutator {
}
}

bool inside_primitive = false;
Expr VisitExpr_(const FunctionNode* op) final {
if (op->HasNonzeroAttr(attr::kPrimitive)) {
CHECK_EQ(inside_primitive, false);
inside_primitive = true;
auto ret = ExprMutator::VisitExpr_(op);
inside_primitive = false;
return ret;
} else {
return ExprMutator::VisitExpr_(op);
}
}

Expr VisitExpr_(const CallNode* call) final {
if (inside_primitive) {
return GetRef<Expr>(call);
}
static auto op_stateful = Op::GetAttrMap<TOpIsStateful>("TOpIsStateful");

std::unordered_set<std::string> skip_list{"zeros_like", "ones_like", "full_like", "full"};
Expand Down
113 changes: 90 additions & 23 deletions src/relay/transforms/gradient.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,10 @@ Type WithGradientType(const Type& t) {
Expr DeGlobal(const Optional<IRModule>& mod, const Expr& e) {
const auto* x = e.as<GlobalVarNode>();

if (mod.defined() && (x)) {
if (mod.defined() && x) {
BaseFunc base_func = mod.value()->Lookup(GetRef<GlobalVar>(x));
if (auto* n = base_func.as<FunctionNode>()) {
return n->body;
return GetRef<Function>(n);
} else {
return e;
}
Expand Down Expand Up @@ -337,11 +337,22 @@ Expr FirstOrderGradient(const Expr& re, const Optional<IRModule>& mod) {

TVM_REGISTER_GLOBAL("relay._transform.first_order_gradient").set_body_typed(FirstOrderGradient);

Type bpt = RelayRefType(FuncType({}, TupleType(Array<Type>()), {}, {}));

struct ReverseADType : TypeMutator {
Type VisitType_(const TensorTypeNode* ttn) final {
Type t = GetRef<Type>(ttn);
return TupleType({t, RelayRefType(t)});
}

Type VisitType_(const FuncTypeNode* ftn) final {
std::vector<Type> arg_types;
for (const auto& t : ftn->arg_types) {
arg_types.push_back(VisitType(t));
}
arg_types.push_back(bpt);
return FuncType(arg_types, ftn->ret_type, ftn->type_params, ftn->type_constraints);
}
};

Type ReverseType(const Type& t) { return ReverseADType()(t); }
Expand Down Expand Up @@ -436,12 +447,21 @@ Expr BPEmpty() {

struct ReverseAD : ExprMutator {
using ADVarMap = std::unordered_map<Var, Var, ObjectPtrHash, ObjectPtrEqual>;

using ADGVarMap = std::unordered_map<GlobalVar, GlobalVar, ObjectPtrHash, ObjectPtrEqual>;
Optional<IRModule> mod;
Var bp;
std::shared_ptr<ADVarMap> ad_vars;
std::shared_ptr<ADGVarMap> ad_gvars;
const OpAttrMap<FPrimalGradient> rev_map = Op::GetAttrMap<FPrimalGradient>("FPrimalGradient");

explicit ReverseAD(const Var& bp, std::shared_ptr<ADVarMap> ad_vars) : bp(bp), ad_vars(ad_vars) {}
explicit ReverseAD(const Optional<IRModule>& mod,
const Var& bp,
const std::shared_ptr<ADVarMap>& ad_vars,
const std::shared_ptr<ADGVarMap>& ad_gvars) :
mod(mod),
bp(bp),
ad_vars(ad_vars),
ad_gvars(ad_gvars) { }

Expr VisitExpr_(const OpNode* op) final {
LOG(FATAL) << "op should only be inside call";
Expand All @@ -461,8 +481,7 @@ struct ReverseAD : ExprMutator {
Expr nbp = Function({}, LetList::With([&](LetList* ll) {
// we need a new ReverseAD visitor to avoid clobbering the bp local var
auto dup_bp = ll->Push(BPEmpty());
ReverseAD dup_diff(dup_bp, ad_vars);
auto dup_ad = ll->Push(dup_diff.VisitExpr(DeDup(x)));
auto dup_ad = ll->Push(ReverseAD(mod, dup_bp, ad_vars, ad_gvars)(DeDup(x)));

TransferGrads(call->checked_type(), ret, dup_ad, ll);
ll->Push(Call(RefRead(dup_bp), {}));
Expand Down Expand Up @@ -498,21 +517,31 @@ struct ReverseAD : ExprMutator {
orig_var->checked_type_ = call->checked_type();
auto ret = ll->Push(GetRev(call->checked_type(), orig_var, ll));
auto bpv = ll->Push(RefRead(bp));
Expr nbp = Function({}, LetList::With([&](LetList* ll) {
tvm::Array<Expr> rev =
rev_map[op_ref](orig, GetGrad(call->checked_type(), ret, ll));
CHECK(args.size() == rev.size());
for (size_t i = 0; i < args.size(); ++i) {
UpdateGrad(call->args[i]->checked_type(), args[i], rev[i], ll);
}
return Call(bpv, {});
}),
TupleType::Empty(), {});
Expr nbp_body =
LetList::With([&](LetList* ll) {
tvm::Array<Expr> rev =
rev_map[op_ref](orig, GetGrad(call->checked_type(), ret, ll));
CHECK(args.size() == rev.size());
for (size_t i = 0; i < args.size(); ++i) {
UpdateGrad(call->args[i]->checked_type(), args[i], rev[i], ll);
}
return Call(bpv, {});
});
Expr nbp = Function({}, nbp_body, TupleType::Empty(), {});
ll->Push(RefWrite(bp, nbp));
return ret;
});
} else if (const ConstructorNode* con = call->op.as<ConstructorNode>()) {
return ExprMutator::VisitExpr_(call);
}
else {
std::vector<Expr> args;
for (const auto& arg : call->args) {
args.push_back(VisitExpr(arg));
}
args.push_back(bp);
return Call(VisitExpr(call->op), args);
}
return ExprMutator::VisitExpr_(call);
}

Expr VisitExpr_(const ConstantNode* op) final {
Expand All @@ -528,14 +557,46 @@ struct ReverseAD : ExprMutator {
Expr VisitExpr_(const VarNode* var) final {
// memoize Var -> ADVar so we don't end up with free Vars when checkpointing
auto var_ref = GetRef<Var>(var);
if (!ad_vars->count(var_ref)) {
if (ad_vars->count(var_ref) == 0) {
auto res = Downcast<Var>(ExprMutator::VisitExpr_(var));
(*ad_vars)[var_ref] = res;
}

return ad_vars->at(var_ref);
}

Expr VisitExpr_(const GlobalVarNode* op) final {
// todo: concatenating string to add attribute seems like a brittle hack.
// maybe get module indexed by a rose tree of string?
CHECK(mod.defined());
auto orig_gv = GetRef<GlobalVar>(op);
if (ad_gvars->count(orig_gv) == 0) {
GlobalVar gv(op->name_hint + "_grad");
(*ad_gvars)[orig_gv] = gv;
Function orig_f = Downcast<Function>(mod.value()->Lookup(GetRef<GlobalVar>(op)));
std::vector<Var> params;
for (const auto& p : orig_f->params) {
params.push_back(Downcast<Var>(VisitExpr(p)));
}
params.push_back(bp);
Expr body = VisitExpr(orig_f->body);
Function f(params, body, VisitType(orig_f->ret_type), orig_f->type_params, orig_f->attrs);
std::cout << "gv " << op->name_hint << ": " << AsText(f, false) << std::endl;
mod.value()->Add(gv, f);
}
return ad_gvars->at(orig_gv);
}

Expr VisitExpr_(const FunctionNode* op) final {
std::vector<Var> params;
for (const auto& var: op->params) {
params.push_back(Downcast<Var>(VisitExpr(var)));
}
auto new_bp = Var("bp", bpt);
params.push_back(new_bp);
return Function(params, ReverseAD(mod, new_bp, ad_vars, ad_gvars)(op->body), VisitType(op->ret_type), op->type_params, op->attrs);
}

Type VisitType(const Type& t) final { return t.defined() ? ReverseType(t) : t; }
};

Expand Down Expand Up @@ -577,12 +638,18 @@ Expr Gradient(const Expr& re, const Optional<IRModule>& mod) {
}
CHECK(!MissingGrad(e)) << "input has operators with missing gradients";
Expr body = LetList::With([&](LetList* ll) {
Var bp = ll->Push(BPEmpty());
Expr rev = ReverseAD(bp, std::make_shared<ReverseAD::ADVarMap>())(e);
std::vector<Expr> args;
Var bp = ll->Push(BPEmpty(), bpt);
Expr rev = ReverseAD(mod,
bp,
std::make_shared<ReverseAD::ADVarMap>(),
std::make_shared<ReverseAD::ADGVarMap>())(e);
std::vector<Expr> normal_args, args;
for (const auto& p : f->params) {
args.push_back(ll->Push(Pair(p, RefCreate(ZerosLike(p)))));
auto x = ll->Push(Pair(p, RefCreate(ZerosLike(p))));
normal_args.push_back(x);
args.push_back(x);
}
args.push_back(bp);
auto c = ll->Push(Call(rev, args));
std::function<void(const Expr&, const Type&)> init_grad;
init_grad = [&](const Expr& e, const Type& t) {
Expand All @@ -599,7 +666,7 @@ Expr Gradient(const Expr& re, const Optional<IRModule>& mod) {
init_grad(c, f->body->checked_type());
ll->Push(Call(RefRead(bp), {}));
std::vector<Expr> ret;
for (const auto& a : args) {
for (const auto& a : normal_args) {
ret.push_back(RefRead(GetField(a, 1)));
}
std::function<Expr(const Expr&, const Type&)> get_final_result;
Expand Down
Loading

0 comments on commit 19a3e65

Please sign in to comment.