Skip to content

Commit

Permalink
[TIR][NarrowDataType] Bufferload's index should not inherit bits cons…
Browse files Browse the repository at this point in the history
…traint of value (#17411)

bufferload's index dtype narrowing should not inherit value bits constraint

Co-authored-by: wrongtest <[email protected]>
  • Loading branch information
wrongtest-intellif and wrongtest authored Sep 25, 2024
1 parent 2a87c4c commit a90fb8e
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 1 deletion.
14 changes: 13 additions & 1 deletion src/tir/transforms/narrow_datatype.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,13 @@ class DataTypeVisitor final : public StmtExprVisitor {
}
}

void VisitExpr_(const BufferLoadNode* op) {
int tmp = bits_;
bits_ = target_bits_;
StmtExprVisitor::VisitExpr_(op);
bits_ = tmp;
}

void VisitStmt_(const ForNode* op) {
analyzer_.Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent));
vextent_[op->loop_var.as<VarNode>()] = op->extent.dtype();
Expand Down Expand Up @@ -245,7 +252,12 @@ class NarrowDataTypeRewriter : public IndexDataTypeRewriter {
const CastNode* new_op = e.as<CastNode>();
ICHECK(new_op != nullptr) << "Expected type to be CastNode"
<< ", but get " << e->GetTypeKey();
return Cast(visitor_.vmap[op], new_op->value);
PrimExpr new_value = new_op->value;
DataType cast_type = visitor_.vmap[op];
if (new_value.dtype() != cast_type) {
new_value = Cast(cast_type, new_value);
}
return new_value;
}
return Parent::VisitExpr_(op);
}
Expand Down
17 changes: 17 additions & 0 deletions tests/python/tir-transform/test_tir_transform_narrow_datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,5 +413,22 @@ def expected_after(PSUM: T.Buffer((313600,), "int32"), PAVG: T.Buffer((313600,),
tvm.ir.assert_structural_equal(after["main"], expected_after.with_attr("global_symbol", "main"))


def test_narrow_i64_valued_bufferload_index_to_i32():
@T.prim_func
def before(A: T.Buffer((16,), "int64")):
for i in range(T.int64(15)):
A[i + T.int64(1)] = A[i] + T.int64(1)

@T.prim_func
def expect(A: T.Buffer((16,), "int64")):
for i in range(15):
A[i + 1] = A[i] + T.int64(1)

after = tvm.tir.transform.NarrowDataType(32)(
tvm.IRModule.from_expr(before.with_attr("global_symbol", "main"))
)["main"]
tvm.ir.assert_structural_equal(after, expect.with_attr("global_symbol", "main"))


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit a90fb8e

Please sign in to comment.