Skip to content
This repository has been archived by the owner on Nov 25, 2022. It is now read-only.

Commit

Permalink
[MetaSchedule][Test] Add unittests for CAP (apache#12047)
Browse files Browse the repository at this point in the history
  • Loading branch information
junrushao authored and xinetzone committed Nov 25, 2022
1 parent 8d445b8 commit 30d80ff
Show file tree
Hide file tree
Showing 2 changed files with 296 additions and 0 deletions.
194 changes: 194 additions & 0 deletions tests/python/unittest/test_meta_schedule_space_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,7 +548,201 @@ def c3d_2(inputs: T.Buffer[(1, 16, 224, 224, 3), "float32"], weight: T.Buffer[(7
)


def test_cpu_cap():
# fmt: off
@T.prim_func
def cap_0(inputs: T.Buffer[(1, 16, 16, 4, 4, 32), "float32"], weight: T.Buffer[(3, 3, 4, 4, 32, 32), "float32"], conv2d_capsule_nhwijc: T.Buffer[(1, 8, 8, 4, 4, 32), "float32"]) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True})
# body
with T.block("root"):
T.reads()
T.writes()
T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":0, "meta_schedule.vectorize":64})
PadInput = T.alloc_buffer([1, 18, 18, 4, 4, 32], dtype="float32")
conv2d_capsule_nhwijc_global = T.alloc_buffer([1, 8, 8, 4, 4, 32], dtype="float32")
for i0_0, i1_0, i2_0, i3_0, i4_0, i5_0, i0_1, i1_1 in T.grid(1, 2, 1, 1, 1, 1, 1, 4):
for ax0, ax1, ax2, ax3, ax4, ax5 in T.grid(1, 3, 17, 4, 4, 32):
with T.block("PadInput"):
i0 = T.axis.spatial(1, ax0)
i1 = T.axis.spatial(18, i1_0 * 8 + i1_1 * 2 + ax1)
i2 = T.axis.spatial(18, ax2)
i3, i4, i5 = T.axis.remap("SSS", [ax3, ax4, ax5])
T.reads(inputs[i0, i1 - 1, i2 - 1, i3, i4, i5])
T.writes(PadInput[i0, i1, i2, i3, i4, i5])
PadInput[i0, i1, i2, i3, i4, i5] = T.if_then_else(1 <= i1 and i1 < 17 and 1 <= i2 and i2 < 17, inputs[i0, i1 - 1, i2 - 1, i3, i4, i5], T.float32(0), dtype="float32")
for i2_1, i3_1, i4_1, i5_1 in T.grid(4, 1, 4, 2):
for i6_0, i7_0, i8_0, i9_0, i0_2, i1_2, i2_2, i3_2, i4_2, i5_2, i6_1, i7_1, i8_1, i9_1, i0_3, i1_3, i2_3, i3_3, i4_3, i5_3 in T.grid(1, 3, 4, 1, 1, 1, 2, 1, 1, 1, 3, 1, 1, 32, 1, 1, 1, 4, 1, 16):
with T.block("conv2d_capsule_nhwijc"):
n = T.axis.spatial(1, i0_3 + i0_2 + i0_1 + i0_0)
h = T.axis.spatial(8, i1_0 * 4 + i1_1 + i1_2 + i1_3)
w = T.axis.spatial(8, (i2_0 * 4 + i2_1) * 2 + i2_2 + i2_3)
cap_i = T.axis.spatial(4, (i3_0 + i3_1 + i3_2) * 4 + i3_3)
cap_j = T.axis.spatial(4, i4_0 * 4 + i4_1 + i4_2 + i4_3)
co = T.axis.spatial(32, (i5_0 * 2 + i5_1 + i5_2) * 16 + i5_3)
rh = T.axis.reduce(3, i6_0 * 3 + i6_1)
rw = T.axis.reduce(3, i7_0 + i7_1)
cap_k = T.axis.reduce(4, i8_0 + i8_1)
rc = T.axis.reduce(32, i9_0 * 32 + i9_1)
T.reads(PadInput[n, h * 2 + rh, w * 2 + rw, cap_i, cap_k, rc], weight[rh, rw, cap_k, cap_j, rc, co])
T.writes(conv2d_capsule_nhwijc_global[n, h, w, cap_i, cap_j, co])
T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"})
with T.init():
conv2d_capsule_nhwijc_global[n, h, w, cap_i, cap_j, co] = T.float32(0)
conv2d_capsule_nhwijc_global[n, h, w, cap_i, cap_j, co] = conv2d_capsule_nhwijc_global[n, h, w, cap_i, cap_j, co] + PadInput[n, h * 2 + rh, w * 2 + rw, cap_i, cap_k, rc] * weight[rh, rw, cap_k, cap_j, rc, co]
for ax0, ax1, ax2, ax3, ax4, ax5 in T.grid(1, 1, 2, 4, 1, 16):
with T.block("conv2d_capsule_nhwijc_global"):
v0 = T.axis.spatial(1, ax0)
v1 = T.axis.spatial(8, i1_0 * 4 + i1_1 + ax1)
v2 = T.axis.spatial(8, i2_1 * 2 + ax2)
v3 = T.axis.spatial(4, ax3)
v4 = T.axis.spatial(4, i4_1 + ax4)
v5 = T.axis.spatial(32, i5_1 * 16 + ax5)
T.reads(conv2d_capsule_nhwijc_global[v0, v1, v2, v3, v4, v5])
T.writes(conv2d_capsule_nhwijc[v0, v1, v2, v3, v4, v5])
conv2d_capsule_nhwijc[v0, v1, v2, v3, v4, v5] = conv2d_capsule_nhwijc_global[v0, v1, v2, v3, v4, v5]
@T.prim_func
def cap_1(inputs: T.Buffer[(1, 16, 16, 4, 4, 32), "float32"], weight: T.Buffer[(3, 3, 4, 4, 32, 32), "float32"], conv2d_capsule_nhwijc: T.Buffer[(1, 8, 8, 4, 4, 32), "float32"]) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True})
# body
with T.block("root"):
T.reads()
T.writes()
T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":0, "meta_schedule.vectorize":64})
PadInput = T.alloc_buffer([1, 18, 18, 4, 4, 32], dtype="float32")
conv2d_capsule_nhwijc_global = T.alloc_buffer([1, 8, 8, 4, 4, 32], dtype="float32")
for i0_0, i1_0, i2_0, i3_0, i4_0, i5_0 in T.grid(1, 2, 1, 1, 1, 1):
for i0_1, i1_1, i2_1, i3_1, i4_1, i5_1 in T.grid(1, 4, 4, 1, 4, 2):
for ax0, ax1, ax2, ax3, ax4, ax5 in T.grid(1, 3, 5, 4, 4, 32):
with T.block("PadInput"):
i0 = T.axis.spatial(1, ax0)
i1 = T.axis.spatial(18, i1_0 * 8 + i1_1 * 2 + ax1)
i2 = T.axis.spatial(18, i2_1 * 4 + ax2)
i3, i4, i5 = T.axis.remap("SSS", [ax3, ax4, ax5])
T.reads(inputs[i0, i1 - 1, i2 - 1, i3, i4, i5])
T.writes(PadInput[i0, i1, i2, i3, i4, i5])
PadInput[i0, i1, i2, i3, i4, i5] = T.if_then_else(1 <= i1 and i1 < 17 and 1 <= i2 and i2 < 17, inputs[i0, i1 - 1, i2 - 1, i3, i4, i5], T.float32(0), dtype="float32")
for i6_0, i7_0, i8_0, i9_0, i0_2, i1_2, i2_2, i3_2, i4_2, i5_2, i6_1, i7_1, i8_1, i9_1, i0_3, i1_3, i2_3, i3_3, i4_3, i5_3 in T.grid(1, 3, 4, 1, 1, 1, 2, 1, 1, 1, 3, 1, 1, 32, 1, 1, 1, 4, 1, 16):
with T.block("conv2d_capsule_nhwijc"):
n = T.axis.spatial(1, i0_3 + i0_2 + i0_1 + i0_0)
h = T.axis.spatial(8, i1_0 * 4 + i1_1 + i1_2 + i1_3)
w = T.axis.spatial(8, (i2_0 * 4 + i2_1) * 2 + i2_2 + i2_3)
cap_i = T.axis.spatial(4, (i3_0 + i3_1 + i3_2) * 4 + i3_3)
cap_j = T.axis.spatial(4, i4_0 * 4 + i4_1 + i4_2 + i4_3)
co = T.axis.spatial(32, (i5_0 * 2 + i5_1 + i5_2) * 16 + i5_3)
rh = T.axis.reduce(3, i6_0 * 3 + i6_1)
rw = T.axis.reduce(3, i7_0 + i7_1)
cap_k = T.axis.reduce(4, i8_0 + i8_1)
rc = T.axis.reduce(32, i9_0 * 32 + i9_1)
T.reads(PadInput[n, h * 2 + rh, w * 2 + rw, cap_i, cap_k, rc], weight[rh, rw, cap_k, cap_j, rc, co])
T.writes(conv2d_capsule_nhwijc_global[n, h, w, cap_i, cap_j, co])
T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"})
with T.init():
conv2d_capsule_nhwijc_global[n, h, w, cap_i, cap_j, co] = T.float32(0)
conv2d_capsule_nhwijc_global[n, h, w, cap_i, cap_j, co] = conv2d_capsule_nhwijc_global[n, h, w, cap_i, cap_j, co] + PadInput[n, h * 2 + rh, w * 2 + rw, cap_i, cap_k, rc] * weight[rh, rw, cap_k, cap_j, rc, co]
for ax0, ax1, ax2, ax3, ax4, ax5 in T.grid(1, 4, 8, 4, 4, 32):
with T.block("conv2d_capsule_nhwijc_global"):
v0 = T.axis.spatial(1, ax0)
v1 = T.axis.spatial(8, i1_0 * 4 + ax1)
v2, v3, v4, v5 = T.axis.remap("SSSS", [ax2, ax3, ax4, ax5])
T.reads(conv2d_capsule_nhwijc_global[v0, v1, v2, v3, v4, v5])
T.writes(conv2d_capsule_nhwijc[v0, v1, v2, v3, v4, v5])
conv2d_capsule_nhwijc[v0, v1, v2, v3, v4, v5] = conv2d_capsule_nhwijc_global[v0, v1, v2, v3, v4, v5]
@T.prim_func
def cap_2(inputs: T.Buffer[(1, 16, 16, 4, 4, 32), "float32"], weight: T.Buffer[(3, 3, 4, 4, 32, 32), "float32"], conv2d_capsule_nhwijc: T.Buffer[(1, 8, 8, 4, 4, 32), "float32"]) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True})
# body
with T.block("root"):
T.reads()
T.writes()
T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":16, "meta_schedule.vectorize":64})
PadInput = T.alloc_buffer([1, 18, 18, 4, 4, 32], dtype="float32")
for i0, i1, i2, i3, i4, i5 in T.grid(1, 18, 18, 4, 4, 32):
with T.block("PadInput"):
i0_1, i1_1, i2_1, i3_1, i4_1, i5_1 = T.axis.remap("SSSSSS", [i0, i1, i2, i3, i4, i5])
T.reads(inputs[i0_1, i1_1 - 1, i2_1 - 1, i3_1, i4_1, i5_1])
T.writes(PadInput[i0_1, i1_1, i2_1, i3_1, i4_1, i5_1])
PadInput[i0_1, i1_1, i2_1, i3_1, i4_1, i5_1] = T.if_then_else(1 <= i1_1 and i1_1 < 17 and 1 <= i2_1 and i2_1 < 17, inputs[i0_1, i1_1 - 1, i2_1 - 1, i3_1, i4_1, i5_1], T.float32(0), dtype="float32")
for i0_0, i1_0, i2_0, i3_0, i4_0, i5_0, i0_1_1, i1_1_1, i2_1_1, i3_1_1, i4_1_1, i5_1_1, i6_0, i7_0, i8_0, i9_0, i0_2, i1_2, i2_2, i3_2, i4_2, i5_2, i6_1, i7_1, i8_1, i9_1, i0_3, i1_3, i2_3, i3_3, i4_3, i5_3 in T.grid(1, 2, 1, 1, 1, 1, 1, 4, 4, 1, 4, 2, 1, 3, 4, 1, 1, 1, 2, 1, 1, 1, 3, 1, 1, 32, 1, 1, 1, 4, 1, 16):
with T.block("conv2d_capsule_nhwijc"):
n = T.axis.spatial(1, i0_3 + i0_2 + i0_1_1 + i0_0)
h = T.axis.spatial(8, i1_0 * 4 + i1_1_1 + i1_2 + i1_3)
w = T.axis.spatial(8, (i2_0 * 4 + i2_1_1) * 2 + i2_2 + i2_3)
cap_i = T.axis.spatial(4, (i3_0 + i3_1_1 + i3_2) * 4 + i3_3)
cap_j = T.axis.spatial(4, i4_0 * 4 + i4_1_1 + i4_2 + i4_3)
co = T.axis.spatial(32, (i5_0 * 2 + i5_1_1 + i5_2) * 16 + i5_3)
rh = T.axis.reduce(3, i6_0 * 3 + i6_1)
rw = T.axis.reduce(3, i7_0 + i7_1)
cap_k = T.axis.reduce(4, i8_0 + i8_1)
rc = T.axis.reduce(32, i9_0 * 32 + i9_1)
T.reads(PadInput[n, h * 2 + rh, w * 2 + rw, cap_i, cap_k, rc], weight[rh, rw, cap_k, cap_j, rc, co])
T.writes(conv2d_capsule_nhwijc[n, h, w, cap_i, cap_j, co])
T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"})
with T.init():
conv2d_capsule_nhwijc[n, h, w, cap_i, cap_j, co] = T.float32(0)
conv2d_capsule_nhwijc[n, h, w, cap_i, cap_j, co] = conv2d_capsule_nhwijc[n, h, w, cap_i, cap_j, co] + PadInput[n, h * 2 + rh, w * 2 + rw, cap_i, cap_k, rc] * weight[rh, rw, cap_k, cap_j, rc, co]
# fmt: on
decision_0 = [
("SamplePerfectTile", [1, 1, 1, 1]),
("SamplePerfectTile", [2, 4, 1, 1]),
("SamplePerfectTile", [1, 4, 2, 1]),
("SamplePerfectTile", [1, 1, 1, 4]),
("SamplePerfectTile", [1, 4, 1, 1]),
("SamplePerfectTile", [1, 2, 1, 16]),
("SamplePerfectTile", [1, 3]),
("SamplePerfectTile", [3, 1]),
("SamplePerfectTile", [4, 1]),
("SamplePerfectTile", [1, 32]),
("SampleCategorical", 0),
("SampleComputeLocation", 7),
]
decision_1 = [
("SamplePerfectTile", [1, 1, 1, 1]),
("SamplePerfectTile", [2, 4, 1, 1]),
("SamplePerfectTile", [1, 4, 2, 1]),
("SamplePerfectTile", [1, 1, 1, 4]),
("SamplePerfectTile", [1, 4, 1, 1]),
("SamplePerfectTile", [1, 2, 1, 16]),
("SamplePerfectTile", [1, 3]),
("SamplePerfectTile", [3, 1]),
("SamplePerfectTile", [4, 1]),
("SamplePerfectTile", [1, 32]),
("SampleCategorical", 0),
("SampleComputeLocation", 11),
]
decision_2 = [
("SamplePerfectTile", [1, 1, 1, 1]),
("SamplePerfectTile", [2, 4, 1, 1]),
("SamplePerfectTile", [1, 4, 2, 1]),
("SamplePerfectTile", [1, 1, 1, 4]),
("SamplePerfectTile", [1, 4, 1, 1]),
("SamplePerfectTile", [1, 2, 1, 16]),
("SamplePerfectTile", [1, 3]),
("SamplePerfectTile", [3, 1]),
("SamplePerfectTile", [4, 1]),
("SamplePerfectTile", [1, 32]),
("SampleCategorical", 1),
("SampleComputeLocation", -1),
]
mod = create_te_workload("CAP", 0)
actual = ms.TuneContext(
mod=mod,
target=_target(),
space_generator=ms.space_generator.PostOrderApply(),
sch_rules="default",
).generate_design_space()
check_sketches(
mod,
sketches=actual,
expected_mods=[cap_0, cap_1, cap_2],
expected_decisions=[decision_0, decision_1, decision_2],
)


if __name__ == "__main__":
test_cpu_c1d()
test_cpu_c2d()
test_cpu_c3d()
test_cpu_cap()
Loading

0 comments on commit 30d80ff

Please sign in to comment.