Skip to content

Commit

Permalink
[CodeGen] Add build config option disable_assert to control whether t…
Browse files Browse the repository at this point in the history
…o generate assert (apache#4340)
  • Loading branch information
FrozenGene authored and Xingyu Zhou committed Nov 15, 2019
1 parent a0df146 commit 41b65ef
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 2 deletions.
4 changes: 4 additions & 0 deletions include/tvm/build_module.h
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,9 @@ class BuildConfigNode : public Node {
/*! \brief Whether to disable loop vectorization. */
bool disable_vectorize = false;

/*! \brief Whether to disable assert stmt generation. */
bool disable_assert = false;

void VisitAttrs(AttrVisitor* v) {
v->Visit("data_alignment", &data_alignment);
v->Visit("offset_factor", &offset_factor);
Expand All @@ -244,6 +247,7 @@ class BuildConfigNode : public Node {
v->Visit("instrument_bound_checkers", &instrument_bound_checkers);
v->Visit("disable_select_rewriting", &disable_select_rewriting);
v->Visit("disable_vectorize", &disable_vectorize);
v->Visit("disable_assert", &disable_assert);
}

static constexpr const char* _type_key = "BuildConfig";
Expand Down
7 changes: 7 additions & 0 deletions include/tvm/ir_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,13 @@ LoweredFunc LowerCustomDatatypes(LoweredFunc f, const std::string& target);
*/
LoweredFunc InferFragment(LoweredFunc f);

/*!
* \brief skip assert stmt generation
* \param f The function to be transformed.
* \return Transformed function.
*/
LoweredFunc SkipAssert(LoweredFunc f);

/*!
* \brief Verify if memory accesses are legal for a specific target device type.
*
Expand Down
3 changes: 2 additions & 1 deletion python/tvm/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,8 @@ class BuildConfig(NodeBase):
"dump_pass_ir": False,
"instrument_bound_checkers": False,
"disable_select_rewriting": False,
"disable_vectorize": False
"disable_vectorize": False,
"disable_assert": False
}
_dump_ir = DumpIR()

Expand Down
1 change: 1 addition & 0 deletions src/codegen/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -672,6 +672,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
p->stream << "instrument_bound_checkers=" << op->instrument_bound_checkers << ", ";
p->stream << "disable_select_rewriting=" << op->disable_select_rewriting;
p->stream << "disable_vectorize=" << op->disable_vectorize;
p->stream << "disable_assert=" << op->disable_assert;
p->stream << ")";
});

Expand Down
12 changes: 11 additions & 1 deletion src/codegen/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <tvm/ir_pass.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/module.h>
#include <tvm/build_module.h>
#include <dmlc/memory_io.h>
#include <sstream>
#include <iostream>
Expand All @@ -40,12 +41,21 @@ runtime::Module Build(const Array<LoweredFunc>& funcs,
if (pos != std::string::npos) {
mode = mode.substr(0, pos);
}
Array<LoweredFunc> transformed_funcs;
for (const auto& x : funcs) {
if (BuildConfig::Current()->disable_assert) {
auto func = ir::SkipAssert(x);
transformed_funcs.push_back(func);
}
}
std::string build_f_name = "codegen.build_" + mode;
// the build function.
const PackedFunc* bf = runtime::Registry::Get(build_f_name);
CHECK(bf != nullptr)
<< "Target " << target << " is not enabled";
runtime::Module m = (*bf)(funcs, target);
runtime::Module m = transformed_funcs.empty() ?
(*bf)(funcs, target) :
(*bf)(transformed_funcs, target);
return m;
}

Expand Down
47 changes: 47 additions & 0 deletions src/pass/skip_assert.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_mutator.h>

namespace tvm {
namespace ir {

class AssertSkipper : public IRMutator {
public:
Stmt Mutate_(const AssertStmt* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
op = stmt.as<AssertStmt>();
return op->body;
}
};

Stmt SkipAssert(Stmt stmt) {
return AssertSkipper().Mutate(stmt);
}

LoweredFunc SkipAssert(LoweredFunc f) {
auto n = make_node<LoweredFuncNode>(*f.operator->());
n->body = SkipAssert(f->body);
return LoweredFunc(n);
}

} // namespace ir
} // namespace tvm

0 comments on commit 41b65ef

Please sign in to comment.