Skip to content
This repository has been archived by the owner on Nov 25, 2022. It is now read-only.

Commit

Permalink
[TE] Raise error for non-bijective transformation (apache#12926)
Browse files Browse the repository at this point in the history
This is a fix for a bug introduced in
apache#12904.  Prior to then, an
exception was raised when the transformation wouldn't be bijective
over the transformed buffer's shape.  The PR replaced the bijective
check done as part of `DetectIterMap` with a check done on the
returned `padding_predicate`.  However, this check was not equivalent,
and some transformations could erroneously apply, rather than
raising an exception as being non-bijective.

This commit re-enables the bijectivity check in `DetectIterMap`, and
adds a test case for this behavior.
  • Loading branch information
Lunderberg authored and xinetzone committed Nov 25, 2022
1 parent 4a80692 commit 2556831
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 14 deletions.
4 changes: 3 additions & 1 deletion src/arith/iter_affine_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1739,7 +1739,9 @@ class IterMapToExprNormalizer : public ExprMutator {
bool IterMapRewriter::CanProveDivisible(const PrimExpr& lhs, const PrimExpr& rhs) {
const auto* clhs = lhs.as<IntImmNode>();
const auto* crhs = rhs.as<IntImmNode>();
if (clhs && crhs) {
if (crhs && crhs->value == 0) {
return false;
} else if (clhs && crhs) {
return clhs->value % crhs->value == 0;
}

Expand Down
33 changes: 20 additions & 13 deletions src/tir/ir/index_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,19 +53,21 @@ IndexMap IndexMap::FromFunc(int ndim, runtime::TypedPackedFunc<Array<PrimExpr>(A
return IndexMap(initial_indices, func(initial_indices), std::move(inverse_index_map));
}

std::pair<IndexMap, PrimExpr> IndexMap::NonSurjectiveInverse(Array<Range> initial_ranges) const {
if ((*this)->inverse_index_map.defined()) {
std::pair<IndexMap, PrimExpr> IndexMapInverseImpl(const IndexMap& self,
const Array<Range>& initial_ranges,
arith::IterMapLevel check_level) {
if (self->inverse_index_map.defined()) {
// return the pre-defined inverse index map if exists. In this
// case, the user-defined inverse is assumed to be correct and
// bijective.
PrimExpr padding_predicate = Bool(false);
return {Downcast<IndexMap>((*this)->inverse_index_map.value()), padding_predicate};
return {Downcast<IndexMap>(self->inverse_index_map.value()), padding_predicate};
}

// Dummy variables to represent the inverse's inputs.
Array<Var> output_vars;
for (size_t i = 0; i < (*this)->final_indices.size(); i++) {
PrimExpr index = (*this)->final_indices[i];
for (size_t i = 0; i < self->final_indices.size(); i++) {
PrimExpr index = self->final_indices[i];
// TODO(Lunderberg): Better names for these variables. A variable
// that is passed through unmodified (`index` is an element of
// `initial_indices`) should use that input index's name. A pair
Expand All @@ -79,16 +81,16 @@ std::pair<IndexMap, PrimExpr> IndexMap::NonSurjectiveInverse(Array<Range> initia

// Dummy ranges for the extent of each input.
Map<Var, Range> input_iters;
ICHECK_EQ((*this)->initial_indices.size(), initial_ranges.size());
ICHECK_EQ(self->initial_indices.size(), initial_ranges.size());
for (size_t i = 0; i < initial_ranges.size(); i++) {
input_iters.Set((*this)->initial_indices[i], initial_ranges[i]);
input_iters.Set(self->initial_indices[i], initial_ranges[i]);
}

// Unpack the output indices into linear combinations of the initial
// indices.
arith::Analyzer analyzer;
auto padded_iter_map = DetectIterMap((*this)->final_indices, input_iters, /* predicate = */ 1,
/*check_level=*/arith::IterMapLevel::NoCheck, &analyzer,
auto padded_iter_map = DetectIterMap(self->final_indices, input_iters, /* predicate = */ 1,
/*check_level=*/check_level, &analyzer,
/*simplify_trivial_iterators=*/false);
CHECK(padded_iter_map->errors.empty()) << "Could not parse mapping as sum of iterators. "
<< "Error: " << padded_iter_map->errors[0];
Expand All @@ -100,8 +102,8 @@ std::pair<IndexMap, PrimExpr> IndexMap::NonSurjectiveInverse(Array<Range> initia

// Unpack the map to an array, maintaining the same parameter order.
Array<PrimExpr> inverse_exprs;
for (int i = 0, n = (*this)->initial_indices.size(); i < n; ++i) {
Var index = (*this)->initial_indices[i];
for (int i = 0, n = self->initial_indices.size(); i < n; ++i) {
Var index = self->initial_indices[i];
PrimExpr expr;
if (is_one(initial_ranges[i]->extent) && !inverse_exprs_map.count(index)) {
expr = initial_ranges[i]->min;
Expand All @@ -116,7 +118,7 @@ std::pair<IndexMap, PrimExpr> IndexMap::NonSurjectiveInverse(Array<Range> initia
padding_predicate = Substitute(padding_predicate, inverse_exprs_map);

{
auto output_ranges = (*this)->MapRanges(initial_ranges);
auto output_ranges = self->MapRanges(initial_ranges);
ICHECK_EQ(output_ranges.size(), output_vars.size());

arith::Analyzer analyzer;
Expand All @@ -131,8 +133,13 @@ std::pair<IndexMap, PrimExpr> IndexMap::NonSurjectiveInverse(Array<Range> initia
return {IndexMap(output_vars, inverse_exprs), padding_predicate};
}

std::pair<IndexMap, PrimExpr> IndexMap::NonSurjectiveInverse(Array<Range> initial_ranges) const {
return IndexMapInverseImpl(*this, initial_ranges, arith::IterMapLevel::NoCheck);
}

IndexMap IndexMap::Inverse(Array<Range> initial_ranges) const {
auto [inverse, padding_predicate] = NonSurjectiveInverse(std::move(initial_ranges));
auto [inverse, padding_predicate] =
IndexMapInverseImpl(*this, initial_ranges, arith::IterMapLevel::Bijective);
arith::Analyzer analyzer;
CHECK(analyzer.CanProve(!padding_predicate))
<< "Bijective inverse should not contain padding, but inverse of " << *this << " over range "
Expand Down
13 changes: 13 additions & 0 deletions tests/python/unittest/test_transform_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,5 +575,18 @@ def test_size_one_buffer(shape, transform):
s[B].transform_layout(transform)


def test_non_divisible_transform_raises_error():
A = te.placeholder([1, 3, 8, 8])
B = te.compute(A.shape, lambda *indices: A[indices])
s = te.create_schedule(B.op)

transform = lambda n, c, h, w: [n, c // 4, h, w, c % 4]
# Error occurs here, because the transformation would introduce
# padding. Padded transforms are supported in TIR-based
# schedules.
with pytest.raises(tvm.TVMError):
s[B].transform_layout(transform)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 2556831

Please sign in to comment.