Skip to content

Commit

Permalink
[Arith] Update BufferDomainTouched to support vector access. (apache#…
Browse files Browse the repository at this point in the history
…11722)

* [Arith] Update BufferDomainTouched to support vector access.

* Add test checking that domain touched works on IR containing RampNodes.
  • Loading branch information
csullivan authored and zxybazh committed Jun 26, 2022
1 parent 624c85e commit eec521c
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 31 deletions.
18 changes: 12 additions & 6 deletions src/arith/domain_touched.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,14 @@ class BufferTouchedDomain final : public StmtExprVisitor {
}

Region FindUnion(const Buffer& buffer, bool consider_loads, bool consider_stores) {
Region ret;
auto kv = buffer_access_map_.find(buffer.get());
CHECK(kv != buffer_access_map_.end())
<< "The requested buffer is not contained in the provided stmt body.";
if (kv == buffer_access_map_.end()) {
LOG(WARNING) << "[arith::BufferDomainTouched] "
<< "The requested buffer is not contained in the provided stmt body: " << buffer;
return ret;
}

Region ret;
Range none;
BufferTouches bounds;
if (consider_loads && consider_stores) {
Expand Down Expand Up @@ -131,13 +134,16 @@ class BufferTouchedDomain final : public StmtExprVisitor {
}

private:
template <typename ArrayType>
void Touch(BufferTouches* bounds, const ArrayType& args) const {
void Touch(BufferTouches* bounds, const Array<PrimExpr>& args) const {
if (args.size() > bounds->size()) {
bounds->resize(args.size());
}
for (size_t i = 0; i < args.size(); ++i) {
(*bounds)[i].emplace_back(EvalSet(args[i], dom_map_));
if (args[i].as<RampNode>()) {
(*bounds)[i].emplace_back(IntSet::Vector(args[i]));
} else {
(*bounds)[i].emplace_back(EvalSet(args[i], dom_map_));
}
}
}

Expand Down
63 changes: 38 additions & 25 deletions tests/python/unittest/test_arith_domain_touched.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,34 +16,36 @@
# under the License.
import tvm
from tvm import te
from tvm.script import tir as T


@T.prim_func
def scalar_func(a: T.handle, b: T.handle):
m = T.var("int32")
n = T.int32(100)
A = T.match_buffer(a, (n, m), name="A")
B = T.match_buffer(b, (n, m), name="B")

for i, j in T.grid(n, m):
A[i, j] = B[i - 1, j + 1] + A[i - 1, j - 1]


@T.prim_func
def vector_func(a: T.handle, b: T.handle):
n = T.var("int32")
m = T.int32(128)
A = T.match_buffer(a, (n, m), name="A")
B = T.match_buffer(b, (n, m), name="B")

for i in T.serial(n):
for j in T.vectorized(m):
A[i, j] = A[i, j] + B[i, j]


def test_domain_touched():
i = te.var("i")
j = te.var("j")
n = tvm.runtime.convert(100)
m = te.var("m")

a = tvm.tir.decl_buffer((n, m), name="a")
b = tvm.tir.decl_buffer((n, m), name="b")

ir = tvm.tir.For(
i,
0,
n,
tvm.tir.ForKind.SERIAL,
tvm.tir.For(
j,
0,
m,
tvm.tir.ForKind.SERIAL,
tvm.tir.BufferStore(
a,
tvm.tir.BufferLoad(b, [i - 1, j + 1]) + tvm.tir.BufferLoad(a, [i - 1, j - 1]),
[i, j],
),
),
)
func = scalar_func
a, b = [func.buffer_map[var] for var in func.params]
ir = func.body

a_domain_r = tvm.arith._ffi_api.DomainTouched(ir, a, True, False)

Expand Down Expand Up @@ -78,5 +80,16 @@ def test_domain_touched():
assert len(b_domain_w) == 0


def test_domain_touched_vector():
func = tvm.lower(vector_func)["main"]
a, b = [func.buffer_map[var] for var in func.params]

assert tvm.arith._ffi_api.DomainTouched(func.body, a, True, False)[0].extent.value == 128
assert tvm.arith._ffi_api.DomainTouched(func.body, a, True, False)[0].extent.value == 128
assert tvm.arith._ffi_api.DomainTouched(func.body, a, True, True)[0].extent.value == 128
assert tvm.arith._ffi_api.DomainTouched(func.body, b, True, False)[0].extent.value == 128
assert tvm.arith._ffi_api.DomainTouched(func.body, b, True, False)[0].extent.value == 128


if __name__ == "__main__":
test_domain_touched()

0 comments on commit eec521c

Please sign in to comment.