Skip to content

Commit

Permalink
commit
Browse files Browse the repository at this point in the history
save

fix
  • Loading branch information
MarisaKirisame committed Aug 27, 2020
1 parent e35b7fc commit 3ec296d
Show file tree
Hide file tree
Showing 6 changed files with 77 additions and 6 deletions.
7 changes: 7 additions & 0 deletions include/tvm/ir/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,13 @@ class IRModuleNode : public Object {
*/
TVM_DLL BaseFunc Lookup(const GlobalVar& var) const;

/*!
* \brief Check if a global function exist by its variable.
* \param var The global var to lookup.
* \returns Wether the function named by the variable argument exist.
*/
TVM_DLL bool Exist(const GlobalVar& var) const;

/*!
* \brief Look up a global function by its string name
* \param name The name of the function.
Expand Down
11 changes: 9 additions & 2 deletions python/tvm/relay/transform/memory_alloc.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,25 @@ def is_primitive(call):


class CheckReshapeOnly(ExprVisitor):
"""A pass to check if the fused op contains only reshape ops."""
"""
A pass to check if the fused op contains only reshape ops.
TODO(@Jared) this is capturing the case where there is no any ops at all.
I had put a quick hack and require that it must have >= 1 reshape, but this must be masking some bigger problem.
Please fix - if you want to collapse reshape on reshape, perhaps you should do fusion as such.
"""
def __init__(self):
super().__init__()
self._reshape_ops = [op.get("reshape"), op.get("contrib_reverse_reshape"),
op.get("dyn.reshape")]
self.reshape_only = True
self.has_reshape = False

def visit_call(self, call):
if not self.reshape_only:
return
if call.op not in self._reshape_ops:
self.reshape_only = False
self.has_reshape = True
for arg in call.args:
self.visit(arg)

Expand All @@ -60,7 +67,7 @@ def is_reshape_only(func):
"""Check if the primitive function contains only reshape ops."""
check = CheckReshapeOnly()
check.visit(func)
return check.reshape_only
return check.reshape_only and check.has_reshape


class ManifestAllocPass(ExprMutator):
Expand Down
12 changes: 11 additions & 1 deletion src/ir/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,12 @@ void IRModuleNode::Add(const GlobalVar& var, const BaseFunc& f, bool update) {
}

void IRModuleNode::AddUnchecked(const GlobalVar& var, const BaseFunc& func) {
this->functions.Set(var, func);
auto* fn = func.as<tvm::relay::FunctionNode>();
if (fn) {
this->functions.Set(var, Downcast<tvm::relay::Function>(tvm::relay::DeDup(GetRef<tvm::relay::Function>(fn))));
} else {
this->functions.Set(var, func);
}

auto it = global_var_map_.find(var->name_hint);
if (it != global_var_map_.end()) {
Expand Down Expand Up @@ -284,6 +289,11 @@ BaseFunc IRModuleNode::Lookup(const GlobalVar& var) const {
return (*it).second;
}

bool IRModuleNode::Exist(const GlobalVar& var) const {
auto it = functions.find(var);
return it != functions.end();
}

BaseFunc IRModuleNode::Lookup(const String& name) const {
GlobalVar id = this->GetGlobalVar(name);
return this->Lookup(id);
Expand Down
31 changes: 30 additions & 1 deletion src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -923,6 +923,31 @@ transform::Sequential MemoryOpt(tvm::Target host_target) {
return transform::Sequential(pass_seqs);
}

Pass CheckPrimeFunc() {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
struct CheckPrimeFuncVisitor : ExprVisitor {
bool inside_primitive = false;
void VisitExpr_(const ConstantNode* op) override {
CHECK_EQ(inside_primitive, false);
}
void VisitExpr_(const FunctionNode* op) override {
if (op->HasNonzeroAttr(attr::kPrimitive)) {
CHECK_EQ(inside_primitive, false);
inside_primitive = true;
VisitExpr(op->body);
inside_primitive = false;
} else {
VisitExpr(op->body);
}
}
} vis;
vis(f);
return f;
};
return CreateFunctionPass(pass_func, 1, "CheckPrimeFunc", {});
}

IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targets,
const Target& target_host) {
Array<Pass> pass_seqs;
Expand Down Expand Up @@ -956,14 +981,19 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe
}
*rv = false;
});

pass_seqs.push_back(transform::EliminateCommonSubexpr(fskip));
pass_seqs.push_back(transform::SimplifyExpr());
pass_seqs.push_back(transform::InlinePrimitives());


pass_seqs.push_back(transform::CombineParallelConv2D(3));
pass_seqs.push_back(transform::CombineParallelDense(3));
pass_seqs.push_back(transform::CombineParallelBatchMatmul(3));

pass_seqs.push_back(transform::FoldConstant());
//pass_seqs.push_back(tvm::transform::PrintIR());
//pass_seqs.push_back(CheckPrimeFunc());
pass_seqs.push_back(transform::FoldScaleAxis());
pass_seqs.push_back(transform::CanonicalizeCast());
pass_seqs.push_back(transform::CanonicalizeOps());
Expand Down Expand Up @@ -1004,7 +1034,6 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe
}

void VMCompiler::PopulateGlobalMap() {
// First we populate global map.
size_t global_index = 0;
for (auto named_func : context_.module->functions) {
auto gvar = named_func.first;
Expand Down
16 changes: 16 additions & 0 deletions src/relay/transforms/fold_constant.cc
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,23 @@ class ConstantFolder : public ExprMutator {
}
}

bool inside_primitive = false;
Expr VisitExpr_(const FunctionNode* op) final {
if (op->HasNonzeroAttr(attr::kPrimitive)) {
CHECK_EQ(inside_primitive, false);
inside_primitive = true;
auto ret = ExprMutator::VisitExpr_(op);
inside_primitive = false;
return ret;
} else {
return ExprMutator::VisitExpr_(op);
}
}

Expr VisitExpr_(const CallNode* call) final {
if (inside_primitive) {
return GetRef<Expr>(call);
}
static auto op_stateful = Op::GetAttrMap<TOpIsStateful>("TOpIsStateful");

std::unordered_set<std::string> skip_list{"zeros_like", "ones_like", "full_like", "full"};
Expand Down
6 changes: 4 additions & 2 deletions src/relay/transforms/partial_eval.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@
*/
#include <tvm/ir/type_functor.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/feature.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/feature.h>
#include <tvm/relay/interpreter.h>
Expand Down Expand Up @@ -776,8 +777,9 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>

Func VisitFuncStatic(const Function& func, const Expr& var) {
CHECK(IsAtomic(var));
// todo: figure out primitive semantic
if (func->HasNonzeroAttr(attr::kPrimitive)) {
return ConstEvaluateFunc(func);
// return ConstEvaluateFunc(func);
}
std::vector<std::pair<Var, PStatic> > free_vars;
for (const auto& v : FreeVars(func)) {
Expand Down Expand Up @@ -1200,7 +1202,7 @@ namespace transform {
Pass PartialEval() {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
[=](IRModule m, PassContext pc) { return relay::PartialEval(m); };
return CreateModulePass(pass_func, 1, "PartialEval", {});
return CreateModulePass(pass_func, 1, "PartialEvaluate", {});
}

TVM_REGISTER_GLOBAL("relay._transform.PartialEvaluate").set_body_typed(PartialEval);
Expand Down

0 comments on commit 3ec296d

Please sign in to comment.