diff --git a/src/PartitionLoops.cpp b/src/PartitionLoops.cpp index 4ea2f067b3e5..760d49831951 100644 --- a/src/PartitionLoops.cpp +++ b/src/PartitionLoops.cpp @@ -411,7 +411,7 @@ class FindSimplifications : public IRVisitor { IRVisitor::visit(op); if (has_uncaptured_likely_tag(op->predicate)) { const int lanes = op->predicate.type().lanes(); - new_simplification(op->predicate, op->predicate, const_true(lanes), op->predicate); + new_simplification(op->predicate, op->predicate, const_true(lanes), remove_likelies(op->predicate)); } } @@ -419,7 +419,7 @@ class FindSimplifications : public IRVisitor { IRVisitor::visit(op); if (has_uncaptured_likely_tag(op->predicate)) { const int lanes = op->predicate.type().lanes(); - new_simplification(op->predicate, op->predicate, const_true(lanes), op->predicate); + new_simplification(op->predicate, op->predicate, const_true(lanes), remove_likelies(op->predicate)); } } @@ -472,6 +472,9 @@ class MakeSimplifications : public IRMutator { Expr mutate(const Expr &e) override { for (auto const &s : simplifications) { if (e.same_as(s.old_expr)) { + internal_assert(!s.likely_value.same_as(s.old_expr)) + << "Loop partitioning simplification does not mutate value: " + << s.old_expr << "\n"; return mutate(s.likely_value); } } diff --git a/test/correctness/partition_loops.cpp b/test/correctness/partition_loops.cpp index 53d82b9c2e51..ba01492f958f 100644 --- a/test/correctness/partition_loops.cpp +++ b/test/correctness/partition_loops.cpp @@ -3,49 +3,61 @@ using namespace Halide; int main(int argc, char *argv[]) { - Buffer input(1024, 1024, 3); + { + Buffer input(1024, 1024, 3); - for (int c = 0; c < input.channels(); c++) { - for (int y = 0; y < input.height(); y++) { - for (int x = 0; x < input.width(); x++) { - input(x, y, c) = x + y + c; + for (int c = 0; c < input.channels(); c++) { + for (int y = 0; y < input.height(); y++) { + for (int x = 0; x < input.width(); x++) { + input(x, y, c) = x + y + c; + } } } - } - Var x("x"), y("y"), c("c"); - - Func clamped_input = Halide::BoundaryConditions::repeat_edge(input); - - // One of the possible conditions for partitioning loop 'f.s0.x' is - // ((f.s0.x + g[0]) <= 1023) which depends on 'g'. Since 'g' is - // only allocated inside f.s0.x, partition loops should not use this - // condition to compute the epilogue/prologue. - Func f("f"), g("g"), h("h"); - g(x, y, c) = x + y + c; - g(x, y, 0) = x; - h(x, y) = clamped_input(x + g(x, y, 0), y, 2); - f(x, y, c) = select(h(x, y) < x + y, x + y, y + c); - - f.compute_root(); - - Func output("output"); - output(x, y, c) = cast(f(x, y, c)); - Buffer im = output.realize({1024, 1024, 3}); - - for (int y = 0; y < input.height(); y++) { - for (int x = 0; x < input.width(); x++) { - for (int c = 0; c < input.channels(); c++) { - float correct = (input(std::min(2 * x, input.width() - 1), y, 2) < x + y) ? x + y : y + c; - if (im(x, y, c) != correct) { - printf("im(%d, %d, %d) = %f instead of %f\n", - x, y, c, im(x, y, c), correct); - return 1; + Var x("x"), y("y"), c("c"); + + Func clamped_input = Halide::BoundaryConditions::repeat_edge(input); + + // One of the possible conditions for partitioning loop 'f.s0.x' is + // ((f.s0.x + g[0]) <= 1023) which depends on 'g'. Since 'g' is + // only allocated inside f.s0.x, partition loops should not use this + // condition to compute the epilogue/prologue. + Func f("f"), g("g"), h("h"); + g(x, y, c) = x + y + c; + g(x, y, 0) = x; + h(x, y) = clamped_input(x + g(x, y, 0), y, 2); + f(x, y, c) = select(h(x, y) < x + y, x + y, y + c); + + f.compute_root(); + + Func output("output"); + output(x, y, c) = cast(f(x, y, c)); + Buffer im = output.realize({1024, 1024, 3}); + + for (int y = 0; y < input.height(); y++) { + for (int x = 0; x < input.width(); x++) { + for (int c = 0; c < input.channels(); c++) { + float correct = (input(std::min(2 * x, input.width() - 1), y, 2) < x + y) ? x + y : y + c; + if (im(x, y, c) != correct) { + printf("im(%d, %d, %d) = %f instead of %f\n", + x, y, c, im(x, y, c), correct); + return 1; + } } } } } + // A loop partitioning bug from https://github.com/halide/Halide/issues/7742 + { + Var x, y, x_outer, x_inner, y_x_inner_fused; + Func f; + f(x, y) = x + y; + f.split(x, x_outer, x_inner, 2, TailStrategy::PredicateStores).fuse(y, x_inner, y_x_inner_fused); + Pipeline p(f); + p.realize({100, 100}); + } + printf("Success!\n"); return 0; }