diff --git a/include/tvm/meta_schedule/postproc.h b/include/tvm/meta_schedule/postproc.h index 738e726aa146..5d99f6845463 100644 --- a/include/tvm/meta_schedule/postproc.h +++ b/include/tvm/meta_schedule/postproc.h @@ -150,12 +150,17 @@ class Postproc : public runtime::ObjectRef { * \return The postprocessor created. */ TVM_DLL static Postproc RewriteTensorize(bool vectorize_init_loop = false); - /*! * \brief Creates a postprocessor that verifies if the GPU code is correct * \return The postprocessor created */ TVM_DLL static Postproc VerifyGPUCode(); + /*! + * \brief Creates a postprocessor that rewrites the layout of input tensor + * \note Weight layout rewrite is supported so far, activation layout rewrite will be added. + * \return The postprocessor created + */ + TVM_DLL static Postproc RewriteLayout(); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Postproc, ObjectRef, PostprocNode); }; diff --git a/python/tvm/auto_scheduler/__init__.py b/python/tvm/auto_scheduler/__init__.py index ff6d82a0242c..97ac323662bb 100644 --- a/python/tvm/auto_scheduler/__init__.py +++ b/python/tvm/auto_scheduler/__init__.py @@ -17,45 +17,64 @@ # pylint: disable=unused-import, redefined-builtin """ Namespace for TVM Auto-scheduler. """ -from . import compute_dag -from . import dispatcher -from . import feature -from . import loop_state -from . import measure -from . import measure_record -from . import relay_integration -from . import search_policy -from . import search_task -from . import task_scheduler -from . import utils -from . import workload_registry +from . import ( + compute_dag, + dispatcher, + feature, + loop_state, + measure, + measure_record, + relay_integration, + search_policy, + search_task, + task_scheduler, + utils, + workload_registry, +) # Shortcut -from .compute_dag import ComputeDAG, LayoutRewriteOption, get_shape_from_rewritten_layout +from .compute_dag import ( + ComputeDAG, + LayoutRewriteOption, + get_shape_from_rewritten_layout, +) from .cost_model import RandomModel, XGBModel -from .dispatcher import DispatchContext, ApplyHistoryBest, ApplyHistoryBestOrSample +from .dispatcher import ApplyHistoryBest, ApplyHistoryBestOrSample, DispatchContext from .measure import ( - MeasureInput, - MeasureResult, LocalBuilder, + LocalRPCMeasureContext, LocalRunner, + MeasureInput, + MeasureResult, RPCRunner, - LocalRPCMeasureContext, register_task_input_check_func, ) -from .measure_record import RecordToFile, RecordReader, load_best_record, load_records, save_records +from .measure_record import ( + RecordReader, + RecordToFile, + load_best_record, + load_records, + save_records, +) from .relay_integration import ( extract_tasks, + is_auto_scheduler_enabled, remove_index_check, rewrite_compute_body, - is_auto_scheduler_enabled, + rewrite_tensor_shape, ) -from .search_task import SearchTask, TuningOptions, HardwareParams, create_task, auto_schedule from .search_policy import ( EmptyPolicy, - SketchPolicy, - PreloadMeasuredStates, PreloadCustomSketchRule, + PreloadMeasuredStates, + SketchPolicy, +) +from .search_task import ( + HardwareParams, + SearchTask, + TuningOptions, + auto_schedule, + create_task, ) from .task_scheduler import TaskScheduler -from .workload_registry import register_workload, make_workload_key +from .workload_registry import make_workload_key, register_workload diff --git a/python/tvm/meta_schedule/default_config.py b/python/tvm/meta_schedule/default_config.py index 34411bde057b..ff0120538133 100644 --- a/python/tvm/meta_schedule/default_config.py +++ b/python/tvm/meta_schedule/default_config.py @@ -262,6 +262,7 @@ def postprocs() -> List[Postproc]: M.DisallowDynamicLoop(), M.RewriteParallelVectorizeUnroll(), M.RewriteReductionBlock(), + M.RewriteLayout(), ] @staticmethod diff --git a/python/tvm/meta_schedule/postproc/__init__.py b/python/tvm/meta_schedule/postproc/__init__.py index 39113bb90011..f70b740d7bd7 100644 --- a/python/tvm/meta_schedule/postproc/__init__.py +++ b/python/tvm/meta_schedule/postproc/__init__.py @@ -15,11 +15,12 @@ # specific language governing permissions and limitations # under the License. """The tvm.meta_schedule.postproc package.""" -from .postproc import Postproc, PyPostproc from .disallow_dynamic_loop import DisallowDynamicLoop +from .postproc import Postproc, PyPostproc from .rewrite_cooperative_fetch import RewriteCooperativeFetch +from .rewrite_layout import RewriteLayout from .rewrite_parallel_vectorize_unroll import RewriteParallelVectorizeUnroll from .rewrite_reduction_block import RewriteReductionBlock +from .rewrite_tensorize import RewriteTensorize from .rewrite_unbound_block import RewriteUnboundBlock from .verify_gpu_code import VerifyGPUCode -from .rewrite_tensorize import RewriteTensorize diff --git a/python/tvm/meta_schedule/postproc/rewrite_layout.py b/python/tvm/meta_schedule/postproc/rewrite_layout.py new file mode 100644 index 000000000000..10addefee542 --- /dev/null +++ b/python/tvm/meta_schedule/postproc/rewrite_layout.py @@ -0,0 +1,32 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""A postprocessor that rewrites the layout of input tensor""" + +from tvm._ffi.registry import register_object + +from .. import _ffi_api +from .postproc import Postproc + + +@register_object("meta_schedule.RewriteLayout") +class RewriteLayout(Postproc): + """A postprocessor that rewrites the layout of input tensor""" + + def __init__(self) -> None: + self.__init_handle_by_constructor__( + _ffi_api.PostprocRewriteLayout, # type: ignore # pylint: disable=no-member + ) diff --git a/src/meta_schedule/postproc/rewrite_layout.cc b/src/meta_schedule/postproc/rewrite_layout.cc new file mode 100644 index 000000000000..f4cbdfe737fb --- /dev/null +++ b/src/meta_schedule/postproc/rewrite_layout.cc @@ -0,0 +1,183 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace tir { + +/*! + * \brief Collect the block and index where the buffer is read. + * \note The buffers are expected to be read by only one BufferLoad + */ +class BufferReadPosCollector : public StmtExprVisitor { + public: + explicit BufferReadPosCollector(const Array& buffers) { + for (const Buffer& buf : buffers) { + buffers_.insert(buf.get()); + } + } + + const std::unordered_map>& GetBufferLocations() const { + return buffer_locs_; + } + + const std::unordered_map>& GetBufferIndexMap() const { + return buffer_index_maps_; + } + + private: + void VisitStmt_(const ForNode* op) final { + loop_stack_.push_back(GetRef(op)); + StmtVisitor::VisitStmt_(op); + loop_stack_.pop_back(); + } + + void VisitStmt_(const BlockRealizeNode* op) final { + BlockRealize outer_block_realize = GetRef(op); + std::swap(outer_block_realize, cur_realize_); + StmtVisitor::VisitStmt_(op); + std::swap(cur_realize_, outer_block_realize); + } + + void VisitExpr_(const BufferLoadNode* op) final { + const Buffer& buffer = op->buffer; + if (buffers_.count(buffer.get())) { + Map subst_map; + for (size_t i = 0; i < cur_realize_->iter_values.size(); i++) { + const Var& var = cur_realize_->block->iter_vars[i]->var; + const PrimExpr& value = cur_realize_->iter_values[i]; + subst_map.Set(var, value); + } + Array subst_indices; + for (const PrimExpr& e : op->indices) { + subst_indices.push_back(Substitute(e, subst_map)); + } + buffer_index_maps_[buffer.get()] = SuggestIndexMap(/*buffer=*/buffer, // + /*indices=*/subst_indices, // + /*loops=*/loop_stack_, // + /*predicate=*/cur_realize_->predicate, // + /*analyzer=*/&analyzer_); + int buffer_index = GetReadBufferIndex(cur_realize_->block, buffer); + ICHECK(buffer_index != -1); + buffer_locs_[buffer.get()] = std::make_pair(cur_realize_->block, buffer_index); + } + } + + static int GetReadBufferIndex(const Block& block, const Buffer& buffer) { + for (size_t i = 0; i < block->reads.size(); i++) { + if (block->reads[i]->buffer.same_as(buffer)) { + return i; + } + } + return -1; + } + + private: + /*! \brief All interested buffer. */ + std::unordered_set buffers_; + /*! \brief The result mapping from buffer to its inner-most block and read index. */ + std::unordered_map> buffer_locs_; + /*! \brief The result mapping from buffer to its IndexMap. */ + std::unordered_map> buffer_index_maps_; + + /*! \brief Loop stack for calculating IndexMap. */ + Array loop_stack_; + /*! \brief Arithmetic analyzer. */ + arith::Analyzer analyzer_; + /*! \brief Current BlockRealize scope, used in recursive visit */ + BlockRealize cur_realize_; +}; + +bool RewriteLayout(const Schedule& sch) { + std::vector> results; + for (const auto& kv : sch->mod()->functions) { + const GlobalVar& g_var = kv.first; + const String& func_name = g_var->name_hint; + const auto* prim_func = kv.second.as(); + // Only consider PrimFunc + if (prim_func == nullptr) { + continue; + } + // Only rewrite PrimFuncs with attr "layout_free_buffers" + Array layout_free_buffer_index = + prim_func->GetAttr(attr::layout_free_buffers, Array()).value(); + + Array layout_free_buffers; + for (const Integer& index : layout_free_buffer_index) { + const Var& param = prim_func->params[index->value]; + layout_free_buffers.push_back(prim_func->buffer_map.at(param)); + } + // Collect Buffer read positions + BufferReadPosCollector collector(layout_free_buffers); + collector(prim_func->body); + const auto& locations = collector.GetBufferLocations(); + const auto& index_maps = collector.GetBufferIndexMap(); + // Check all buffers are collected + if (locations.size() != layout_free_buffers.size() || + index_maps.size() != layout_free_buffer_index.size()) { + return false; + } + + for (const auto& kv : locations) { + const Buffer& buffer = GetRef(kv.first); + const Block& block = kv.second.first; + int buffer_index = kv.second.second; + + // Get IndexMap + const Optional index_map = index_maps.at(buffer.get()); + if (!index_map.defined()) { + continue; + } + + // Apply schedule + BlockRV block_rv = sch->GetBlock(block->name_hint, func_name); + BlockRV cached_block_rv = sch->CacheRead(block_rv, buffer_index, "global"); + sch->TransformLayout(block_rv, buffer_index, BufferIndexType::kRead, index_map.value()); + sch->Annotate(cached_block_rv, attr::meta_schedule_layout_rewrite_preproc, const_true()); + } + } + return true; +} + +} // namespace tir + +namespace meta_schedule { +/*! \brief Layout Rewrite. */ +class RewriteLayoutNode : public PostprocNode { + public: + // Inherited from PostprocNode + void InitializeWithTuneContext(const TuneContext& context) final {} + + // Inherited from PostprocNode + bool Apply(const tir::Schedule& sch) final { return tir::RewriteLayout(sch); } + + static constexpr const char* _type_key = "meta_schedule.RewriteLayout"; + TVM_DECLARE_FINAL_OBJECT_INFO(RewriteLayoutNode, PostprocNode); +}; + +Postproc Postproc::RewriteLayout() { + auto n = make_object(); + return Postproc(n); +} + +TVM_REGISTER_NODE_TYPE(RewriteLayoutNode); +TVM_REGISTER_GLOBAL("meta_schedule.PostprocRewriteLayout").set_body_typed(Postproc::RewriteLayout); + +} // namespace meta_schedule +} // namespace tvm diff --git a/tests/python/unittest/test_meta_schedule_postproc_rewrite_layout.py b/tests/python/unittest/test_meta_schedule_postproc_rewrite_layout.py new file mode 100644 index 000000000000..b3e112e0e704 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_postproc_rewrite_layout.py @@ -0,0 +1,91 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring + +import tvm +from tvm.meta_schedule import TuneContext +from tvm.meta_schedule.postproc import RewriteLayout +from tvm.script import tir as T +from tvm.target import Target + + +def _target() -> Target: + return Target("cuda", host="llvm") + + +def _create_context(mod, target) -> TuneContext: + return TuneContext( + mod=mod, + target=target, + postprocs=[ + RewriteLayout(), + ], + task_name="test", + ) + + +@T.prim_func +def tir_matmul( + A: T.Buffer[(16, 16), "float32"], + B: T.Buffer[(16, 16), "float32"], + C: T.Buffer[(16, 16), "float32"], +) -> None: + T.func_attr({"layout_free_buffers": [1]}) + for i0, j, k0, i1, k1 in T.grid(4, 16, 4, 4, 4): + with T.block("matmul"): + vi = T.axis.S(16, i0 * 4 + i1) + vj = T.axis.S(16, j) + vk = T.axis.R(16, k0 * 4 + k1) + with T.init(): + C[vi, vj] = T.float32(0) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + + +@T.prim_func +def rewritten_tir_matmul( + A: T.Buffer[(16, 16), "float32"], + B: T.Buffer[(16, 16), "float32"], + C: T.Buffer[(16, 16), "float32"], +) -> None: + T.func_attr({"layout_free_buffers": [1]}) + B_reindex = T.alloc_buffer([16, 4, 4], dtype="float32") + for ax0, ax1 in T.grid(16, 16): + with T.block("layout_rewrite"): + i0, i1 = T.axis.remap("SS", [ax0, ax1]) + T.block_attr({"meta_schedule.layout_rewrite_preproc": True}) + B_reindex[i1, i0 // 4, i0 % 4] = B[i0, i1] + for i0, j, k0, i1, k1 in T.grid(4, 16, 4, 4, 4): + with T.block("matmul"): + vi = T.axis.spatial(16, i0 * 4 + i1) + vj = T.axis.spatial(16, j) + vk = T.axis.reduce(16, k0 * 4 + k1) + with T.init(): + C[vi, vj] = T.float32(0) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B_reindex[vj, vk // 4, vk % 4] + + +def test_layout_rewrite(): + target = _target() + ctx = _create_context(tir_matmul, target) + sch = tvm.tir.Schedule(tir_matmul, debug_mask="all") + sch.enter_postproc() + assert ctx.postprocs[0].apply(sch) + tvm.ir.assert_structural_equal(sch.mod["main"], rewritten_tir_matmul) + + +if __name__ == "__main__": + test_layout_rewrite()