Skip to content

Commit

Permalink
revert apache#9880 and add more testcases
Browse files Browse the repository at this point in the history
  • Loading branch information
wrongtest-intellif committed Jan 30, 2022
1 parent 579d8d2 commit 0280297
Show file tree
Hide file tree
Showing 6 changed files with 129 additions and 45 deletions.
1 change: 1 addition & 0 deletions src/arith/iter_affine_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,7 @@ class IterMapRewriter : public ExprMutator {
if (predicate_induced_min.defined()) predicate_induced_min = predicate_induced_min - base;
if (predicate_induced_max.defined()) predicate_induced_max = predicate_induced_max - base;
}
if (expr->args.size() < 1) return expr;
Optional<IterSumExpr> opt = TryFuseIters(expr);
ICHECK(!opt.defined() || opt.value()->args.size() == 1);
// scale should be 1
Expand Down
16 changes: 2 additions & 14 deletions src/tir/analysis/block_access_region_detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -145,22 +145,10 @@ void BlockReadWriteDetector::VisitExpr_(const LoadNode* op) {
ExprVisitor::VisitExpr_(op);
}

arith::IntSet BlockReadWriteDetector::RelaxAccessIndex(const PrimExpr& index) {
arith::IntSet relaxed = arith::EvalSet(index, dom_map_);
if (!hint_map_.empty()) {
// take non-relaxed var bound hints into considerations
// eg, if i * 4 + j with i >= 10 and j in [0, 4), only j in domain scope
// then the index region can be relaxed to [i*4, i*4+4) ^ [40, inf)
arith::IntSet hint_bound = arith::EvalSet(relaxed, hint_map_);
relaxed = arith::Intersect({relaxed, hint_bound});
}
return relaxed;
}

void BlockReadWriteDetector::VisitExpr_(const BufferLoadNode* op) {
std::vector<arith::IntSet> relaxed_region;
for (const PrimExpr& index : op->indices) {
relaxed_region.push_back(RelaxAccessIndex(index));
relaxed_region.push_back(arith::EvalSet(index, dom_map_));
}
Update(&read_buffers_, &read_regions_, op->buffer, relaxed_region);
ExprVisitor::VisitExpr_(op);
Expand Down Expand Up @@ -213,7 +201,7 @@ void BlockReadWriteDetector::VisitStmt_(const StoreNode* op) {
void BlockReadWriteDetector::VisitStmt_(const BufferStoreNode* op) {
std::vector<arith::IntSet> relaxed_region;
for (const PrimExpr& index : op->indices) {
relaxed_region.push_back(RelaxAccessIndex(index));
relaxed_region.push_back(arith::EvalSet(index, dom_map_));
}
Update(&writes_buffers_, &write_regions_, op->buffer, relaxed_region);
StmtVisitor::VisitStmt_(op);
Expand Down
14 changes: 8 additions & 6 deletions src/tir/schedule/primitive/compute_at.cc
Original file line number Diff line number Diff line change
Expand Up @@ -237,23 +237,25 @@ class ScopeReconstructor : private StmtMutator {
PrimExpr predicate = const_true();
for (int i = 0; i < n_iters; ++i) {
Range iter_dom = iter_doms[i].dom.CoverRange(block_->iter_vars[i]->dom);
const arith::IntSet& pred_bound = iter_doms[i].bound;
if (preserve_unit_loops || !is_one(iter_dom->extent) || !pred_bound.IsNothing()) {
if (preserve_unit_loops || !is_one(iter_dom->extent)) {
Var var("ax" + std::to_string(loop_vars.size()), DataType::Int(32));
loop_vars.push_back(var);
loop_extents.push_back(iter_dom->extent);
iter_values.push_back(iter_dom->min + var);
analyzer->Bind(var, Range::FromMinExtent(0, iter_dom->extent));
} else {
iter_values.push_back(iter_dom->min);
}
const arith::IntSet& pred_bound = iter_doms[i].bound;
if (!pred_bound.IsNothing()) {
if (pred_bound.HasLowerBound()) {
PrimExpr lower_bound = iter_dom->min + var >= pred_bound.min();
PrimExpr lower_bound = iter_values[i] >= pred_bound.min();
predicate = predicate && lower_bound;
}
if (pred_bound.HasUpperBound()) {
PrimExpr upper_bound = iter_dom->min + var < pred_bound.max() + 1;
PrimExpr upper_bound = iter_values[i] < pred_bound.max() + 1;
predicate = predicate && upper_bound;
}
} else {
iter_values.push_back(iter_dom->min);
}
}
this->new_block_realize_ =
Expand Down
8 changes: 8 additions & 0 deletions tests/python/unittest/test_arith_iter_affine_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,14 @@ def test_predicate():
)
assert len(res) == 0

# irrelavant predicate
res = tvm.arith.detect_iter_map(
[i + j],
var_dom([(i, 1)]),
j <= 24,
)
assert_iter_sum_pattern(res[0], 1, j)

# constraint on nested fused iters
res = tvm.arith.detect_iter_map(
[i * 8 + j * 2 + k],
Expand Down
20 changes: 3 additions & 17 deletions tests/python/unittest/test_tir_analysis_get_block_access_region.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,29 +138,15 @@ def access_of_padding_pattern() -> None:
for i, j in T.grid(32, 32):
with T.block("padding"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads(
[
X[
T.max(vi - 2, 0) : T.min(vi - 2, 27) + 1,
T.max(vj - 2, 0) : T.min(vj - 2, 27) + 1,
]
]
)
T.reads([X[vi - 2, vj - 2]])
T.writes([X_pad[vi, vj]])
X_pad[vi, vj] = T.if_then_else(
2 <= vi and vi < 30 and 2 <= vj and vj < 30, X[vi - 2, vj - 2], 0.0, dtype="float32"
)
with T.block("padding_reverse"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads([X_pad[T.max(vi, 2) : T.min(vi, 29) + 1, T.max(vj, 2) : T.min(vj, 29) + 1]])
T.writes(
[
Y[
T.max(vi - 2, 0) : T.min(vi - 2, 27) + 1,
T.max(vj - 2, 0) : T.min(vj - 2, 27) + 1,
]
]
)
T.reads([X_pad[vi, vj]])
T.writes([Y[vi - 2, vj - 2]])
if 2 <= vi and vi < 30 and 2 <= vj and vj < 30:
Y[vi - 2, vj - 2] = X_pad[vi, vj]

Expand Down
115 changes: 107 additions & 8 deletions tests/python/unittest/test_tir_schedule_compute_at.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,16 +765,12 @@ def tiled_pooling_read_cache(a: T.handle, b: T.handle) -> None:
for hh, ww in T.grid(224, 224):
with T.block("cache"):
h, w = T.axis.remap("SS", [hh, ww])
T.reads([X[h, w]])
T.writes([cache[h, w]])
cache[h, w] = X[h, w]
for hh_0, ww_0, hh_1, ww_1, khh, kww in T.grid(28, 28, 8, 8, 3, 3):
with T.block("compute"):
h = T.axis.spatial(224, hh_0 * 8 + hh_1)
w = T.axis.spatial(224, ww_0 * 8 + ww_1)
kh, kw = T.axis.remap("RR", [khh, kww])
T.reads([Y[h, w], cache[h + kh - 1, w + kw - 1]])
T.writes([Y[h, w]])
with T.init():
Y[h, w] = 0.0
Y[h, w] = T.max(Y[h, w], T.if_then_else(
Expand All @@ -795,16 +791,12 @@ def tiled_pooling_read_cache_after_compute_at(a: T.handle, b: T.handle) -> None:
h = T.axis.spatial(224, hh_0 * 8 - 1 + ax0)
w = T.axis.spatial(224, ww_0 * 8 - 1 + ax1)
T.where(1 <= hh_0 * 8 + ax0 and hh_0 * 8 + ax0 < 225 and 1 <= ww_0 * 8 + ax1 and ww_0 * 8 + ax1 < 225)
T.reads([X[h, w]])
T.writes([cache[h, w]])
cache[h, w] = X[h, w]
for hh_1, ww_1, khh, kww in T.grid(8, 8, 3, 3):
with T.block("compute"):
h = T.axis.spatial(224, hh_0 * 8 + hh_1)
w = T.axis.spatial(224, ww_0 * 8 + ww_1)
kh, kw = T.axis.remap("RR", [khh, kww])
T.reads([Y[h, w], cache[h + kh - 1, w + kw - 1]])
T.writes([Y[h, w]])
with T.init():
Y[h, w] = 0.0
Y[h, w] = T.max(Y[h, w], T.if_then_else(
Expand All @@ -814,6 +806,93 @@ def tiled_pooling_read_cache_after_compute_at(a: T.handle, b: T.handle) -> None:
T.likely(w + kw < 225, dtype="bool"),
cache[h + kh - 1, w + kw - 1], 0.0, dtype="float32"))

@T.prim_func
def non_uniform_tiled_conv(x: T.Buffer[(1, 3, 100, 100), "float32"],
w: T.Buffer[(16, 3, 3, 3), "float32"],
y: T.Buffer[(1, 16, 98, 98), "float32"]) -> None:
x_global = T.alloc_buffer([1, 3, 100, 100], dtype="float32")
for ax0, ax1, ax2, ax3 in T.grid(1, 3, 100, 100):
with T.block("cache"):
v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
x_global[v0, v1, v2, v3] = x[v0, v1, v2, v3]
for h_o, w_o, n, c_o, h_i, w_i, c_i, kh, kw in T.grid(7, 7, 1, 16, 15, 15, 3, 3, 3):
with T.block("compute"):
nn = T.axis.spatial(1, 0)
cc = T.axis.spatial(16, c_o)
hh = T.axis.spatial(98, h_o * 15 + h_i)
ww = T.axis.spatial(98, w_o * 15 + w_i)
rc, rh, rw = T.axis.remap("RRR", [c_i, kh, kw])
T.where(h_o * 15 + h_i < 98 and w_o * 15 + w_i < 98)
with T.init():
y[nn, cc, hh, ww] = T.float32(0)
y[nn, cc, hh, ww] = y[nn, cc, hh, ww] + \
x_global[nn, cc // 16 * 3 + rc, hh + rh, ww + rw] * w[cc, rc, rh, rw]

@T.prim_func
def non_uniform_tiled_conv_after_compute_at(x: T.Buffer[(1, 3, 100, 100), "float32"],
w: T.Buffer[(16, 3, 3, 3), "float32"],
y: T.Buffer[(1, 16, 98, 98), "float32"]) -> None:
x_global = T.alloc_buffer([1, 3, 100, 100], dtype="float32")
for h_o, w_o in T.grid(7, 7):
for ax0, ax1, ax2 in T.grid(3, 17, 17):
with T.block("cache"):
v0 = T.axis.spatial(1, 0)
v1 = T.axis.spatial(3, ax0)
v2 = T.axis.spatial(100, h_o * 15 + ax1)
v3 = T.axis.spatial(100, w_o * 15 + ax2)
T.where(h_o * 15 + ax1 < 100 and w_o * 15 + ax2 < 100)
x_global[v0, v1, v2, v3] = x[v0, v1, v2, v3]
for n, c_o, h_i, w_i, c_i, kh, kw in T.grid(1, 16, 15, 15, 3, 3, 3):
with T.block("compute"):
nn = T.axis.spatial(1, 0)
cc = T.axis.spatial(16, c_o)
hh = T.axis.spatial(98, h_o * 15 + h_i)
ww = T.axis.spatial(98, w_o * 15 + w_i)
rc, rh, rw = T.axis.remap("RRR", [c_i, kh, kw])
T.where(h_o * 15 + h_i < 98 and w_o * 15 + w_i < 98)
with T.init():
y[nn, cc, hh, ww] = T.float32(0)
y[nn, cc, hh, ww] = y[nn, cc, hh, ww] + \
x_global[nn, cc // 16 * 3 + rc, hh + rh, ww + rw] * w[cc, rc, rh, rw]

@T.prim_func
def concat_two_elemwise(x: T.Buffer[(16,), "float32"],
y: T.Buffer[(8,), "float32"],
T_concat: T.Buffer[(24,), "float32"]) -> None:
T_add_1 = T.alloc_buffer([16], dtype="float32")
T_add_2 = T.alloc_buffer([8], dtype="float32")
for i in T.serial(16):
with T.block("T_add_1"):
ax = T.axis.spatial(16, i)
T_add_1[ax] = x[ax] + T.float32(1)
for i in T.serial(8):
with T.block("T_add_2"):
ax = T.axis.spatial(8, i)
T_add_2[ax] = y[ax] + T.float32(2)
for i in T.serial(24):
with T.block("T_concat"):
ax = T.axis.spatial(24, i)
T_concat[ax] = T.if_then_else(16 <= ax, T_add_1[ax - 16], T_add_2[ax], dtype="float32")

@T.prim_func
def concat_two_elemwise_after_compute_at(x: T.Buffer[(16,), "float32"],
y: T.Buffer[(8,), "float32"],
T_concat: T.Buffer[(24,), "float32"]) -> None:
T_add_1 = T.alloc_buffer([16], dtype="float32")
T_add_2 = T.alloc_buffer([8], dtype="float32")
for i in T.serial(24):
with T.block("T_add_1"):
ax = T.axis.spatial(16, i - 16)
T.where(16 <= i)
T_add_1[ax] = x[ax] + T.float32(1)
with T.block("T_add_2"):
ax = T.axis.spatial(8, i)
T.where(i < 8)
T_add_2[ax] = y[ax] + T.float32(2)
with T.block("T_concat"):
ax = T.axis.spatial(24, i)
T_concat[ax] = T.if_then_else(16 <= ax, T_add_1[ax - 16], T_add_2[ax], dtype="float32")

@T.prim_func
def floordiv_and_floormod_indices(a: T.handle, b: T.handle) -> None:
X = T.match_buffer(a, [16, 16])
Expand Down Expand Up @@ -929,6 +1008,26 @@ def test_compute_at_tiled_pooling_read_cache():
verify_trace_roundtrip(sch=sch, mod=tiled_pooling_read_cache)


def test_compute_at_non_uniform_tiled_conv():
sch = tir.Schedule(non_uniform_tiled_conv, debug_mask="all")
compute = sch.get_block("compute")
sch.compute_at(sch.get_block("cache"), sch.get_loops(compute)[1])
tvm.ir.assert_structural_equal(non_uniform_tiled_conv_after_compute_at, sch.mod["main"])
verify_trace_roundtrip(sch=sch, mod=non_uniform_tiled_conv)


def test_compute_at_concat():
sch = tir.Schedule(concat_two_elemwise, debug_mask="all")
concat = sch.get_block("T_concat")
add1 = sch.get_block("T_add_1")
add2 = sch.get_block("T_add_2")
axis = sch.get_loops(concat)[0]
sch.compute_at(add1, axis)
sch.compute_at(add2, axis)
tvm.ir.assert_structural_equal(concat_two_elemwise_after_compute_at, sch.mod["main"])
verify_trace_roundtrip(sch=sch, mod=concat_two_elemwise)


def test_reverse_compute_at_tiled():
sch = tir.Schedule(tiled, debug_mask="all")
block = sch.get_block("C")
Expand Down

0 comments on commit 0280297

Please sign in to comment.