Skip to content

Commit

Permalink
[TVMScript] Encourage using T.Buffer directly (apache#13971)
Browse files Browse the repository at this point in the history
Previously there are two equivalent ways of declaring a buffer in
TVMScript:

```python
buffer = T.buffer_decl(...)
buffer = T.Buffer(...)
```

The two approaches are aliases to each other and are essentially the
same in implementation. Therefore, this PR encourages to use `T.Buffer`
as the recommended approach as it's a bit shorter. Meanwhile,
`T.buffer_decl` will continue to be valid in TVMScript, but a
deprecation warning will be emitted if its used.
  • Loading branch information
junrushao authored Feb 13, 2023
1 parent 82cf9f7 commit bea4919
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 19 deletions.
2 changes: 1 addition & 1 deletion python/tvm/script/ir_builder/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@
"""Package tvm.script.ir_builder.tir"""
from .ir import * # pylint: disable=wildcard-import,redefined-builtin
from .ir import boolean as bool # pylint: disable=redefined-builtin
from .ir import buffer_decl as Buffer
from .ir import buffer as Buffer
23 changes: 18 additions & 5 deletions python/tvm/script/ir_builder/tir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@
# pylint: enable=unused-import


def buffer_decl(
def buffer(
shape: Union[List[PrimExpr], Tuple[PrimExpr], PrimExpr, Integral],
dtype: str = "float32",
data: Var = None,
Expand Down Expand Up @@ -138,7 +138,7 @@ def buffer_decl(
The declared buffer.
"""
shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape
return _ffi_api.BufferDecl( # type: ignore[attr-defined] # pylint: disable=no-member
return _ffi_api.Buffer( # type: ignore[attr-defined] # pylint: disable=no-member
shape,
dtype,
"",
Expand All @@ -153,6 +153,11 @@ def buffer_decl(
)


@deprecated("T.buffer_decl(...)", "T.Buffer(...)")
def buffer_decl(*args, **kwargs):
return buffer(*args, **kwargs)


def prim_func() -> frame.PrimFuncFrame:
"""The primitive function statement.
Expand Down Expand Up @@ -1177,7 +1182,11 @@ def env_thread(thread_tag: str) -> IterVar:
return _ffi_api.EnvThread(thread_tag) # type: ignore[attr-defined] # pylint: disable=no-member


def buffer_store(buffer: Buffer, value: PrimExpr, indices: List[Union[PrimExpr, slice]]) -> None:
def buffer_store(
buffer: Buffer, # pylint: disable=redefined-outer-name
value: PrimExpr,
indices: List[Union[PrimExpr, slice]],
) -> None:
"""Buffer store node.
Parameters
Expand Down Expand Up @@ -1211,7 +1220,10 @@ def buffer_store(buffer: Buffer, value: PrimExpr, indices: List[Union[PrimExpr,
)


def prefetch(buffer: Buffer, bounds: List[Range]) -> None:
def prefetch(
buffer: Buffer, # pylint: disable=redefined-outer-name
bounds: List[Range],
) -> None:
"""The prefetch hint for a buffer.
Parameters
Expand Down Expand Up @@ -1432,7 +1444,7 @@ def ptr(dtype: str, storage_scope: str = "global") -> Var:
return _ffi_api.Ptr(dtype, storage_scope) # type: ignore[attr-defined] # pylint: disable=no-member


@deprecated("T.buffer_var", "T.Ptr")
@deprecated("T.buffer_var", "T.handle")
def buffer_var(dtype: str, storage_scope: str = "global") -> Var:
"""The pointer declaration function.
Expand Down Expand Up @@ -1815,6 +1827,7 @@ def wrapped(*args, **kwargs):
"float16x64",
"float32x64",
"float64x64",
"buffer",
"buffer_decl",
"prim_func",
"arg",
Expand Down
12 changes: 4 additions & 8 deletions python/tvm/script/parser/tir/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from tvm.ir.base import deprecated
from tvm.tir import Buffer, PrimFunc

from ...ir_builder.tir import buffer_decl, ptr
from ...ir_builder.tir import buffer, ptr
from .._core import parse, utils


Expand Down Expand Up @@ -49,9 +49,7 @@ def prim_func(func: Callable) -> Union[PrimFunc, Callable]:


class BufferProxy:
"""Buffer proxy class for constructing tir buffer.
Overload __call__ and __getitem__ to support syntax as T.Buffer() and T.Buffer().
"""
"""Buffer proxy class for constructing tir buffer."""

def __call__(
self,
Expand All @@ -66,7 +64,7 @@ def __call__(
buffer_type="",
axis_separators=None,
) -> Buffer:
return buffer_decl(
return buffer(
shape,
dtype=dtype,
data=data,
Expand All @@ -89,9 +87,7 @@ def __getitem__(self, keys) -> Buffer:


class PtrProxy:
"""Ptr proxy class for constructing tir pointer.
Overload __call__ and __getitem__ to support syntax as T.Ptr() and T.Ptr().
"""
"""Ptr proxy class for constructing tir pointer."""

@deprecated("T.Ptr(...)", "T.handle(...)")
def __call__(self, dtype, storage_scope="global"):
Expand Down
3 changes: 1 addition & 2 deletions src/script/ir_builder/tir/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -593,8 +593,7 @@ TVM_STATIC_IR_FUNCTOR(Namer, vtable)
Namer::Name(var->var, name);
});

TVM_REGISTER_GLOBAL("script.ir_builder.tir.BufferDecl").set_body_typed(BufferDecl);

TVM_REGISTER_GLOBAL("script.ir_builder.tir.Buffer").set_body_typed(BufferDecl);
TVM_REGISTER_GLOBAL("script.ir_builder.tir.PrimFunc").set_body_typed(PrimFunc);
TVM_REGISTER_GLOBAL("script.ir_builder.tir.Arg")
.set_body_typed([](String name, ObjectRef obj) -> ObjectRef {
Expand Down
6 changes: 3 additions & 3 deletions tests/python/unittest/test_auto_scheduler_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,9 +209,9 @@ def tir_matmul(
) -> None:
# function attr dict
T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
A_flat = T.buffer_decl([16384], dtype="float32", data=A.data)
B_flat = T.buffer_decl([16384], dtype="float32", data=B.data)
C_flat = T.buffer_decl([16384], dtype="float32", data=C.data)
A_flat = T.Buffer([16384], dtype="float32", data=A.data)
B_flat = T.Buffer([16384], dtype="float32", data=B.data)
C_flat = T.Buffer([16384], dtype="float32", data=C.data)
# body
for x, y in T.grid(128, 128):
C_flat[x * 128 + y] = T.float32(0)
Expand Down

0 comments on commit bea4919

Please sign in to comment.