Skip to content

Commit

Permalink
[MetaSchedule] No explicit for spatial PrimFunc (#11534)
Browse files Browse the repository at this point in the history
  • Loading branch information
junrushao authored Jun 2, 2022
1 parent 480fa74 commit 84eb78c
Show file tree
Hide file tree
Showing 4 changed files with 211 additions and 1 deletion.
7 changes: 6 additions & 1 deletion src/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ bool IsRootBlock(const Schedule& sch, const BlockRV& block_rv) {
return block_sref->parent == nullptr;
}

bool CheckSpatialPrimFunc(const Schedule& sch, const BlockRV& root_block_rv) {
return IsSpatialPrimFunc(
GetRef<PrimFunc>(GetRootPrimFunc(sch->mod(), sch->Get(root_block_rv).get(), nullptr)));
}

} // namespace tir
} // namespace tvm

Expand Down Expand Up @@ -60,7 +65,7 @@ class ParallelizeVectorizeUnrollNode : public ScheduleRuleNode {
sch->Annotate(root_rv, tir::attr::meta_schedule_vectorize, Integer(max_vectorize_extent));
}
// Unroll
if (!unroll_max_steps.empty()) {
if (!unroll_max_steps.empty() && !tir::CheckSpatialPrimFunc(sch, root_rv)) {
int n = unroll_max_steps.size();
double prob = 1.0 / n;
Array<FloatImm> probs(n, FloatImm(DataType::Float(64), prob));
Expand Down
7 changes: 7 additions & 0 deletions src/tir/schedule/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,13 @@ bool IsTrivialBinding(const ScheduleState& self, const StmtSRef& block_sref);
*/
bool NeedsMultiLevelTiling(const ScheduleState& self, const StmtSRef& block_sref);

/*!
* \brief Checks if all the blocks in the PrimFunc is spatial
* \param func The PrimFunc to be checked
* \return A boolean indicating whether all the blocks in the PrimFunc is spatial
*/
bool IsSpatialPrimFunc(const PrimFunc& func);

/*!
* \brief Checks if the rfactor or cross thread reduction is beneficial to the given block.
* \param self The schedule state.
Expand Down
19 changes: 19 additions & 0 deletions src/tir/schedule/analysis/analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1957,6 +1957,25 @@ bool NeedsMultiLevelTiling(const ScheduleState& self, const StmtSRef& block_sref
return total_unused_block_vars >= 1;
}

bool IsSpatialPrimFunc(const PrimFunc& func) {
bool result = true;
PreOrderVisit(func->body, [&result](const ObjectRef& obj) {
if (result == false) {
return false;
}
if (const auto* block = obj.as<BlockNode>()) {
for (const IterVar& iter_var : block->iter_vars) {
if (iter_var->iter_type != IterVarType::kDataPar) {
result = false;
return false;
}
}
}
return true;
});
return result;
}

std::pair<int64_t, int64_t> GetCumulativeSpaceAndReductionLength(const tir::ScheduleState& self,
const tir::StmtSRef& block_sref) {
Array<tir::StmtSRef> loops = tir::GetLoops(block_sref);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring
import tvm
from tvm import meta_schedule as ms
from tvm.meta_schedule.space_generator.post_order_apply import PostOrderApply
from tvm.meta_schedule.testing.schedule_rule import parallel_vectorize_unroll
from tvm.meta_schedule.testing.space_generation import check_trace
Expand Down Expand Up @@ -61,6 +62,164 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None:
C[vi, vj] = 0.0
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]


# from tvm.script import tir as T
@tvm.script.ir_module
class PureSpatial:
@T.prim_func
def main(placeholder: T.Buffer[(1, 13, 13, 3, 85), "float32"], placeholder_1: T.Buffer[(1, 26, 26, 3, 85), "float32"], placeholder_2: T.Buffer[(1, 52, 52, 3, 85), "float32"], T_expand_dims: T.Buffer[(1, 80, 10647), "float32"]) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True})
# body
# with T.block("root")
T_strided_slice_with_axes = T.alloc_buffer([1, 52, 52, 3, 1], dtype="float32")
T_sigmoid = T.alloc_buffer([1, 52, 52, 3, 1], dtype="float32")
T_strided_slice_with_axes_1 = T.alloc_buffer([1, 52, 52, 3, 80], dtype="float32")
T_sigmoid_1 = T.alloc_buffer([1, 52, 52, 3, 80], dtype="float32")
T_multiply = T.alloc_buffer([1, 52, 52, 3, 80], dtype="float32")
T_reshape = T.alloc_buffer([8112, 80], dtype="float32")
T_strided_slice_with_axes_2 = T.alloc_buffer([1, 26, 26, 3, 1], dtype="float32")
T_sigmoid_2 = T.alloc_buffer([1, 26, 26, 3, 1], dtype="float32")
T_strided_slice_with_axes_3 = T.alloc_buffer([1, 26, 26, 3, 80], dtype="float32")
T_sigmoid_3 = T.alloc_buffer([1, 26, 26, 3, 80], dtype="float32")
T_multiply_1 = T.alloc_buffer([1, 26, 26, 3, 80], dtype="float32")
T_reshape_1 = T.alloc_buffer([2028, 80], dtype="float32")
T_strided_slice_with_axes_4 = T.alloc_buffer([1, 13, 13, 3, 1], dtype="float32")
T_sigmoid_4 = T.alloc_buffer([1, 13, 13, 3, 1], dtype="float32")
T_strided_slice_with_axes_5 = T.alloc_buffer([1, 13, 13, 3, 80], dtype="float32")
T_sigmoid_5 = T.alloc_buffer([1, 13, 13, 3, 80], dtype="float32")
T_multiply_2 = T.alloc_buffer([1, 13, 13, 3, 80], dtype="float32")
T_reshape_2 = T.alloc_buffer([507, 80], dtype="float32")
T_concat = T.alloc_buffer([10647, 80], dtype="float32")
T_transpose = T.alloc_buffer([80, 10647], dtype="float32")
for i0, i1, i2, i3, i4 in T.grid(1, 52, 52, 3, 1):
with T.block("T_strided_slice_with_axes"):
ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4])
T.reads(placeholder_2[ax0, ax1, ax2, ax3, T.cast(ax4, "int64") + T.int64(4)])
T.writes(T_strided_slice_with_axes[ax0, ax1, ax2, ax3, ax4])
T_strided_slice_with_axes[ax0, ax1, ax2, ax3, ax4] = placeholder_2[ax0, ax1, ax2, ax3, T.cast(ax4, "int64") + T.int64(4)]
for i0, i1, i2, i3, i4 in T.grid(1, 52, 52, 3, 1):
with T.block("T_sigmoid"):
ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4])
T.reads(T_strided_slice_with_axes[ax0, ax1, ax2, ax3, ax4])
T.writes(T_sigmoid[ax0, ax1, ax2, ax3, ax4])
T_sigmoid[ax0, ax1, ax2, ax3, ax4] = T.sigmoid(T_strided_slice_with_axes[ax0, ax1, ax2, ax3, ax4], dtype="float32")
for i0, i1, i2, i3, i4 in T.grid(1, 52, 52, 3, 80):
with T.block("T_strided_slice_with_axes_1"):
ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4])
T.reads(placeholder_2[ax0, ax1, ax2, ax3, T.cast(ax4, "int64") + T.int64(5)])
T.writes(T_strided_slice_with_axes_1[ax0, ax1, ax2, ax3, ax4])
T_strided_slice_with_axes_1[ax0, ax1, ax2, ax3, ax4] = placeholder_2[ax0, ax1, ax2, ax3, T.cast(ax4, "int64") + T.int64(5)]
for i0, i1, i2, i3, i4 in T.grid(1, 52, 52, 3, 80):
with T.block("T_sigmoid_1"):
ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4])
T.reads(T_strided_slice_with_axes_1[ax0, ax1, ax2, ax3, ax4])
T.writes(T_sigmoid_1[ax0, ax1, ax2, ax3, ax4])
T_sigmoid_1[ax0, ax1, ax2, ax3, ax4] = T.sigmoid(T_strided_slice_with_axes_1[ax0, ax1, ax2, ax3, ax4], dtype="float32")
for i0, i1, i2, i3, i4 in T.grid(1, 52, 52, 3, 80):
with T.block("T_multiply"):
ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4])
T.reads(T_sigmoid[ax0, ax1, ax2, ax3, 0], T_sigmoid_1[ax0, ax1, ax2, ax3, ax4])
T.writes(T_multiply[ax0, ax1, ax2, ax3, ax4])
T_multiply[ax0, ax1, ax2, ax3, ax4] = T_sigmoid[ax0, ax1, ax2, ax3, 0] * T_sigmoid_1[ax0, ax1, ax2, ax3, ax4]
for i0, i1 in T.grid(8112, 80):
with T.block("T_reshape"):
ax0, ax1 = T.axis.remap("SS", [i0, i1])
T.reads(T_multiply[0, (ax1 // 80 + ax0) % 8112 // 156, (ax1 // 80 + ax0) % 156 // 3, (ax1 // 80 + ax0) % 3, ax1 % 80])
T.writes(T_reshape[ax0, ax1])
T_reshape[ax0, ax1] = T_multiply[0, (ax1 // 80 + ax0) % 8112 // 156, (ax1 // 80 + ax0) % 156 // 3, (ax1 // 80 + ax0) % 3, ax1 % 80]
for i0, i1, i2, i3, i4 in T.grid(1, 26, 26, 3, 1):
with T.block("T_strided_slice_with_axes_2"):
ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4])
T.reads(placeholder_1[ax0, ax1, ax2, ax3, T.cast(ax4, "int64") + T.int64(4)])
T.writes(T_strided_slice_with_axes_2[ax0, ax1, ax2, ax3, ax4])
T_strided_slice_with_axes_2[ax0, ax1, ax2, ax3, ax4] = placeholder_1[ax0, ax1, ax2, ax3, T.cast(ax4, "int64") + T.int64(4)]
for i0, i1, i2, i3, i4 in T.grid(1, 26, 26, 3, 1):
with T.block("T_sigmoid_2"):
ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4])
T.reads(T_strided_slice_with_axes_2[ax0, ax1, ax2, ax3, ax4])
T.writes(T_sigmoid_2[ax0, ax1, ax2, ax3, ax4])
T_sigmoid_2[ax0, ax1, ax2, ax3, ax4] = T.sigmoid(T_strided_slice_with_axes_2[ax0, ax1, ax2, ax3, ax4], dtype="float32")
for i0, i1, i2, i3, i4 in T.grid(1, 26, 26, 3, 80):
with T.block("T_strided_slice_with_axes_3"):
ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4])
T.reads(placeholder_1[ax0, ax1, ax2, ax3, T.cast(ax4, "int64") + T.int64(5)])
T.writes(T_strided_slice_with_axes_3[ax0, ax1, ax2, ax3, ax4])
T_strided_slice_with_axes_3[ax0, ax1, ax2, ax3, ax4] = placeholder_1[ax0, ax1, ax2, ax3, T.cast(ax4, "int64") + T.int64(5)]
for i0, i1, i2, i3, i4 in T.grid(1, 26, 26, 3, 80):
with T.block("T_sigmoid_3"):
ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4])
T.reads(T_strided_slice_with_axes_3[ax0, ax1, ax2, ax3, ax4])
T.writes(T_sigmoid_3[ax0, ax1, ax2, ax3, ax4])
T_sigmoid_3[ax0, ax1, ax2, ax3, ax4] = T.sigmoid(T_strided_slice_with_axes_3[ax0, ax1, ax2, ax3, ax4], dtype="float32")
for i0, i1, i2, i3, i4 in T.grid(1, 26, 26, 3, 80):
with T.block("T_multiply_1"):
ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4])
T.reads(T_sigmoid_2[ax0, ax1, ax2, ax3, 0], T_sigmoid_3[ax0, ax1, ax2, ax3, ax4])
T.writes(T_multiply_1[ax0, ax1, ax2, ax3, ax4])
T_multiply_1[ax0, ax1, ax2, ax3, ax4] = T_sigmoid_2[ax0, ax1, ax2, ax3, 0] * T_sigmoid_3[ax0, ax1, ax2, ax3, ax4]
for i0, i1 in T.grid(2028, 80):
with T.block("T_reshape_1"):
ax0, ax1 = T.axis.remap("SS", [i0, i1])
T.reads(T_multiply_1[0, (ax1 // 80 + ax0) % 2028 // 78, (ax1 // 80 + ax0) % 78 // 3, (ax1 // 80 + ax0) % 3, ax1 % 80])
T.writes(T_reshape_1[ax0, ax1])
T_reshape_1[ax0, ax1] = T_multiply_1[0, (ax1 // 80 + ax0) % 2028 // 78, (ax1 // 80 + ax0) % 78 // 3, (ax1 // 80 + ax0) % 3, ax1 % 80]
for i0, i1, i2, i3, i4 in T.grid(1, 13, 13, 3, 1):
with T.block("T_strided_slice_with_axes_4"):
ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4])
T.reads(placeholder[ax0, ax1, ax2, ax3, T.cast(ax4, "int64") + T.int64(4)])
T.writes(T_strided_slice_with_axes_4[ax0, ax1, ax2, ax3, ax4])
T_strided_slice_with_axes_4[ax0, ax1, ax2, ax3, ax4] = placeholder[ax0, ax1, ax2, ax3, T.cast(ax4, "int64") + T.int64(4)]
for i0, i1, i2, i3, i4 in T.grid(1, 13, 13, 3, 1):
with T.block("T_sigmoid_4"):
ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4])
T.reads(T_strided_slice_with_axes_4[ax0, ax1, ax2, ax3, ax4])
T.writes(T_sigmoid_4[ax0, ax1, ax2, ax3, ax4])
T_sigmoid_4[ax0, ax1, ax2, ax3, ax4] = T.sigmoid(T_strided_slice_with_axes_4[ax0, ax1, ax2, ax3, ax4], dtype="float32")
for i0, i1, i2, i3, i4 in T.grid(1, 13, 13, 3, 80):
with T.block("T_strided_slice_with_axes_5"):
ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4])
T.reads(placeholder[ax0, ax1, ax2, ax3, T.cast(ax4, "int64") + T.int64(5)])
T.writes(T_strided_slice_with_axes_5[ax0, ax1, ax2, ax3, ax4])
T_strided_slice_with_axes_5[ax0, ax1, ax2, ax3, ax4] = placeholder[ax0, ax1, ax2, ax3, T.cast(ax4, "int64") + T.int64(5)]
for i0, i1, i2, i3, i4 in T.grid(1, 13, 13, 3, 80):
with T.block("T_sigmoid_5"):
ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4])
T.reads(T_strided_slice_with_axes_5[ax0, ax1, ax2, ax3, ax4])
T.writes(T_sigmoid_5[ax0, ax1, ax2, ax3, ax4])
T_sigmoid_5[ax0, ax1, ax2, ax3, ax4] = T.sigmoid(T_strided_slice_with_axes_5[ax0, ax1, ax2, ax3, ax4], dtype="float32")
for i0, i1, i2, i3, i4 in T.grid(1, 13, 13, 3, 80):
with T.block("T_multiply_2"):
ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4])
T.reads(T_sigmoid_4[ax0, ax1, ax2, ax3, 0], T_sigmoid_5[ax0, ax1, ax2, ax3, ax4])
T.writes(T_multiply_2[ax0, ax1, ax2, ax3, ax4])
T_multiply_2[ax0, ax1, ax2, ax3, ax4] = T_sigmoid_4[ax0, ax1, ax2, ax3, 0] * T_sigmoid_5[ax0, ax1, ax2, ax3, ax4]
for i0, i1 in T.grid(507, 80):
with T.block("T_reshape_2"):
ax0, ax1 = T.axis.remap("SS", [i0, i1])
T.reads(T_multiply_2[0, (ax1 // 80 + ax0) % 507 // 39, (ax1 // 80 + ax0) % 39 // 3, (ax1 // 80 + ax0) % 3, ax1 % 80])
T.writes(T_reshape_2[ax0, ax1])
T_reshape_2[ax0, ax1] = T_multiply_2[0, (ax1 // 80 + ax0) % 507 // 39, (ax1 // 80 + ax0) % 39 // 3, (ax1 // 80 + ax0) % 3, ax1 % 80]
for i0, i1 in T.grid(10647, 80):
with T.block("T_concat"):
ax0, ax1 = T.axis.remap("SS", [i0, i1])
T.reads(T_reshape[ax0 - 2535, ax1], T_reshape_1[ax0 - 507, ax1], T_reshape_2[ax0, ax1])
T.writes(T_concat[ax0, ax1])
T_concat[ax0, ax1] = T.if_then_else(2535 <= ax0, T_reshape[ax0 - 2535, ax1], T.if_then_else(507 <= ax0, T_reshape_1[ax0 - 507, ax1], T_reshape_2[ax0, ax1], dtype="float32"), dtype="float32")
for i0, i1 in T.grid(80, 10647):
with T.block("T_transpose"):
ax0, ax1 = T.axis.remap("SS", [i0, i1])
T.reads(T_concat[ax1, ax0])
T.writes(T_transpose[ax0, ax1])
T_transpose[ax0, ax1] = T_concat[ax1, ax0]
for i0, i1, i2 in T.grid(1, 80, 10647):
with T.block("T_expand_dims"):
ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2])
T.reads(T_transpose[ax1, ax2])
T.writes(T_expand_dims[ax0, ax1, ax2])
T_expand_dims[ax0, ax1, ax2] = T_transpose[ax1, ax2]


# pylint: enable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks
# fmt: on

Expand Down Expand Up @@ -101,5 +260,25 @@ def test_parallel_vectorize_unroll():
check_trace(spaces, expected)


def test_parallel_vectorize_unroll_spatial():
mod = PureSpatial
target = Target("llvm --num-cores=32")
ctx = _create_context(
mod=mod,
target=target,
rule=ms.schedule_rule.ParallelizeVectorizeUnroll(
max_jobs_per_core=-1,
max_vectorize_extent=-1,
unroll_max_steps=[1, 2, 4, 8, 16, 32, 64],
unroll_explicit=True,
),
)
spaces = ctx.space_generator.generate_design_space(mod=mod)
assert len(spaces) == 1
trace = spaces[0].trace.simplified(remove_postproc=True)
assert not trace.insts


if __name__ == "__main__":
test_parallel_vectorize_unroll()
test_parallel_vectorize_unroll_spatial()

0 comments on commit 84eb78c

Please sign in to comment.