diff --git a/include/tvm/te/schedule.h b/include/tvm/te/schedule.h index 6f26d07dc8a5..17aedbcff308 100644 --- a/include/tvm/te/schedule.h +++ b/include/tvm/te/schedule.h @@ -251,6 +251,11 @@ class Stage : public ObjectRef { * \return reference to self. */ TVM_DLL Stage& double_buffer(); // NOLINT(*) + /*! + * \brief Compute current stage with rolling buffering. + * \return reference to self. + */ + TVM_DLL Stage& rolling_buffer(); // NOLINT(*) /*! * \brief whether the stage has been scheduled. * \return whether the stage has been scheduled. @@ -493,6 +498,8 @@ class StageNode : public Object { bool is_output{false}; /*! \brief Whether apply double buffer optimization to this stage */ bool double_buffer{false}; + /*! \brief Whether apply rolling buffer optimization to this stage */ + bool rolling_buffer{false}; /*! * \brief The parent group of the current stage. * The stage cannot be assigned to stages outside the group. diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 4f5772822d9e..066496704e5f 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -1284,6 +1284,8 @@ constexpr const char* double_buffer_scope = "double_buffer_scope"; * \brief Marks region used by double buffer write */ constexpr const char* double_buffer_write = "double_buffer_write"; +/*! \brief Mark realization for rolling buffer optimization */ +constexpr const char* rolling_buffer_scope = "rolling_buffer_scope"; /*! \brief Mark of scan update scope */ constexpr const char* scan_update_scope = "scan_update_scope"; /*! \brief Mark of scan init scope */ diff --git a/python/tvm/te/schedule.py b/python/tvm/te/schedule.py index 7bd7dceb03e5..55d07a57e3e4 100644 --- a/python/tvm/te/schedule.py +++ b/python/tvm/te/schedule.py @@ -511,6 +511,14 @@ def double_buffer(self): """ _ffi_api.StageDoubleBuffer(self) + def rolling_buffer(self): + """Compute the current stage via rolling buffering. + + This can only be applied to intermediate stage. + This will change the storage cost of the current stage. + """ + _ffi_api.StageRollingBuffer(self) + @tvm._ffi.register_object class SpecializedCondition(Object): diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index 722810e9aa5b..69e8821f7423 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -206,6 +206,17 @@ def InjectDoubleBuffer(): return _ffi_api.InjectDoubleBuffer() # type: ignore +def InjectRollingBuffer(): + """Inject rolling buffer statements. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.InjectRollingBuffer() # type: ignore + + def StorageRewrite(): """Rewrite storage allocation pattern. diff --git a/src/te/operation/compute_op.cc b/src/te/operation/compute_op.cc index c73a6e0ce120..c3062045939a 100644 --- a/src/te/operation/compute_op.cc +++ b/src/te/operation/compute_op.cc @@ -484,7 +484,8 @@ ComputeLoopNest ComputeLoopNest::Create(const BaseComputeOpNode* self, const Sta } ret.init_nest = MakeLoopNest(stage, dom_map, begin_loop, true, skip_iter, &(ret.init_vmap), debug_keep_trivial_loop); - ret.init_predicates = MakeBoundCheck(stage, dom_map, ret.init_vmap, true, skip_iter); + ret.init_predicates = + MakeBoundCheck(stage, dom_map, ret.init_vmap, !stage->rolling_buffer, skip_iter); for (auto& e : ret.init_predicates) { e = likely(e); } diff --git a/src/te/schedule/schedule_lang.cc b/src/te/schedule/schedule_lang.cc index 5d71c5345fd0..2f74d2905454 100644 --- a/src/te/schedule/schedule_lang.cc +++ b/src/te/schedule/schedule_lang.cc @@ -423,6 +423,13 @@ Stage& Stage::double_buffer() { return *this; } +Stage& Stage::rolling_buffer() { + StageNode* self = operator->(); + ICHECK(!self->is_output) << "Cannot apply rolling buffer on output"; + self->rolling_buffer = true; + return *this; +} + Stage CopyStage(const Stage& s) { ObjectPtr n = make_object(*s.operator->()); return Stage(n); @@ -886,6 +893,8 @@ TVM_REGISTER_GLOBAL("te.StageStorageAlign").set_body_method(&Stage::storage_alig TVM_REGISTER_GLOBAL("te.StageDoubleBuffer").set_body_method(&Stage::double_buffer); +TVM_REGISTER_GLOBAL("te.StageRollingBuffer").set_body_method(&Stage::rolling_buffer); + TVM_REGISTER_GLOBAL("te.ScheduleNormalize").set_body_method(&Schedule::normalize); TVM_REGISTER_GLOBAL("te.ScheduleCreateGroup").set_body_method(&Schedule::create_group); diff --git a/src/te/schedule/schedule_ops.cc b/src/te/schedule/schedule_ops.cc index 825092d20ac0..1568df4670af 100644 --- a/src/te/schedule/schedule_ops.cc +++ b/src/te/schedule/schedule_ops.cc @@ -52,6 +52,10 @@ Stmt MakePipeline(const Stage& s, const std::unordered_map& dom_ pipeline = SeqStmt({producer, consumer}); } + if (s->rolling_buffer) { + pipeline = AttrStmt(s->op, tir::attr::rolling_buffer_scope, Bool(true), pipeline); + } + return s->op->BuildRealize(s, dom_map, pipeline, s->scope); } diff --git a/src/te/schedule/schedule_postproc_to_primfunc.cc b/src/te/schedule/schedule_postproc_to_primfunc.cc index 439d0ff17255..7e8b12b6d61e 100644 --- a/src/te/schedule/schedule_postproc_to_primfunc.cc +++ b/src/te/schedule/schedule_postproc_to_primfunc.cc @@ -67,7 +67,8 @@ class TensorToBufferMapper : public StmtExprMutator { Stmt VisitStmt_(const AttrStmtNode* op) final { auto ret = StmtExprMutator::VisitStmt_(op); op = ret.as(); - if (op->attr_key == tir::attr::double_buffer_scope) { + if (op->attr_key == tir::attr::double_buffer_scope || + op->attr_key == tir::attr::rolling_buffer_scope) { Stmt body = op->body; Operation operation = Downcast(op->node); for (int i = operation->num_outputs(); i != 0; --i) { diff --git a/src/tir/transforms/inject_rolling_buffer.cc b/src/tir/transforms/inject_rolling_buffer.cc new file mode 100644 index 000000000000..bc4012cf0556 --- /dev/null +++ b/src/tir/transforms/inject_rolling_buffer.cc @@ -0,0 +1,317 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file inject_rolling_buffer.cc + * \brief Inject rolling buffer statements. + + Rolling buffers are buffers where one of the dimensions has been made into + a circular buffer. Two optimizations are implemented in order to accomplish + this: sliding window and storage folding. In particular, the sliding window + optimization is applied to the entire buffer (to avoid recomputing elements) + and storage folding is then applied to just the rolling dimension. + + Rolling buffers must be inside a loop with only part of the buffer used per + iteration. The outermost axis will be rolled over. + + For more information, see the RFC: + https://discuss.tvm.apache.org/t/rfc-introducing-a-rolling-buffer-scheduling-primitive/9836 + */ +#include +#include +#include +#include + +#include "ir_utils.h" + +namespace tvm { +namespace tir { + +using arith::IntSet; + +struct RollingBufferInfo { + int rolling_axis; + int rolling_extent; + std::vector axis_overlaps; + std::vector> axis_iter_vars; +}; + +class RollingBufferInjector : public StmtExprMutator { + std::vector for_loops{}; + std::set rolling_buffers{}; + std::map buffer_to_buffer_realize{}; + std::map> buffer_to_attrs{}; + std::map rolling_buffer_to_info{}; + // The actual key type is Var, ObjectRef has been used because + // of the ambiguous overload for ‘operator<’ + std::map> hoist_buffer_to_for{}; + + public: + RollingBufferInjector() {} + + Stmt Inject(Stmt stmt) { return ConvertSSA(operator()(std::move(stmt))); } + + Stmt VisitStmt_(const ForNode* op) final { + // Manage the stack of iter_vars + for_loops.push_back(GetRef(op)); + + auto stmt{StmtExprMutator::VisitStmt_(op)}; + op = stmt.as(); + + // Manage the stack of iter_vars + for_loops.pop_back(); + + auto it{hoist_buffer_to_for.find(op->loop_var)}; + if (it != hoist_buffer_to_for.end()) { + // If the loop corresponds to an iter_var that needs a BufferRealize + // hoisting to its scope, perform the hoisting + Stmt body{GetRef(op)}; + for (auto realise : it->second) { + auto attrs{buffer_to_attrs[realise->buffer]}; + Stmt new_realize{BufferRealize(realise->buffer, realise->bounds, realise->condition, body, + realise->span)}; + // The attributes attached to the BufferRealize need hoisting too + for (auto attr : attrs) { + if (attr->attr_key == attr::rolling_buffer_scope) { + continue; + } + new_realize = AttrStmt(attr->node, attr->attr_key, attr->value, new_realize, attr->span); + } + body = new_realize; + } + return body; + } else { + return stmt; + } + } + + Stmt VisitStmt_(const AttrStmtNode* op) final { + if (auto b = op->node.as()) { + auto buffer = GetRef(b); + // Keep a dictionary associating attribute statements with the buffers + // they reference. We'll need this if the buffer gets hoisted and we + // need to hoist all of its attributes at the same time. + buffer_to_attrs[buffer].push_back(GetRef(op)); + + if (op->attr_key == attr::rolling_buffer_scope && Downcast(op->value)->value) { + // If the attribute is indicating that a buffer should be a rolling + // buffer, then update the rolling_buffers set to include the buffer + rolling_buffers.insert(buffer); + + auto it{buffer_to_buffer_realize.find(buffer)}; + ICHECK(it != buffer_to_buffer_realize.end()) + << "Rolling buffer injection failed: no BufferRealize found"; + BufferRealize buffer_realize = it->second; + + // If a BufferRealize has been identified as needing to be made into + // a rolling buffer, begin the analysis. + std::vector> bound_iter_vars{}; + std::vector bound_overlaps{}; + // We use the bound information of the BufferRealize to calculate + // how we can legally roll + auto stride{0}; + Optional iter_var{}; + for (auto bound : buffer_realize->bounds) { + if (auto floor_div = bound->min.as()) { + // Handle the case of fractional strides + // They take this form: floordiv(hh.outer, 2) + // Strip the floordiv and keep track of the divisor + auto divisor{Downcast(floor_div->b)->value}; + bound = Range::FromMinExtent(floor_div->a, bound->extent, bound->span); + stride = std::ceil(stride / divisor); + } + if (bound->min.as()) { + // If the bound is an int, we can't roll over it + iter_var = nullptr; + } else if (auto var = bound->min.as()) { + // If the bound is just a Var, that implies the stride is 1 + iter_var = GetRef(var); + stride = 1; + } else { + // Otherwise, it's the iter var multiplied by the stride + // If not we're in unknown behaviour, so assert + auto mul = bound->min.as(); + ICHECK(mul) << "Rolling buffer injection failed: the buffer striding is unsupported"; + auto a = mul->a.as(); + ICHECK(a) << "Rolling buffer injection failed: the buffer striding is unsupported"; + auto b = mul->b.as(); + ICHECK(b) << "Rolling buffer injection failed: the buffer striding is unsupported"; + iter_var = GetRef(a); + stride = b->value; + } + bound_iter_vars.push_back(iter_var); + if (iter_var) { + bound_overlaps.push_back(Downcast(bound->extent)->value - stride); + } else { + bound_overlaps.push_back(0); + } + } + // Pick the outermost iter_var that's mentioned in the bounds + // to be the rolling axis + Optional roll_iter_var{}; + int roll_axis{1}; + for (auto loop : for_loops) { + auto loop_var{loop->loop_var}; + iter_var = loop_var; + + auto it{std::find_if( + bound_iter_vars.begin(), bound_iter_vars.end(), + [&](Optional var) { return var && (var.value().get() == loop_var.get()); })}; + + if (it != bound_iter_vars.end()) { + auto i{std::distance(bound_iter_vars.begin(), it)}; + roll_iter_var = loop_var; + roll_axis = i; + break; + } + } + // We must have found an axis to roll over + ICHECK(roll_iter_var) << "Rolling buffer injection failed: no rolling axis found"; + ICHECK(roll_axis != -1) << "Rolling buffer injection failed: no rolling axis found"; + + RollingBufferInfo rolling_buffer_info = { + roll_axis, + static_cast(Downcast(buffer_realize->bounds[roll_axis]->extent)->value), + bound_overlaps, + bound_iter_vars, + }; + rolling_buffer_to_info[buffer] = rolling_buffer_info; + Array new_bounds{}; + auto shape{buffer->shape}; + for (size_t i{0}; i < shape.size(); ++i) { + auto extent{shape[i]}; + if (static_cast(i) == rolling_buffer_info.rolling_axis) { + new_bounds.push_back(Range(0, rolling_buffer_info.rolling_extent)); + } else { + new_bounds.push_back(Range(0, extent)); + } + } + BufferRealize new_realize{BufferRealize(buffer, new_bounds, buffer_realize->condition, + buffer_realize->body, buffer_realize->span)}; + hoist_buffer_to_for[iter_var.value()].push_back(new_realize); + } + } + + auto stmt{StmtExprMutator::VisitStmt_(op)}; + op = stmt.as(); + + if (rolling_buffers.count(GetRef(op->node.as()))) { + // Remove the attribute statements attached to rolling buffers + // because they will have been hoisted to the relevant rolling + // scope + return op->body; + } else { + return stmt; + } + } + + Stmt VisitStmt_(const BufferRealizeNode* op) final { + buffer_to_buffer_realize.insert({op->buffer, GetRef(op)}); + + auto stmt{StmtExprMutator::VisitStmt_(op)}; + op = stmt.as(); + + if (rolling_buffers.count(op->buffer)) { + // Remove the original BufferRealize for rolling buffers + // because they will have been hoisted to the relevant rolling + // scope + return op->body; + } else { + return stmt; + } + } + + Stmt VisitStmt_(const BufferStoreNode* op) final { + auto stmt{StmtExprMutator::VisitStmt_(op)}; + op = stmt.as(); + + auto it{rolling_buffer_to_info.find(op->buffer)}; + if (it != rolling_buffer_to_info.end()) { + auto rolling_buffer_info{it->second}; + std::vector indices{}; + // First modify the access indices to use modulo arithmetic + // for the rolling axis + for (size_t i{0}; i < op->indices.size(); ++i) { + auto index{op->indices[i]}; + if (static_cast(i) == rolling_buffer_info.rolling_axis) { + indices.push_back(FloorMod(index, rolling_buffer_info.rolling_extent)); + } else { + indices.push_back(index); + } + } + Stmt buffer_store = BufferStore(op->buffer, op->value, indices, op->span); + // Then wrap the BufferStores in some Ifs to avoid recomputing elements + for (size_t i{0}; i < rolling_buffer_info.axis_iter_vars.size(); ++i) { + auto iter_var{rolling_buffer_info.axis_iter_vars[i]}; + if (iter_var && rolling_buffer_info.axis_overlaps[i] > 0) { + Var var{iter_var.value()}; + const Map dmap{std::make_pair(var, IntSet::Interval(0, 0))}; + auto term_2{arith::Analyzer{}.int_set(op->indices[i], dmap).min()}; + buffer_store = IfThenElse( + Or(LT(var, 1), GE(term_2, rolling_buffer_info.axis_overlaps[i])), buffer_store); + } + } + return buffer_store; + } else { + return stmt; + } + } + + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + auto expr{StmtExprMutator::VisitExpr_(op)}; + op = expr.as(); + + auto it{rolling_buffer_to_info.find(op->buffer)}; + if (it != rolling_buffer_to_info.end()) { + auto rolling_buffer_info{it->second}; + std::vector indices{}; + // Modify the access indices to use modulo arithmetic + // for the rolling axis + for (size_t i{0}; i < op->indices.size(); ++i) { + auto index{op->indices[i]}; + if (static_cast(i) == rolling_buffer_info.rolling_axis) { + indices.push_back(FloorMod(index, rolling_buffer_info.rolling_extent)); + } else { + indices.push_back(index); + } + } + return BufferLoad(op->buffer, indices, op->span); + } else { + return expr; + } + } +}; // namespace tir + +namespace transform { + +Pass InjectRollingBuffer() { + auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { + auto* n = f.CopyOnWrite(); + n->body = RollingBufferInjector().Inject(std::move(n->body)); + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tir.InjectRollingBuffer", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.InjectRollingBuffer").set_body_typed(InjectRollingBuffer); + +} // namespace transform + +} // namespace tir +} // namespace tvm diff --git a/tests/python/unittest/test_tir_transform_inject_rolling_buffer.py b/tests/python/unittest/test_tir_transform_inject_rolling_buffer.py new file mode 100644 index 000000000000..2298fe94da18 --- /dev/null +++ b/tests/python/unittest/test_tir_transform_inject_rolling_buffer.py @@ -0,0 +1,265 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +import tvm.script +from tvm.script import tir as T +from tvm import te +from tvm import topi +from tvm.driver.build_module import get_binds +import numpy as np + +import pytest + + +def _tile_nd(s, tensor, tile): + outer_indices = [] + inner_indices = [] + for i, size in enumerate(tile): + outer, inner = s[tensor].split(tensor.op.axis[i], size) + outer_indices.append(outer) + inner_indices.append(inner) + + s[tensor].reorder(*outer_indices, *inner_indices) + return outer_indices, inner_indices + + +def _lower_schedule(sch, args): + sch = sch.normalize() + bounds = tvm.te.schedule.InferBound(sch) + stmt = tvm.te.schedule.ScheduleOps(sch, bounds) + + compact = tvm.te.schedule.VerifyCompactBuffer(stmt) + binds, arg_list = get_binds(args, compact, None) + func = tvm.te.schedule.SchedulePostProcToPrimFunc(arg_list, stmt, binds) + + func = func.with_attr("global_symbol", "main") + func = func.with_attr("tir.noalias", True) + mod = tvm.IRModule({"main": func}) + return mod + + +def _verify_schedule(sch, inputs, output): + mod = _lower_schedule(sch, inputs + [output]) + mods = [] + mods.append(mod) + mod = tvm.tir.transform.InjectRollingBuffer()(mod) + + def _check(stmt): + if isinstance(stmt, tvm.tir.AttrStmt): + assert stmt.attr_key != "rolling_buffer_scope", "Failed to lower rolling buffers" + + tvm.tir.stmt_functor.post_order_visit(mod["main"].body, _check) + mods.append(mod) + + outputs = [] + ctx = tvm.cpu(0) + input_data = [] + for tensor in inputs: + shape = [i.value for i in tensor.shape] + input_data.append( + tvm.nd.array(np.random.randint(low=-100, high=100, size=shape).astype("int8"), ctx) + ) + shape = [i.value for i in output.shape] + out = tvm.nd.array(np.zeros(shape, dtype="int8"), ctx) + for mod in mods: + mod = tvm.tir.transform.StorageFlatten(64)(mod) + mod = tvm.tir.transform.NarrowDataType(32)(mod) + mod = tvm.tir.transform.LoopPartition()(mod) + mod = tvm.tir.transform.StorageRewrite()(mod) + # Build for CPU execution + f = tvm.build(mod) + f(*input_data, out) + outputs.append(out.asnumpy()) + + np.testing.assert_equal(outputs[0], outputs[1]) + + +@pytest.mark.parametrize("tile_shape", [(1, 4, 8, 16), (1, 8, 7, 11), (1, 8, 3, 8), (1, 7, 5, 3)]) +def test_tile_shapes(tile_shape): + A = te.placeholder((1, 12, 14, 16), name="A", dtype="int8") + pool_a = topi.nn.pool2d(A, (3, 3), (1, 1), (1, 1), (0, 0, 0, 0), "max", layout="NHWC") + pool_b = topi.nn.pool2d(pool_a, (3, 5), (1, 1), (1, 1), (0, 0, 0, 0), "max", layout="NHWC") + + sch = tvm.te.create_schedule([pool_b.op]) + oi, ii = _tile_nd(sch, pool_b, tile_shape) + sch[pool_a].compute_at(sch[pool_b], oi[-1]) + sch[pool_a].rolling_buffer() + + _verify_schedule(sch, [A], pool_b) + + +def test_implied_split(): + A = te.placeholder((1, 12, 12, 16), name="A", dtype="int8") + pool_a = topi.nn.pool2d(A, (3, 3), (1, 1), (1, 1), (0, 0, 0, 0), "max", layout="NHWC") + pool_b = topi.nn.pool2d(pool_a, (3, 3), (1, 1), (1, 1), (0, 0, 0, 0), "max", layout="NHWC") + + sch = tvm.te.create_schedule([pool_b.op]) + n, h, w, c = pool_b.op.axis + oi, ii = sch[pool_b].split(w, 4) + sch[pool_a].compute_at(sch[pool_b], oi) + sch[pool_a].rolling_buffer() + + _verify_schedule(sch, [A], pool_b) + + +def test_upscale(): + A = te.placeholder((1, 12, 12, 16), name="A", dtype="int8") + pool = topi.nn.pool2d(A, (1, 1), (1, 1), (1, 1), (0, 0, 0, 0), "max", layout="NHWC") + upscale = te.compute((1, 24, 24, 16), lambda nn, hh, ww, cc: pool[nn, hh // 2, ww // 2, cc]) + + sch = tvm.te.create_schedule([upscale.op]) + oi, ii = _tile_nd(sch, upscale, (1, 5, 5, 16)) + sch[pool].compute_at(sch[upscale], oi[-1]) + sch[pool].rolling_buffer() + + _verify_schedule(sch, [A], upscale) + + +@pytest.mark.parametrize("tile_shape", [(1, 4, 8, 16), (1, 8, 7, 11), (1, 8, 3, 8), (1, 7, 5, 3)]) +def test_3_tiled_poolings(tile_shape): + A = te.placeholder((1, 14, 14, 16), name="A", dtype="int8") + pool_a = topi.nn.pool2d(A, (3, 3), (1, 1), (1, 1), (0, 0, 0, 0), "max", layout="NHWC") + pool_b = topi.nn.pool2d(pool_a, (3, 3), (1, 1), (1, 1), (0, 0, 0, 0), "max", layout="NHWC") + pool_c = topi.nn.pool2d(pool_b, (3, 3), (1, 1), (1, 1), (0, 0, 0, 0), "max", layout="NHWC") + + sch = tvm.te.create_schedule([pool_c.op]) + oi, ii = _tile_nd(sch, pool_c, tile_shape) + sch[pool_b].compute_at(sch[pool_c], oi[-1]) + sch[pool_b].rolling_buffer() + sch[pool_a].compute_at(sch[pool_c], oi[-1]) + sch[pool_a].rolling_buffer() + + _verify_schedule(sch, [A], pool_c) + + +@pytest.mark.parametrize("tile_shape", [(1, 4, 8, 16), (1, 8, 7, 11), (1, 8, 3, 8), (1, 7, 5, 3)]) +def test_tiled_added_poolings(tile_shape): + A = te.placeholder((1, 12, 12, 16), name="A", dtype="int8") + B = te.placeholder((1, 14, 14, 16), name="A", dtype="int8") + pool_a = topi.nn.pool2d(A, (3, 3), (1, 1), (1, 1), (0, 0, 0, 0), "max", layout="NHWC") + pool_b = topi.nn.pool2d(B, (5, 5), (1, 1), (1, 1), (0, 0, 0, 0), "max", layout="NHWC") + add = topi.add(pool_a, pool_b) + pool_c = topi.nn.pool2d(add, (3, 3), (1, 1), (1, 1), (0, 0, 0, 0), "max", layout="NHWC") + + sch = tvm.te.create_schedule([pool_c.op]) + oi, ii = _tile_nd(sch, pool_c, tile_shape) + sch[add].compute_at(sch[pool_c], oi[-1]) + sch[add].rolling_buffer() + sch[pool_b].compute_at(sch[pool_c], oi[-1]) + sch[pool_b].rolling_buffer() + sch[pool_a].compute_at(sch[pool_c], oi[-1]) + sch[pool_a].rolling_buffer() + + _verify_schedule(sch, [A, B], pool_c) + + +@pytest.mark.parametrize("make_rolling", [(0, 0), (1, 0), (0, 1), (1, 1)]) +def test_mixed_buffers(make_rolling): + A = te.placeholder((1, 14, 14, 16), name="A", dtype="int8") + pool_a = topi.nn.pool2d(A, (3, 3), (1, 1), (1, 1), (0, 0, 0, 0), "max", layout="NHWC") + pool_b = topi.nn.pool2d(pool_a, (3, 3), (1, 1), (1, 1), (0, 0, 0, 0), "max", layout="NHWC") + pool_c = topi.nn.pool2d(pool_b, (3, 3), (1, 1), (1, 1), (0, 0, 0, 0), "max", layout="NHWC") + + sch = tvm.te.create_schedule([pool_c.op]) + oi, ii = _tile_nd(sch, pool_c, (1, 4, 8, 16)) + sch[pool_b].compute_at(sch[pool_c], oi[-1]) + if make_rolling[0]: + sch[pool_b].rolling_buffer() + sch[pool_a].compute_at(sch[pool_c], oi[-1]) + if make_rolling[1]: + sch[pool_a].rolling_buffer() + + _verify_schedule(sch, [A], pool_c) + + +# fmt: off +@tvm.script.ir_module +class PreRollingBuffer: + @T.prim_func + def main(A: T.handle, tensor: T.handle) -> None: + # function attr dict + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + # buffer definition + tensor_2 = T.buffer_decl([1, 10, 12, 16], dtype="int8", elem_offset=0, align=128, offset_factor=1) + A_1 = T.match_buffer(A, [1, 12, 14, 16], dtype="int8", elem_offset=0, align=128, offset_factor=1) + tensor_1 = T.match_buffer(tensor, [1, 8, 8, 16], dtype="int8", elem_offset=0, align=128, offset_factor=1) + # body + T.realize(tensor_1[0:1, 0:8, 0:8, 0:16], "") + for ax1_outer in T.serial(0, 2): + T.realize(tensor_2[0:1, (ax1_outer*4):((ax1_outer*4) + 6), 0:12, 0:16], "") + T.attr(tensor_2, "rolling_buffer_scope", True) + for ax1 in T.serial(0, 6): + for ax2 in T.serial(0, 12): + for ax3 in T.serial(0, 16): + tensor_2[0, (ax1 + (ax1_outer*4)), ax2, ax3] = T.int8(0) + for dh in T.serial(0, 3): + for dw in T.serial(0, 3): + tensor_2[0, (ax1 + (ax1_outer*4)), ax2, ax3] = T.max(tensor_2[0, (ax1 + (ax1_outer*4)), ax2, ax3], A_1[0, ((ax1 + (ax1_outer*4)) + dh), (ax2 + dw), ax3]) + for ax1_inner in T.serial(0, 4): + for ax2_inner in T.serial(0, 8): + for ax3_inner in T.serial(0, 16): + tensor_1[0, (ax1_inner + (ax1_outer*4)), ax2_inner, ax3_inner] = T.int8(0) + for dh_1 in T.serial(0, 3): + for dw_1 in T.serial(0, 5): + tensor_1[0, (ax1_inner + (ax1_outer*4)), ax2_inner, ax3_inner] = T.max(tensor_1[0, (ax1_inner + (ax1_outer*4)), ax2_inner, ax3_inner], tensor_2[0, ((ax1_inner + (ax1_outer*4)) + dh_1), (ax2_inner + dw_1), ax3_inner]) + __tvm_meta__ = None + + +@tvm.script.ir_module +class PostRollingBuffer: + @T.prim_func + def main(A: T.handle, tensor: T.handle) -> None: + # function attr dict + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + # buffer definition + tensor_2 = T.buffer_decl([1, 10, 12, 16], dtype="int8", elem_offset=0, align=128, offset_factor=1) + A_1 = T.match_buffer(A, [1, 12, 14, 16], dtype="int8", elem_offset=0, align=128, offset_factor=1) + tensor_1 = T.match_buffer(tensor, [1, 8, 8, 16], dtype="int8", elem_offset=0, align=128, offset_factor=1) + # body + T.realize(tensor_1[0:1, 0:8, 0:8, 0:16], "") + T.realize(tensor_2[0:1, 0:6, 0:12, 0:16], "") + for ax1_outer in T.serial(0, 2): + for ax1 in T.serial(0, 6): + for ax2 in T.serial(0, 12): + for ax3 in T.serial(0, 16): + if ((ax1_outer < 1) or (ax1 >= 2)): + tensor_2[0, T.floormod((ax1 + (ax1_outer*4)), 6), ax2, ax3] = T.int8(0) + for dh in T.serial(0, 3): + for dw in T.serial(0, 3): + if ((ax1_outer < 1) or (ax1 >= 2)): + tensor_2[0, T.floormod((ax1 + (ax1_outer*4)), 6), ax2, ax3] = T.max(tensor_2[0, T.floormod((ax1 + (ax1_outer*4)), 6), ax2, ax3], A_1[0, ((ax1 + (ax1_outer*4)) + dh), (ax2 + dw), ax3]) + for ax1_inner in T.serial(0, 4): + for ax2_inner in T.serial(0, 8): + for ax3_inner in T.serial(0, 16): + tensor_1[0, (ax1_inner + (ax1_outer*4)), ax2_inner, ax3_inner] = T.int8(0) + for dh_1 in T.serial(0, 3): + for dw_1 in T.serial(0, 5): + tensor_1[0, (ax1_inner + (ax1_outer*4)), ax2_inner, ax3_inner] = T.max(tensor_1[0, (ax1_inner + (ax1_outer*4)), ax2_inner, ax3_inner], tensor_2[0, T.floormod(((ax1_inner + (ax1_outer*4)) + dh_1), 6), (ax2_inner + dw_1), ax3_inner]) + __tvm_meta__ = None +# fmt: on + + +def test_rolling_buffer_ir_transform(): + mod = PreRollingBuffer + mod = tvm.tir.transform.InjectRollingBuffer()(mod) + script = mod.script(show_meta=True) + mod = tvm.script.from_source(script) + tvm.ir.assert_structural_equal(mod["main"], PostRollingBuffer["main"], True) + + +if __name__ == "__main__": + pytest.main([__file__])