Skip to content

Commit

Permalink
[Fix][TIR] Fix dtype issues for match_buffer and ramp node (#16051)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ubospica authored Nov 7, 2023
1 parent a302e0f commit df5d3b5
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 3 deletions.
3 changes: 2 additions & 1 deletion python/tvm/script/ir_builder/tir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,8 @@ def match_buffer(
raise ValueError("Shape must be specified when binding input param")
shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape
if strides is not None:
strides = [Var(s, "int32") if isinstance(s, str) else s for s in strides]
idx_dtype = shape[0].dtype if isinstance(shape[0], PrimExpr) else "int32"
strides = [Var(s, idx_dtype) if isinstance(s, str) else s for s in strides]
else:
strides = []
return _ffi_api.MatchBuffer( # type: ignore[attr-defined] # pylint: disable=no-member
Expand Down
1 change: 0 additions & 1 deletion python/tvm/tir/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1062,7 +1062,6 @@ def ptx_mma(
saturate : bool
The optional saturation at the output.
operator : Optional[Literal["xor", "and"]]
The 1-bit operator.
Expand Down
4 changes: 3 additions & 1 deletion src/tir/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,9 @@ Ramp::Ramp(PrimExpr base, PrimExpr stride, int lanes, Span span) {
ICHECK(base.dtype().is_scalar());
ICHECK(stride.dtype().is_scalar());
ICHECK_GT(lanes, 1);
ICHECK_EQ(stride.dtype(), base.dtype());
if (stride.dtype() != base.dtype()) {
stride = cast(base.dtype(), stride);
}

ObjectPtr<RampNode> node = make_object<RampNode>();
node->dtype = base.dtype().with_lanes(lanes);
Expand Down
26 changes: 26 additions & 0 deletions tests/python/unittest/test_tvmscript_roundtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -3331,6 +3331,14 @@ def buffer_ramp_access(a: T.handle, b: T.handle, c: T.handle) -> None:
return buffer_ramp_access


def ramp_int64():
@T.prim_func
def func() -> None:
T.evaluate(T.Ramp(T.int64(0), 1, 3))

return func


def let_expression():
@T.prim_func
def func():
Expand All @@ -3346,6 +3354,7 @@ def test_void_ptr_vs_handle():
In the future, perhaps these should be de-duplicated by forbidding
one of the two C++ representations.
"""

# Generates PointerType(PrimType(DataType::Void()))
@T.prim_func
def void_ptr(out_ret_value: T.handle("void")):
Expand Down Expand Up @@ -3622,6 +3631,21 @@ def main(a: T.handle, b: T.handle):
return main


def string_stride_int64():
@T.prim_func
def main(a: T.handle, b: T.handle):
T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
n = T.int64()
A_s0 = T.int64()
B_s0 = T.int64()
A = T.match_buffer(a, (n,), strides=(A_s0,), buffer_type="auto")
B = T.match_buffer(b, (n,), strides=(B_s0,), buffer_type="auto")
for i in range(n):
B[i] = A[i]

return main


def merge_shape_var_def():
@T.prim_func
def main(A: T.handle, B: T.handle):
Expand Down Expand Up @@ -4013,6 +4037,7 @@ def func():
pointer_type,
buffer_axis_separator,
buffer_ramp_access_as_slice_index,
ramp_int64,
let_expression,
void_ptr,
decl_buffer,
Expand All @@ -4035,6 +4060,7 @@ def func():
let_stmt_var,
let_stmt_value,
string_stride,
string_stride_int64,
merge_shape_var_def,
if_then_else_var,
tvm_shfl_builtins,
Expand Down

0 comments on commit df5d3b5

Please sign in to comment.