diff --git a/python/tvm/script/ir_builder/tir/__init__.py b/python/tvm/script/ir_builder/tir/__init__.py index 563ac56f7b10..db2fc6aca095 100644 --- a/python/tvm/script/ir_builder/tir/__init__.py +++ b/python/tvm/script/ir_builder/tir/__init__.py @@ -17,4 +17,4 @@ """Package tvm.script.ir_builder.tir""" from .ir import * # pylint: disable=wildcard-import,redefined-builtin from .ir import boolean as bool # pylint: disable=redefined-builtin -from .ir import buffer_decl as Buffer +from .ir import buffer as Buffer diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index 2c5a848e4ab4..5f4e9d4f2cf0 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -86,7 +86,7 @@ # pylint: enable=unused-import -def buffer_decl( +def buffer( shape: Union[List[PrimExpr], Tuple[PrimExpr], PrimExpr, Integral], dtype: str = "float32", data: Var = None, @@ -138,7 +138,7 @@ def buffer_decl( The declared buffer. """ shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape - return _ffi_api.BufferDecl( # type: ignore[attr-defined] # pylint: disable=no-member + return _ffi_api.Buffer( # type: ignore[attr-defined] # pylint: disable=no-member shape, dtype, "", @@ -153,6 +153,11 @@ def buffer_decl( ) +@deprecated("T.buffer_decl(...)", "T.Buffer(...)") +def buffer_decl(*args, **kwargs): + return buffer(*args, **kwargs) + + def prim_func() -> frame.PrimFuncFrame: """The primitive function statement. @@ -1177,7 +1182,11 @@ def env_thread(thread_tag: str) -> IterVar: return _ffi_api.EnvThread(thread_tag) # type: ignore[attr-defined] # pylint: disable=no-member -def buffer_store(buffer: Buffer, value: PrimExpr, indices: List[Union[PrimExpr, slice]]) -> None: +def buffer_store( + buffer: Buffer, # pylint: disable=redefined-outer-name + value: PrimExpr, + indices: List[Union[PrimExpr, slice]], +) -> None: """Buffer store node. Parameters @@ -1211,7 +1220,10 @@ def buffer_store(buffer: Buffer, value: PrimExpr, indices: List[Union[PrimExpr, ) -def prefetch(buffer: Buffer, bounds: List[Range]) -> None: +def prefetch( + buffer: Buffer, # pylint: disable=redefined-outer-name + bounds: List[Range], +) -> None: """The prefetch hint for a buffer. Parameters @@ -1432,7 +1444,7 @@ def ptr(dtype: str, storage_scope: str = "global") -> Var: return _ffi_api.Ptr(dtype, storage_scope) # type: ignore[attr-defined] # pylint: disable=no-member -@deprecated("T.buffer_var", "T.Ptr") +@deprecated("T.buffer_var", "T.handle") def buffer_var(dtype: str, storage_scope: str = "global") -> Var: """The pointer declaration function. @@ -1815,6 +1827,7 @@ def wrapped(*args, **kwargs): "float16x64", "float32x64", "float64x64", + "buffer", "buffer_decl", "prim_func", "arg", diff --git a/python/tvm/script/parser/tir/entry.py b/python/tvm/script/parser/tir/entry.py index 51743e6b507b..411a7f8f3c83 100644 --- a/python/tvm/script/parser/tir/entry.py +++ b/python/tvm/script/parser/tir/entry.py @@ -21,7 +21,7 @@ from tvm.ir.base import deprecated from tvm.tir import Buffer, PrimFunc -from ...ir_builder.tir import buffer_decl, ptr +from ...ir_builder.tir import buffer, ptr from .._core import parse, utils @@ -49,9 +49,7 @@ def prim_func(func: Callable) -> Union[PrimFunc, Callable]: class BufferProxy: - """Buffer proxy class for constructing tir buffer. - Overload __call__ and __getitem__ to support syntax as T.Buffer() and T.Buffer(). - """ + """Buffer proxy class for constructing tir buffer.""" def __call__( self, @@ -66,7 +64,7 @@ def __call__( buffer_type="", axis_separators=None, ) -> Buffer: - return buffer_decl( + return buffer( shape, dtype=dtype, data=data, @@ -89,9 +87,7 @@ def __getitem__(self, keys) -> Buffer: class PtrProxy: - """Ptr proxy class for constructing tir pointer. - Overload __call__ and __getitem__ to support syntax as T.Ptr() and T.Ptr(). - """ + """Ptr proxy class for constructing tir pointer.""" @deprecated("T.Ptr(...)", "T.handle(...)") def __call__(self, dtype, storage_scope="global"): diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc index 9ab19b2e28a5..30102b687722 100644 --- a/src/script/ir_builder/tir/ir.cc +++ b/src/script/ir_builder/tir/ir.cc @@ -593,8 +593,7 @@ TVM_STATIC_IR_FUNCTOR(Namer, vtable) Namer::Name(var->var, name); }); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.BufferDecl").set_body_typed(BufferDecl); - +TVM_REGISTER_GLOBAL("script.ir_builder.tir.Buffer").set_body_typed(BufferDecl); TVM_REGISTER_GLOBAL("script.ir_builder.tir.PrimFunc").set_body_typed(PrimFunc); TVM_REGISTER_GLOBAL("script.ir_builder.tir.Arg") .set_body_typed([](String name, ObjectRef obj) -> ObjectRef { diff --git a/tests/python/unittest/test_auto_scheduler_feature.py b/tests/python/unittest/test_auto_scheduler_feature.py index ddd86347c2ec..c8edebfd3b87 100644 --- a/tests/python/unittest/test_auto_scheduler_feature.py +++ b/tests/python/unittest/test_auto_scheduler_feature.py @@ -209,9 +209,9 @@ def tir_matmul( ) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - A_flat = T.buffer_decl([16384], dtype="float32", data=A.data) - B_flat = T.buffer_decl([16384], dtype="float32", data=B.data) - C_flat = T.buffer_decl([16384], dtype="float32", data=C.data) + A_flat = T.Buffer([16384], dtype="float32", data=A.data) + B_flat = T.Buffer([16384], dtype="float32", data=B.data) + C_flat = T.Buffer([16384], dtype="float32", data=C.data) # body for x, y in T.grid(128, 128): C_flat[x * 128 + y] = T.float32(0)