diff --git a/src/LoopCarry.cpp b/src/LoopCarry.cpp index 5f4d7bb519d3..050cdfbfc8d9 100644 --- a/src/LoopCarry.cpp +++ b/src/LoopCarry.cpp @@ -283,11 +283,34 @@ class LoopCarryOverLoop : public IRMutator { // For each load, move the load index forwards by one loop iteration vector indices, next_indices, predicates, next_predicates; + // CSE-d versions of the above, so can_prove can be safely used on them. + vector indices_csed, next_indices_csed, predicates_csed, next_predicates_csed; for (const vector &v : loads) { indices.push_back(v[0]->index); next_indices.push_back(step_forwards(v[0]->index, linear)); predicates.push_back(v[0]->predicate); next_predicates.push_back(step_forwards(v[0]->predicate, linear)); + + if (indices.back().defined()) { + indices_csed.push_back(common_subexpression_elimination(indices.back())); + } else { + indices_csed.emplace_back(); + } + if (next_indices.back().defined()) { + next_indices_csed.push_back(common_subexpression_elimination(next_indices.back())); + } else { + next_indices_csed.emplace_back(); + } + if (predicates.back().defined()) { + predicates_csed.push_back(common_subexpression_elimination(predicates.back())); + } else { + predicates_csed.emplace_back(); + } + if (next_predicates.back().defined()) { + next_predicates_csed.push_back(common_subexpression_elimination(next_predicates.back())); + } else { + next_predicates_csed.emplace_back(); + } } // Find loads done on this loop iteration that will be @@ -299,11 +322,16 @@ class LoopCarryOverLoop : public IRMutator { if (i == j) { continue; } + // can_prove is stronger than graph_equal, because it doesn't require index expressions to be + // exactly the same, but evaluate to the same value. We keep the graph_equal check, because + // it's faster and should be executed before the more expensive check. if (loads[i][0]->name == loads[j][0]->name && next_indices[j].defined() && - graph_equal(indices[i], next_indices[j]) && + (graph_equal(indices[i], next_indices[j]) || + ((indices[i].type() == next_indices[j].type()) && can_prove(indices_csed[i] == next_indices_csed[j]))) && next_predicates[j].defined() && - graph_equal(predicates[i], next_predicates[j])) { + (graph_equal(predicates[i], next_predicates[j]) || + ((predicates[i].type() == next_predicates[j].type()) && can_prove(predicates_csed[i] == next_predicates_csed[j])))) { chains.push_back({j, i}); debug(3) << "Found carried value:\n" << i << ": -> " << Expr(loads[i][0]) << "\n" diff --git a/test/correctness/CMakeLists.txt b/test/correctness/CMakeLists.txt index 07921a347425..cd66f21a346e 100644 --- a/test/correctness/CMakeLists.txt +++ b/test/correctness/CMakeLists.txt @@ -199,6 +199,7 @@ tests(GROUPS correctness likely.cpp load_library.cpp logical.cpp + loop_carry.cpp loop_invariant_extern_calls.cpp loop_level_generator_param.cpp lossless_cast.cpp diff --git a/test/correctness/loop_carry.cpp b/test/correctness/loop_carry.cpp new file mode 100644 index 000000000000..4cfba7d25f3f --- /dev/null +++ b/test/correctness/loop_carry.cpp @@ -0,0 +1,64 @@ +#include "Halide.h" +#include + +using namespace Halide; +using namespace Halide::Internal; + +// Wrapper class to call loop_carry on a given statement. +class LoopCarryWrapper : public IRMutator { + using IRMutator::visit; + + int register_count_; + Stmt mutate(const Stmt &stmt) override { + return simplify(loop_carry(stmt, register_count_)); + } + +public: + LoopCarryWrapper(int register_count) + : register_count_(register_count) { + } +}; + +int main(int argc, char **argv) { + Func input; + Func g; + Func h; + Func f; + Var x, y, xo, yo, xi, yi; + + input(x, y) = x + y; + + Expr sum_expr = 0; + for (int ix = -100; ix <= 100; ix++) { + // Generate two chains of sums, but only one of them will be carried. + sum_expr += input(x, y + ix); + sum_expr += input(x + 13, y + 2 * ix); + } + g(x, y) = sum_expr; + h(x, y) = g(x, y) + 12; + f(x, y) = h(x, y); + + // Make a maximum number of the carried values very large for the purpose + // of this test. + constexpr int kMaxRegisterCount = 1024; + f.add_custom_lowering_pass(new LoopCarryWrapper(kMaxRegisterCount)); + + const int size = 128; + f.compute_root() + .bound(x, 0, size) + .bound(y, 0, size); + + h.compute_root() + .tile(x, y, xo, yo, xi, yi, 16, 16, TailStrategy::RoundUp); + + g.compute_at(h, xo) + .reorder(y, x) + .vectorize(x, 4); + + input.compute_root(); + + f.realize({size, size}); + + printf("Success!\n"); + return 0; +}