Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bug] Fix false overflow alarm in struct fors under packed mode #6457

Merged
merged 3 commits into from
Oct 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions taichi/ir/snode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ SNode &SNode::create_node(std::vector<Axis> axes,
}
if (acc_shape > std::numeric_limits<int>::max()) {
TI_WARN(
"Snode index might be out of int32 boundary but int64 indexing is not "
"supported yet.");
"SNode index might be out of int32 boundary but int64 indexing is not "
"supported yet. Struct fors might not work either.");
}
new_node.num_cells_per_container = acc_shape;
// infer extractors (only for POT)
Expand Down
2 changes: 1 addition & 1 deletion taichi/transforms/demote_dense_struct_fors.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ void convert_to_range_for(OffloadedStmt *offloaded, bool packed) {
snode = snode->parent;
}
std::reverse(snodes.begin(), snodes.end());
TI_ASSERT(total_bits <= 30);

// general shape calculation - no dependence on POT
int64 total_n = 1;
Expand All @@ -38,6 +37,7 @@ void convert_to_range_for(OffloadedStmt *offloaded, bool packed) {
}
total_n *= s->num_cells_per_container;
}
TI_ASSERT(total_n <= std::numeric_limits<int>::max());

offloaded->const_begin = true;
offloaded->const_end = true;
Expand Down
15 changes: 15 additions & 0 deletions tests/python/test_struct_for_non_pot.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,18 @@ def test_2d():
@test_utils.test(require=ti.extension.packed, packed=True)
def test_2d_packed():
_test_2d()


@test_utils.test(require=ti.extension.packed, packed=True)
def test_2d_overflow_if_not_packed():
n, m, p = 2**9 + 1, 2**9 + 1, 2**10 + 1
arr = ti.field(ti.u8, (n, m, p))

@ti.kernel
def count() -> ti.i32:
res = 0
for _ in ti.grouped(arr):
res += 1
return res

assert count() == n * m * p