diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index 5471288878f5..d4a7445b7d08 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -316,7 +316,8 @@ def match_buffer( raise ValueError("Shape must be specified when binding input param") shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape if strides is not None: - strides = [Var(s, "int32") if isinstance(s, str) else s for s in strides] + idx_dtype = shape[0].dtype if isinstance(shape[0], PrimExpr) else "int32" + strides = [Var(s, idx_dtype) if isinstance(s, str) else s for s in strides] else: strides = [] return _ffi_api.MatchBuffer( # type: ignore[attr-defined] # pylint: disable=no-member diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index d7df2a4bb690..bb2530b12594 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -1062,7 +1062,6 @@ def ptx_mma( saturate : bool The optional saturation at the output. - operator : Optional[Literal["xor", "and"]] The 1-bit operator. diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index d590f8b2dd8b..41500051fa89 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -433,7 +433,9 @@ Ramp::Ramp(PrimExpr base, PrimExpr stride, int lanes, Span span) { ICHECK(base.dtype().is_scalar()); ICHECK(stride.dtype().is_scalar()); ICHECK_GT(lanes, 1); - ICHECK_EQ(stride.dtype(), base.dtype()); + if (stride.dtype() != base.dtype()) { + stride = cast(base.dtype(), stride); + } ObjectPtr node = make_object(); node->dtype = base.dtype().with_lanes(lanes); diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py index 105ea62fd572..5b3e68e22fa9 100644 --- a/tests/python/unittest/test_tvmscript_roundtrip.py +++ b/tests/python/unittest/test_tvmscript_roundtrip.py @@ -3331,6 +3331,14 @@ def buffer_ramp_access(a: T.handle, b: T.handle, c: T.handle) -> None: return buffer_ramp_access +def ramp_int64(): + @T.prim_func + def func() -> None: + T.evaluate(T.Ramp(T.int64(0), 1, 3)) + + return func + + def let_expression(): @T.prim_func def func(): @@ -3346,6 +3354,7 @@ def test_void_ptr_vs_handle(): In the future, perhaps these should be de-duplicated by forbidding one of the two C++ representations. """ + # Generates PointerType(PrimType(DataType::Void())) @T.prim_func def void_ptr(out_ret_value: T.handle("void")): @@ -3622,6 +3631,21 @@ def main(a: T.handle, b: T.handle): return main +def string_stride_int64(): + @T.prim_func + def main(a: T.handle, b: T.handle): + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + n = T.int64() + A_s0 = T.int64() + B_s0 = T.int64() + A = T.match_buffer(a, (n,), strides=(A_s0,), buffer_type="auto") + B = T.match_buffer(b, (n,), strides=(B_s0,), buffer_type="auto") + for i in range(n): + B[i] = A[i] + + return main + + def merge_shape_var_def(): @T.prim_func def main(A: T.handle, B: T.handle): @@ -4013,6 +4037,7 @@ def func(): pointer_type, buffer_axis_separator, buffer_ramp_access_as_slice_index, + ramp_int64, let_expression, void_ptr, decl_buffer, @@ -4035,6 +4060,7 @@ def func(): let_stmt_var, let_stmt_value, string_stride, + string_stride_int64, merge_shape_var_def, if_then_else_var, tvm_shfl_builtins,