Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stronger chain detection in LoopCarry pass #8016

Merged
merged 10 commits into from
Jan 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 30 additions & 2 deletions src/LoopCarry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -283,11 +283,34 @@ class LoopCarryOverLoop : public IRMutator {

// For each load, move the load index forwards by one loop iteration
vector<Expr> indices, next_indices, predicates, next_predicates;
// CSE-d versions of the above, so can_prove can be safely used on them.
vector<Expr> indices_csed, next_indices_csed, predicates_csed, next_predicates_csed;
for (const vector<const Load *> &v : loads) {
indices.push_back(v[0]->index);
next_indices.push_back(step_forwards(v[0]->index, linear));
predicates.push_back(v[0]->predicate);
next_predicates.push_back(step_forwards(v[0]->predicate, linear));

if (indices.back().defined()) {
indices_csed.push_back(common_subexpression_elimination(indices.back()));
} else {
indices_csed.emplace_back();
}
if (next_indices.back().defined()) {
next_indices_csed.push_back(common_subexpression_elimination(next_indices.back()));
} else {
next_indices_csed.emplace_back();
}
if (predicates.back().defined()) {
predicates_csed.push_back(common_subexpression_elimination(predicates.back()));
} else {
predicates_csed.emplace_back();
}
if (next_predicates.back().defined()) {
next_predicates_csed.push_back(common_subexpression_elimination(next_predicates.back()));
} else {
next_predicates_csed.emplace_back();
}
}

// Find loads done on this loop iteration that will be
Expand All @@ -299,11 +322,16 @@ class LoopCarryOverLoop : public IRMutator {
if (i == j) {
continue;
}
// can_prove is stronger than graph_equal, because it doesn't require index expressions to be
// exactly the same, but evaluate to the same value. We keep the graph_equal check, because
// it's faster and should be executed before the more expensive check.
if (loads[i][0]->name == loads[j][0]->name &&
next_indices[j].defined() &&
graph_equal(indices[i], next_indices[j]) &&
(graph_equal(indices[i], next_indices[j]) ||
((indices[i].type() == next_indices[j].type()) && can_prove(indices_csed[i] == next_indices_csed[j]))) &&
next_predicates[j].defined() &&
graph_equal(predicates[i], next_predicates[j])) {
(graph_equal(predicates[i], next_predicates[j]) ||
((predicates[i].type() == next_predicates[j].type()) && can_prove(predicates_csed[i] == next_predicates_csed[j])))) {
chains.push_back({j, i});
debug(3) << "Found carried value:\n"
<< i << ": -> " << Expr(loads[i][0]) << "\n"
Expand Down
1 change: 1 addition & 0 deletions test/correctness/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ tests(GROUPS correctness
likely.cpp
load_library.cpp
logical.cpp
loop_carry.cpp
loop_invariant_extern_calls.cpp
loop_level_generator_param.cpp
lossless_cast.cpp
Expand Down
64 changes: 64 additions & 0 deletions test/correctness/loop_carry.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
#include "Halide.h"
#include <stdio.h>

using namespace Halide;
using namespace Halide::Internal;

// Wrapper class to call loop_carry on a given statement.
class LoopCarryWrapper : public IRMutator {
using IRMutator::visit;

int register_count_;
Stmt mutate(const Stmt &stmt) override {
return simplify(loop_carry(stmt, register_count_));
}

public:
LoopCarryWrapper(int register_count)
: register_count_(register_count) {
}
};

int main(int argc, char **argv) {
Func input;
Func g;
Func h;
Func f;
Var x, y, xo, yo, xi, yi;

input(x, y) = x + y;

Expr sum_expr = 0;
for (int ix = -100; ix <= 100; ix++) {
// Generate two chains of sums, but only one of them will be carried.
sum_expr += input(x, y + ix);
sum_expr += input(x + 13, y + 2 * ix);
}
g(x, y) = sum_expr;
h(x, y) = g(x, y) + 12;
f(x, y) = h(x, y);

// Make a maximum number of the carried values very large for the purpose
// of this test.
constexpr int kMaxRegisterCount = 1024;
f.add_custom_lowering_pass(new LoopCarryWrapper(kMaxRegisterCount));

const int size = 128;
f.compute_root()
.bound(x, 0, size)
.bound(y, 0, size);

h.compute_root()
.tile(x, y, xo, yo, xi, yi, 16, 16, TailStrategy::RoundUp);

g.compute_at(h, xo)
.reorder(y, x)
.vectorize(x, 4);

input.compute_root();

f.realize({size, size});

printf("Success!\n");
return 0;
}