From ced51108a4d3cc6d2ad5d959338fead80dda5dfc Mon Sep 17 00:00:00 2001 From: Logan Weber <36520469+weberlo@users.noreply.github.com> Date: Tue, 21 May 2019 16:34:35 -0700 Subject: [PATCH] Add `SkipVectorize` pass (#3222) --- docs/api/python/dev.rst | 1 + include/tvm/build_module.h | 4 ++++ include/tvm/ir_pass.h | 35 +++++++++++++++++++++-------------- python/tvm/build_module.py | 8 ++++++-- src/codegen/build_module.cc | 7 ++++++- src/pass/vectorize_loop.cc | 22 ++++++++++++++++++++-- 6 files changed, 58 insertions(+), 19 deletions(-) diff --git a/docs/api/python/dev.rst b/docs/api/python/dev.rst index e4b207bf4cbc..7bb938ca7517 100644 --- a/docs/api/python/dev.rst +++ b/docs/api/python/dev.rst @@ -61,6 +61,7 @@ tvm.ir_pass tvm.ir_pass.CanonicalSimplify tvm.ir_pass.StorageFlatten tvm.ir_pass.VectorizeLoop + tvm.ir_pass.SkipVectorize tvm.ir_pass.UnrollLoop tvm.ir_pass.ThreadSync tvm.ir_pass.StorageRewrite diff --git a/include/tvm/build_module.h b/include/tvm/build_module.h index 208f086f86c0..7fb456c823a7 100644 --- a/include/tvm/build_module.h +++ b/include/tvm/build_module.h @@ -246,6 +246,9 @@ class BuildConfigNode : public Node { /*! \brief Whether to disable select rewriting. */ bool disable_select_rewriting = false; + /*! \brief Whether to disable loop vectorization. */ + bool disable_vectorize = false; + void VisitAttrs(AttrVisitor* v) final { v->Visit("data_alignment", &data_alignment); v->Visit("offset_factor", &offset_factor); @@ -260,6 +263,7 @@ class BuildConfigNode : public Node { v->Visit("dump_pass_ir", &dump_pass_ir); v->Visit("instrument_bound_checkers", &instrument_bound_checkers); v->Visit("disable_select_rewriting", &disable_select_rewriting); + v->Visit("disable_vectorize", &disable_vectorize); } static constexpr const char* _type_key = "BuildConfig"; diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h index 5ef4dc4ed9d7..e1c92e50e6ad 100644 --- a/include/tvm/ir_pass.h +++ b/include/tvm/ir_pass.h @@ -250,35 +250,42 @@ Stmt UnrollLoop(Stmt stmt, /*! * \brief vectorize the constant loops - * \param stmt The statment to be vectorized. + * \param stmt The statement to be vectorized. * \return Transformed stmt. */ Stmt VectorizeLoop(Stmt stmt); +/*! + * \brief convert vectorized loops into serialized loops + * \param stmt The statement to skip vectorization on. + * \return Transformed stmt. + */ +Stmt SkipVectorize(Stmt stmt); + /*! * \brief instruments bound checkers. -* \param stmt The statment to be instrumented. -* \return Instrumented Stmt. +* \param stmt The statement to be instrumented. +* \return Instrumented stmt. */ Stmt InstrumentBoundCheckers(Stmt stmt); /*! * \brief Inject virtual thread loops into stmt. - * \param stmt The statment to be transformed. + * \param stmt The statement to be transformed. * \return Transformed stmt. */ Stmt InjectVirtualThread(Stmt stmt); /*! * \brief Inject prefetch instructions into stmt. - * \param stmt The statment to be transformed. + * \param stmt The statement to be transformed. * \return Transformed stmt. */ Stmt InjectPrefetch(Stmt stmt); /*! * \brief Inject double buffer into stmt. - * \param stmt The statment to be transformed. + * \param stmt The statement to be transformed. * \param split_loop Loop splitting factor. * \return Transformed stmt. */ @@ -287,7 +294,7 @@ Stmt InjectDoubleBuffer(Stmt stmt, int split_loop); /*! * \brief Inject copy intrinsics with optional pad. * - * \param stmt The statment to be transformed. + * \param stmt The statement to be transformed. * \param pragma_key The pragma key for hint of copy. * \param fintrin The function with signature * @@ -308,7 +315,7 @@ Stmt InjectCopyIntrin(Stmt stmt, * Trying to share space between allocations to make * a static allocation plan when possible. * - * \param stmt The stmt to be trasnformed + * \param stmt The stmt to be transformed * \return Transformed stmt. */ Stmt StorageRewrite(Stmt stmt); @@ -324,7 +331,7 @@ Stmt LoopPartition(Stmt stmt, bool split_const_loop); /*! * \brief Detect and insert sync points to co-processor. * - * \param stmt The stmt to be trasnformed + * \param stmt The stmt to be transformed * \return Transformed stmt. */ Stmt CoProcSync(Stmt stmt); @@ -332,7 +339,7 @@ Stmt CoProcSync(Stmt stmt); /*! * \brief Lift common attrs with attr_key to outer scope. * - * \param stmt The stmt to be trasnformed + * \param stmt The stmt to be transformed * \param attr_key The attribute key to be checked. * \return Transformed stmt. */ @@ -340,7 +347,7 @@ Stmt LiftAttrScope(Stmt stmt, std::string attr_key); /*! * \brief Detect and rewrite unsafe select that contains memory access. - * \param stmt The statment to be rewritten. + * \param stmt The statement to be rewritten. * \return Transformed stmt. */ Stmt RewriteUnsafeSelect(Stmt stmt); @@ -349,7 +356,7 @@ Stmt RewriteUnsafeSelect(Stmt stmt); * \brief Lower attached storage access information. * Do this pass after all storage access analysis finish. * - * \param stmt The stmt to be trasnformed + * \param stmt The stmt to be transformed * \return Transformed stmt. */ Stmt LowerStorageAccessInfo(Stmt stmt); @@ -358,7 +365,7 @@ Stmt LowerStorageAccessInfo(Stmt stmt); * \brief Decorate the stmt with a device scope, this is helpful for * hardware accelerator without thread blocks. * - * \param stmt The stmt to be trasnformed + * \param stmt The stmt to be transformed * \return Transformed stmt. */ Stmt DecorateDeviceScope(Stmt stmt); @@ -381,7 +388,7 @@ Stmt DecorateDeviceScope(Stmt stmt); * \return a LoweredFunc with the specified signiture. * * \note - * The function signiture have two cases + * The function signature have two cases * * let num_packed_args = len(api_args) - num_unpacked_args; * diff --git a/python/tvm/build_module.py b/python/tvm/build_module.py index 120bf629a959..a28ab98fb60e 100644 --- a/python/tvm/build_module.py +++ b/python/tvm/build_module.py @@ -143,7 +143,8 @@ class BuildConfig(NodeBase): "double_buffer_split_loop": 1, "dump_pass_ir": False, "instrument_bound_checkers": False, - "disable_select_rewriting": False + "disable_select_rewriting": False, + "disable_vectorize": False } _dump_ir = DumpIR() @@ -384,7 +385,10 @@ def lower(sch, # Phase 2 if not simple_mode: stmt = ir_pass.LoopPartition(stmt, cfg.partition_const_loop) - stmt = ir_pass.VectorizeLoop(stmt) + if cfg.disable_vectorize: + stmt = ir_pass.SkipVectorize(stmt) + else: + stmt = ir_pass.VectorizeLoop(stmt) stmt = ir_pass.InjectVirtualThread(stmt) stmt = ir_pass.InjectDoubleBuffer(stmt, cfg.double_buffer_split_loop) stmt = ir_pass.StorageRewrite(stmt) diff --git a/src/codegen/build_module.cc b/src/codegen/build_module.cc index 9b30ced90c4f..ac6b797d9683 100644 --- a/src/codegen/build_module.cc +++ b/src/codegen/build_module.cc @@ -392,7 +392,11 @@ Stmt BuildStmt(Schedule sch, if (loop_partition) { stmt = ir::LoopPartition(stmt, config->partition_const_loop); } - stmt = ir::VectorizeLoop(stmt); + if (config->disable_vectorize) { + stmt = ir::SkipVectorize(stmt); + } else { + stmt = ir::VectorizeLoop(stmt); + } stmt = ir::InjectVirtualThread(stmt); stmt = ir::InjectDoubleBuffer(stmt, config->double_buffer_split_loop); stmt = ir::StorageRewrite(stmt); @@ -642,6 +646,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) p->stream << "dump_pass_ir=" << op->dump_pass_ir << ", "; p->stream << "instrument_bound_checkers=" << op->instrument_bound_checkers << ", "; p->stream << "disable_select_rewriting=" << op->disable_select_rewriting; + p->stream << "disable_vectorize=" << op->disable_vectorize; p->stream << ")"; }); diff --git a/src/pass/vectorize_loop.cc b/src/pass/vectorize_loop.cc index f87e80c2d030..8c3d383c1529 100644 --- a/src/pass/vectorize_loop.cc +++ b/src/pass/vectorize_loop.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -519,5 +519,23 @@ Stmt VectorizeLoop(Stmt stmt) { return LoopVectorizer().Mutate(stmt); } +class VectorizeSkipper : public IRMutator { + public: + Stmt Mutate_(const For* op, const Stmt& s) final { + Stmt stmt = IRMutator::Mutate_(op, s); + op = stmt.as(); + if (op->for_type == ForType::Vectorized) { + return For::make(op->loop_var, op->min, op->extent, ForType::Serial, op->device_api, + op->body); + } else { + return stmt; + } + } +}; + +Stmt SkipVectorize(Stmt stmt) { + return VectorizeSkipper().Mutate(stmt); +} + } // namespace ir } // namespace tvm