Skip to content

Commit

Permalink
[TVMScript] Fix mismatched dtype of IterVar in T.thread_binding
Browse files Browse the repository at this point in the history
As reported by the [community](https://discuss.tvm.apache.org/t/int32-int64-issue-when-codegen-into-llvm-function/15915),
the dtype of IterVar in `T.thread_binding` is not consistent with the dtype of the corresponding
IterVar. This PR fixes this issue.
  • Loading branch information
Hzfengsy committed Nov 2, 2023
1 parent bd3e8bb commit 4345e32
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 5 deletions.
9 changes: 5 additions & 4 deletions src/script/ir_builder/tir/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -344,13 +344,14 @@ ForFrame ThreadBinding(PrimExpr start, PrimExpr stop, String thread,
PrimExpr extent = arith::Analyzer().Simplify(stop - start);
ObjectPtr<ForFrameNode> n = make_object<ForFrameNode>();
int bits = std::max(min.dtype().bits(), extent.dtype().bits());
n->vars = {Var("v", DataType(min.dtype().code(), bits, 1))};
DataType dtype = DataType(min.dtype().code(), bits, 1);
n->vars = {Var("v", dtype)};
n->doms = {Range::FromMinExtent(min, extent)};
n->f_make_for_loop = [annotations, thread](Array<Var> vars, Array<Range> doms, Stmt body) -> For {
n->f_make_for_loop = [annotations, thread, dtype](Array<Var> vars, Array<Range> doms,
Stmt body) -> For {
ICHECK_EQ(vars.size(), 1);
ICHECK_EQ(doms.size(), 1);
IterVar iter_var(Range(nullptr), Var("iter", DataType::Int(32)), IterVarType::kThreadIndex,
thread);
IterVar iter_var(Range(nullptr), Var("iter", dtype), IterVarType::kThreadIndex, thread);
return For(vars[0], doms[0]->min, doms[0]->extent, ForKind::kThreadBinding, body, iter_var,
annotations.value_or(Map<String, ObjectRef>()));
};
Expand Down
2 changes: 1 addition & 1 deletion src/tir/transforms/lower_cross_thread_reduction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -836,7 +836,7 @@ class CrossThreadReductionTransformer : public StmtMutator {
/*kind=*/ForKind::kThreadBinding, //
/*body=*/body, //
/*thread_binding=*/
IterVar(NullValue<Range>(), Var(""), IterVarType::kThreadIndex,
IterVar(NullValue<Range>(), Var("", loop_vars[i]->dtype), IterVarType::kThreadIndex,
"threadIdx." + dim_index));
}
return body;
Expand Down
15 changes: 15 additions & 0 deletions tests/python/unittest/test_tvmscript_parser_tir.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,5 +325,20 @@ def evaluated(A: T.Buffer((2, 128, 128), "int32")):
tvm.ir.assert_structural_equal(with_builtin, evaluated)


def test_thread_binding_dtype():
@T.prim_func(private=True)
def func(A: T.Buffer((128, 128)), B: T.Buffer((128, 128))):
for i in T.thread_binding(T.int64(128), "threadIdx.x"):
for j in T.thread_binding(128, "threadIdx.y"):
B[i, j] = A[i, j]

loop_i = func.body
loop_j = loop_i.body
assert loop_i.loop_var.dtype == "int64"
assert loop_i.thread_binding.var.dtype == "int64"
assert loop_j.loop_var.dtype == "int32"
assert loop_j.thread_binding.var.dtype == "int32"


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 4345e32

Please sign in to comment.