Skip to content

Commit

Permalink
[TIR] Fix Datatype in Lower TVM Builtin (#14347)
Browse files Browse the repository at this point in the history
Fix data type and add minimal reproducible test.

Co-authored-by: Sunghyun Park <[email protected]>
  • Loading branch information
zxybazh and Sunghyun Park authored Mar 21, 2023
1 parent 36b3097 commit 7f6da09
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 9 deletions.
17 changes: 9 additions & 8 deletions src/tir/transforms/lower_tvm_builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -239,8 +239,9 @@ class BuiltinLower : public StmtExprMutator {
}
}
}
PrimExpr total_bytes = make_const(op->extents[0].dtype(), nbytes);
PrimExpr total_bytes = make_const(DataType::UInt(64), nbytes);
for (size_t i = 0; i < op->extents.size(); ++i) {
// set total_bytes to uint64 to avoid overflow
total_bytes = total_bytes * op->extents[i];
}
ICHECK(device_type_.defined()) << "Unknown device type in current IR";
Expand All @@ -250,13 +251,13 @@ class BuiltinLower : public StmtExprMutator {
Stmt body = SeqStmt({IfThenElse(Call(DataType::Bool(1), builtin::isnullptr(), {op->buffer_var}),
throw_last_error),
op->body});
Stmt alloca = LetStmt(
op->buffer_var,
Call(op->buffer_var.dtype(), Op::Get("tir.TVMBackendAllocWorkspace"),
{cast(DataType::Int(32), device_type_), cast(DataType::Int(32), device_id_),
cast(DataType::UInt(64), total_bytes), IntImm(DataType::Int(32), op->dtype.code()),
IntImm(DataType::Int(32), op->dtype.bits())}),
body);
Stmt alloca =
LetStmt(op->buffer_var,
Call(op->buffer_var.dtype(), Op::Get("tir.TVMBackendAllocWorkspace"),
{cast(DataType::Int(32), device_type_), cast(DataType::Int(32), device_id_),
total_bytes, IntImm(DataType::Int(32), op->dtype.code()),
IntImm(DataType::Int(32), op->dtype.bits())}),
body);

PrimExpr free_op = Call(DataType::Int(32), Op::Get("tir.TVMBackendFreeWorkspace"),
{cast(DataType::Int(32), device_type_),
Expand Down
19 changes: 18 additions & 1 deletion tests/python/unittest/test_tir_transform_lower_tvm_builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
# under the License.
import tvm
from tvm import te
from tvm.script import tir as T
import numpy as np
from tvm import testing


@tvm.register_func("tvm.test_matmul")
Expand Down Expand Up @@ -172,6 +172,23 @@ def build_tir():
tvm.testing.assert_allclose(a.numpy(), expected_value)


def test_lower_overflow_int32():
@T.prim_func
def variance4(rxplaceholder: T.Buffer((T.int64(1), T.int64(32), T.int64(25690112)), "float32")):
T.func_attr({"global_symbol": "variance4", "tir.noalias": True})
rxplaceholder_red = T.allocate([32], "float32", "global")
T_subtract = T.allocate([822083584], "float32", "global")
rxplaceholder_red_1 = T.Buffer((T.int64(32),), data=rxplaceholder_red)
rxplaceholder_1 = T.Buffer((T.int64(822083584),), data=rxplaceholder.data)
T_subtract_1 = T.Buffer((T.int64(822083584),), data=T_subtract)
for ax1, ax2 in T.grid(32, 25690112):
cse_var_1: T.int32 = ax1 * 25690112 + ax2
T_subtract_1[cse_var_1] = rxplaceholder_1[cse_var_1] - rxplaceholder_red_1[ax1]

func = variance4
tvm.build(func, target="llvm") # should not crash


if __name__ == "__main__":
test_call_packed_return_non_i32()
test_lower_packed_func()

0 comments on commit 7f6da09

Please sign in to comment.