Skip to content

Commit

Permalink
Add SkipVectorize pass (apache#3222)
Browse files Browse the repository at this point in the history
  • Loading branch information
weberlo authored and Wei Chen committed Jun 26, 2019
1 parent 5a798c8 commit ced5110
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 19 deletions.
1 change: 1 addition & 0 deletions docs/api/python/dev.rst
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ tvm.ir_pass
tvm.ir_pass.CanonicalSimplify
tvm.ir_pass.StorageFlatten
tvm.ir_pass.VectorizeLoop
tvm.ir_pass.SkipVectorize
tvm.ir_pass.UnrollLoop
tvm.ir_pass.ThreadSync
tvm.ir_pass.StorageRewrite
Expand Down
4 changes: 4 additions & 0 deletions include/tvm/build_module.h
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,9 @@ class BuildConfigNode : public Node {
/*! \brief Whether to disable select rewriting. */
bool disable_select_rewriting = false;

/*! \brief Whether to disable loop vectorization. */
bool disable_vectorize = false;

void VisitAttrs(AttrVisitor* v) final {
v->Visit("data_alignment", &data_alignment);
v->Visit("offset_factor", &offset_factor);
Expand All @@ -260,6 +263,7 @@ class BuildConfigNode : public Node {
v->Visit("dump_pass_ir", &dump_pass_ir);
v->Visit("instrument_bound_checkers", &instrument_bound_checkers);
v->Visit("disable_select_rewriting", &disable_select_rewriting);
v->Visit("disable_vectorize", &disable_vectorize);
}

static constexpr const char* _type_key = "BuildConfig";
Expand Down
35 changes: 21 additions & 14 deletions include/tvm/ir_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -250,35 +250,42 @@ Stmt UnrollLoop(Stmt stmt,

/*!
* \brief vectorize the constant loops
* \param stmt The statment to be vectorized.
* \param stmt The statement to be vectorized.
* \return Transformed stmt.
*/
Stmt VectorizeLoop(Stmt stmt);

/*!
* \brief convert vectorized loops into serialized loops
* \param stmt The statement to skip vectorization on.
* \return Transformed stmt.
*/
Stmt SkipVectorize(Stmt stmt);

/*!
* \brief instruments bound checkers.
* \param stmt The statment to be instrumented.
* \return Instrumented Stmt.
* \param stmt The statement to be instrumented.
* \return Instrumented stmt.
*/
Stmt InstrumentBoundCheckers(Stmt stmt);

/*!
* \brief Inject virtual thread loops into stmt.
* \param stmt The statment to be transformed.
* \param stmt The statement to be transformed.
* \return Transformed stmt.
*/
Stmt InjectVirtualThread(Stmt stmt);

/*!
* \brief Inject prefetch instructions into stmt.
* \param stmt The statment to be transformed.
* \param stmt The statement to be transformed.
* \return Transformed stmt.
*/
Stmt InjectPrefetch(Stmt stmt);

/*!
* \brief Inject double buffer into stmt.
* \param stmt The statment to be transformed.
* \param stmt The statement to be transformed.
* \param split_loop Loop splitting factor.
* \return Transformed stmt.
*/
Expand All @@ -287,7 +294,7 @@ Stmt InjectDoubleBuffer(Stmt stmt, int split_loop);
/*!
* \brief Inject copy intrinsics with optional pad.
*
* \param stmt The statment to be transformed.
* \param stmt The statement to be transformed.
* \param pragma_key The pragma key for hint of copy.
* \param fintrin The function with signature
*
Expand All @@ -308,7 +315,7 @@ Stmt InjectCopyIntrin(Stmt stmt,
* Trying to share space between allocations to make
* a static allocation plan when possible.
*
* \param stmt The stmt to be trasnformed
* \param stmt The stmt to be transformed
* \return Transformed stmt.
*/
Stmt StorageRewrite(Stmt stmt);
Expand All @@ -324,23 +331,23 @@ Stmt LoopPartition(Stmt stmt, bool split_const_loop);
/*!
* \brief Detect and insert sync points to co-processor.
*
* \param stmt The stmt to be trasnformed
* \param stmt The stmt to be transformed
* \return Transformed stmt.
*/
Stmt CoProcSync(Stmt stmt);

/*!
* \brief Lift common attrs with attr_key to outer scope.
*
* \param stmt The stmt to be trasnformed
* \param stmt The stmt to be transformed
* \param attr_key The attribute key to be checked.
* \return Transformed stmt.
*/
Stmt LiftAttrScope(Stmt stmt, std::string attr_key);

/*!
* \brief Detect and rewrite unsafe select that contains memory access.
* \param stmt The statment to be rewritten.
* \param stmt The statement to be rewritten.
* \return Transformed stmt.
*/
Stmt RewriteUnsafeSelect(Stmt stmt);
Expand All @@ -349,7 +356,7 @@ Stmt RewriteUnsafeSelect(Stmt stmt);
* \brief Lower attached storage access information.
* Do this pass after all storage access analysis finish.
*
* \param stmt The stmt to be trasnformed
* \param stmt The stmt to be transformed
* \return Transformed stmt.
*/
Stmt LowerStorageAccessInfo(Stmt stmt);
Expand All @@ -358,7 +365,7 @@ Stmt LowerStorageAccessInfo(Stmt stmt);
* \brief Decorate the stmt with a device scope, this is helpful for
* hardware accelerator without thread blocks.
*
* \param stmt The stmt to be trasnformed
* \param stmt The stmt to be transformed
* \return Transformed stmt.
*/
Stmt DecorateDeviceScope(Stmt stmt);
Expand All @@ -381,7 +388,7 @@ Stmt DecorateDeviceScope(Stmt stmt);
* \return a LoweredFunc with the specified signiture.
*
* \note
* The function signiture have two cases
* The function signature have two cases
*
* let num_packed_args = len(api_args) - num_unpacked_args;
*
Expand Down
8 changes: 6 additions & 2 deletions python/tvm/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,8 @@ class BuildConfig(NodeBase):
"double_buffer_split_loop": 1,
"dump_pass_ir": False,
"instrument_bound_checkers": False,
"disable_select_rewriting": False
"disable_select_rewriting": False,
"disable_vectorize": False
}
_dump_ir = DumpIR()

Expand Down Expand Up @@ -384,7 +385,10 @@ def lower(sch,
# Phase 2
if not simple_mode:
stmt = ir_pass.LoopPartition(stmt, cfg.partition_const_loop)
stmt = ir_pass.VectorizeLoop(stmt)
if cfg.disable_vectorize:
stmt = ir_pass.SkipVectorize(stmt)
else:
stmt = ir_pass.VectorizeLoop(stmt)
stmt = ir_pass.InjectVirtualThread(stmt)
stmt = ir_pass.InjectDoubleBuffer(stmt, cfg.double_buffer_split_loop)
stmt = ir_pass.StorageRewrite(stmt)
Expand Down
7 changes: 6 additions & 1 deletion src/codegen/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,11 @@ Stmt BuildStmt(Schedule sch,
if (loop_partition) {
stmt = ir::LoopPartition(stmt, config->partition_const_loop);
}
stmt = ir::VectorizeLoop(stmt);
if (config->disable_vectorize) {
stmt = ir::SkipVectorize(stmt);
} else {
stmt = ir::VectorizeLoop(stmt);
}
stmt = ir::InjectVirtualThread(stmt);
stmt = ir::InjectDoubleBuffer(stmt, config->double_buffer_split_loop);
stmt = ir::StorageRewrite(stmt);
Expand Down Expand Up @@ -642,6 +646,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
p->stream << "dump_pass_ir=" << op->dump_pass_ir << ", ";
p->stream << "instrument_bound_checkers=" << op->instrument_bound_checkers << ", ";
p->stream << "disable_select_rewriting=" << op->disable_select_rewriting;
p->stream << "disable_vectorize=" << op->disable_vectorize;
p->stream << ")";
});

Expand Down
22 changes: 20 additions & 2 deletions src/pass/vectorize_loop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
Expand Down Expand Up @@ -519,5 +519,23 @@ Stmt VectorizeLoop(Stmt stmt) {
return LoopVectorizer().Mutate(stmt);
}

class VectorizeSkipper : public IRMutator {
public:
Stmt Mutate_(const For* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
op = stmt.as<For>();
if (op->for_type == ForType::Vectorized) {
return For::make(op->loop_var, op->min, op->extent, ForType::Serial, op->device_api,
op->body);
} else {
return stmt;
}
}
};

Stmt SkipVectorize(Stmt stmt) {
return VectorizeSkipper().Mutate(stmt);
}

} // namespace ir
} // namespace tvm

0 comments on commit ced5110

Please sign in to comment.