Skip to content

Commit

Permalink
Fix regression from halide#6734 (halide#6739)
Browse files Browse the repository at this point in the history
That change inadvertently required the RHS of an update stage that used `+=` (or similar operators) to match the LHS type, which should be required (implicit casting of the RHS is expected). Restructured to remove this, but still ensure that auto-injection of a pure definition matches the required types (if any), and updated tests.
  • Loading branch information
steven-johnson authored and ardier committed Mar 3, 2024
1 parent e09a3f0 commit b656c68
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 19 deletions.
39 changes: 23 additions & 16 deletions src/Func.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2952,7 +2952,7 @@ namespace {

// Inject a suitable base-case definition given an update
// definition. This is a helper for FuncRef::operator+= and co.
Func define_base_case(const Internal::Function &func, const vector<Expr> &a, const Tuple &e) {
Func define_base_case(const Internal::Function &func, const vector<Expr> &a, const vector<Expr> &rhs, int init_val) {
Func f(func);

if (func.has_pure_definition()) {
Expand All @@ -2971,28 +2971,32 @@ Func define_base_case(const Internal::Function &func, const vector<Expr> &a, con
}
}

f(pure_args) = e;
return f;
}
const auto &required_types = func.required_types();
internal_assert(required_types.empty() || required_types.size() == rhs.size());

vector<Expr> init_values(rhs.size());
for (size_t i = 0; i < rhs.size(); ++i) {
// If we have required types, cast the init_val to that type instead of the rhs type
const Type &t = required_types.empty() ? rhs[i].type() : required_types[i];
init_values[i] = cast(t, init_val);
}

Func define_base_case(const Internal::Function &func, const vector<Expr> &a, const Expr &e) {
return define_base_case(func, a, Tuple(e));
f(pure_args) = Tuple(init_values);
return f;
}

} // namespace

template<typename BinaryOp>
Stage FuncRef::func_ref_update(const Tuple &e, int init_val) {
func.check_types(e);
// Don't do this: we want to allow the RHS to be implicitly cast to the type of LHS.
// func.check_types(e);

internal_assert(e.size() > 1);

vector<Expr> init_values(e.size());
for (int i = 0; i < (int)init_values.size(); ++i) {
init_values[i] = cast(e[i].type(), init_val);
}
vector<Expr> expanded_args = args_with_implicit_vars(e.as_vector());
FuncRef self_ref = define_base_case(func, expanded_args, Tuple(init_values))(expanded_args);
const vector<Expr> &rhs = e.as_vector();
const vector<Expr> expanded_args = args_with_implicit_vars(rhs);
FuncRef self_ref = define_base_case(func, expanded_args, rhs, init_val)(expanded_args);

vector<Expr> values(e.size());
for (int i = 0; i < (int)values.size(); ++i) {
Expand All @@ -3003,9 +3007,12 @@ Stage FuncRef::func_ref_update(const Tuple &e, int init_val) {

template<typename BinaryOp>
Stage FuncRef::func_ref_update(Expr e, int init_val) {
func.check_types(e);
vector<Expr> expanded_args = args_with_implicit_vars({e});
FuncRef self_ref = define_base_case(func, expanded_args, cast(e.type(), init_val))(expanded_args);
// Don't do this: we want to allow the RHS to be implicitly cast to the type of LHS.
// func.check_types(e);

const vector<Expr> rhs = {e};
const vector<Expr> expanded_args = args_with_implicit_vars(rhs);
FuncRef self_ref = define_base_case(func, expanded_args, rhs, init_val)(expanded_args);
return self_ref = BinaryOp()(Expr(self_ref), e);
}

Expand Down
66 changes: 65 additions & 1 deletion test/correctness/typed_func.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,76 @@ int main(int argc, char **argv) {
f(x, y) = x + y;

auto r = f.realize({10, 10}); // will assert-fail for values other than 10x10
Buffer<int> b = r[0];
Buffer<int32_t> b = r[0];
b.for_each_element([&](int x, int y) {
assert(b(x, y) == x + y);
});
}

// Verify that update stages defined via += and friends *don't* require
// the RHS type to match the LHS type (whether or not the pure definition
// is implicitly defined)
{
Func f(Int(32), 2, "f");

f(x, y) = cast<int32_t>(1);
f(x, y) += cast<uint8_t>(x + y);

auto r = f.realize({10, 10});
Buffer<int32_t> b = r[0];
b.for_each_element([&](int x, int y) {
assert(b(x, y) == 1 + (uint8_t)(x + y));
});
}

{
Func f(Int(32), 2, "f");

// f(x, y) = cast<int32_t>(0); // leave out, so Halide injects the implicit init
f(x, y) += cast<uint8_t>(x + y);

auto r = f.realize({10, 10});
Buffer<int32_t> b = r[0];
b.for_each_element([&](int x, int y) {
assert(b(x, y) == 0 + (uint8_t)(x + y));
});
}

// Same, but with Tuples
{
Func f({Int(32), Int(8)}, 2, "f");

f(x, y) = Tuple(cast<int32_t>(1), cast<int8_t>(2));
f(x, y) += Tuple(cast<uint8_t>(x + y), cast<int8_t>(x - y));

auto r = f.realize({10, 10});
Buffer<int32_t> b0 = r[0];
Buffer<int8_t> b1 = r[1];
b0.for_each_element([&](int x, int y) {
assert(b0(x, y) == 1 + (uint8_t)(x + y));
});
b1.for_each_element([&](int x, int y) {
assert(b1(x, y) == 2 + (int8_t)(x - y));
});
}

{
Func f({Int(32), Int(8)}, 2, "f");

// f(x, y) = Tuple(cast<int32_t>(1), cast<int8_t>(2)); // leave out, so Halide injects the implicit init
f(x, y) += Tuple(cast<uint8_t>(x + y), cast<int8_t>(x - y));

auto r = f.realize({10, 10});
Buffer<int32_t> b0 = r[0];
Buffer<int8_t> b1 = r[1];
b0.for_each_element([&](int x, int y) {
assert(b0(x, y) == 0 + (uint8_t)(x + y));
});
b1.for_each_element([&](int x, int y) {
assert(b1(x, y) == 0 + (int8_t)(x - y));
});
}

printf("Success!\n");
return 0;
}
2 changes: 1 addition & 1 deletion test/error/func_expr_update_type_mismatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ int main(int argc, char **argv) {
Func f(Float(32), 2, "f");

f(x, y) = 0.f;
f(x, y) += cast<uint8_t>(0);
f(x, y) = cast<uint8_t>(0);

f.realize({100, 100});

Expand Down
2 changes: 1 addition & 1 deletion test/error/func_tuple_update_types_mismatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ int main(int argc, char **argv) {
Func f({UInt(8), Float(64)}, 2, "f");

f(x, y) = {cast<uint8_t>(0), cast<double>(0)};
f(x, y) += {cast<int>(0), cast<float>(0)};
f(x, y) = {cast<int>(0), cast<float>(0)};

f.realize({100, 100});

Expand Down

0 comments on commit b656c68

Please sign in to comment.