Skip to content

Commit

Permalink
[TVMScript] Unify T.handle and T.Ptr (#13969)
Browse files Browse the repository at this point in the history
While both represents a pointer type, `T.handle` was previously used to
refer to tir variables whose `type_annotation` is `PrimType`, while
`T.Ptr` instead specifically refers to `PointerType`. The divide is
unnecessary if we extend `T.handle` slightly.
  • Loading branch information
junrushao authored Feb 13, 2023
1 parent a5a6e7f commit dc626f3
Show file tree
Hide file tree
Showing 14 changed files with 77 additions and 64 deletions.
5 changes: 2 additions & 3 deletions include/tvm/script/ir_builder/tir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -415,12 +415,12 @@ void Prefetch(Buffer buffer, Array<Range> bounds);
void Evaluate(PrimExpr value);

/*!
* \brief The pointer declaration function.
* \brief Create a TIR var that represents a pointer
* \param dtype The data type of the pointer.
* \param storage_scope The storage scope of the pointer.
* \return The pointer.
*/
PrimExpr Ptr(runtime::DataType dtype, String storage_scope = "global");
Var Handle(runtime::DataType dtype = runtime::DataType::Void(), String storage_scope = "global");

#define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName, DType) \
inline PrimExpr FuncName(Optional<PrimExpr> expr = NullOpt) { \
Expand Down Expand Up @@ -455,7 +455,6 @@ TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(Float, DataType::Float);
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(UInt, DataType::UInt);
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(Int, DataType::Int);
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Boolean, DataType::Bool());
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Handle, DataType::Handle());
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Void, DataType::Void());

#undef TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST
Expand Down
13 changes: 8 additions & 5 deletions python/tvm/script/ir_builder/tir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1358,20 +1358,23 @@ def boolean(expr: Optional[PrimExpr] = None) -> PrimExpr:
return _ffi_api.Boolean(expr) # type: ignore[attr-defined] # pylint: disable=no-member


def handle(expr: Optional[PrimExpr] = None) -> PrimExpr:
"""Construct a new tir.Var with type handle or cast expression to type handle.
def handle(dtype: str = "void", storage_scope: str = "global") -> Var:
"""Create a TIR var that represents a pointer.
Parameters
----------
expr: PrimExpr
The expression to be cast.
dtype: str
The data type of the pointer.
storage_scope: str
The storage scope of the pointer.
Returns
-------
res : PrimExpr
The new tir.Var with type handle or casted expression with type handle.
"""
return _ffi_api.Handle(expr) # type: ignore[attr-defined] # pylint: disable=no-member
return _ffi_api.Handle(dtype, storage_scope) # type: ignore[attr-defined] # pylint: disable=no-member


def void(expr: Optional[PrimExpr] = None) -> PrimExpr:
Expand Down
5 changes: 3 additions & 2 deletions python/tvm/script/parser/tir/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def __call__(
axis_separators=axis_separators,
)

@deprecated("T.Buffer(...)", "T.Buffer(...)")
@deprecated("T.Buffer[...]", "T.Buffer(...)")
def __getitem__(self, keys) -> Buffer:
if not isinstance(keys, tuple):
return self(keys)
Expand All @@ -93,12 +93,13 @@ class PtrProxy:
Overload __call__ and __getitem__ to support syntax as T.Ptr() and T.Ptr().
"""

@deprecated("T.Ptr(...)", "T.handle(...)")
def __call__(self, dtype, storage_scope="global"):
if callable(dtype):
dtype = dtype().dtype
return ptr(dtype, storage_scope) # pylint: disable=no-member # type: ignore

@deprecated("T.Ptr(...)", "T.Ptr(...)")
@deprecated("T.Ptr[...]", "T.handle(...)")
def __getitem__(self, keys):
if not isinstance(keys, tuple):
return self(keys)
Expand Down
10 changes: 10 additions & 0 deletions src/script/ir_builder/tir/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -545,6 +545,16 @@ PrimExpr Ptr(runtime::DataType dtype, String storage_scope) {
return tvm::tir::Var("", tvm::PointerType(PrimType(dtype), storage_scope));
}

Var Handle(runtime::DataType dtype, String storage_scope) {
Type type_annotation{nullptr};
if (dtype.is_void() && storage_scope == "global") {
type_annotation = PrimType(runtime::DataType::Handle());
} else {
type_annotation = PointerType(PrimType(dtype), storage_scope);
}
return tvm::tir::Var("", type_annotation);
}

using tvm::script::ir_builder::details::Namer;

TVM_STATIC_IR_FUNCTOR(Namer, vtable)
Expand Down
6 changes: 3 additions & 3 deletions src/script/printer/tir/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,10 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
element_type = d->AsDoc<ExprDoc>(ty->element_type, ty_p->Attr("element_type"));
}
if (ty->storage_scope == "") {
return TIR(d, "Ptr")->Call({element_type});
return TIR(d, "handle")->Call({element_type});
} else {
return TIR(d, "Ptr")->Call(
{element_type, LiteralDoc::Str(ty->storage_scope, ty_p->Attr("storage_scope"))});
return TIR(d, "handle")
->Call({element_type, LiteralDoc::Str(ty->storage_scope, ty_p->Attr("storage_scope"))});
}
});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def test_create_executor_metadata_single_func():
class Module:
@T.prim_func
def __tvm_main__(
a: T.handle, output: T.handle, workspace: T.Ptr(T.uint8), constants: T.Ptr(T.uint8)
a: T.handle, output: T.handle, workspace: T.handle("uint8"), constants: T.handle("uint8")
) -> None:
# function attr dict
T.func_attr({"global_symbol": "test_mod___tvm_main__", "runner_function": True, "target": T.target({"kind": "llvm", "tag": "", "keys": ["cpu"]}), "input_vars": [a], "output_vars": [output], "devices": ["test_device"]})
Expand Down
4 changes: 2 additions & 2 deletions tests/python/relay/aot/test_pass_aot_lower_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,13 +178,13 @@ def @main(%a: Tensor[(5, 7), float32]) -> Tensor[(5, 7), float32] {
def func(a: T.handle, output: T.handle) -> None:
# function attr dict
T.func_attr({"global_symbol": "test_mod___tvm_main__", "runner_function": True, "target": T.target({"kind":"llvm", "tag":"", "keys":["cpu"]}), "input_vars": [a], "output_vars": [output], "devices": []})
tmp_read = T.Ptr("uint8", "")
tmp_read = T.handle("uint8", "")
# buffer definition
tmp_read_1 = T.Buffer([T.uint64(140)], dtype="uint8", data=tmp_read)
a_buffer = T.match_buffer(a, [5, 7], dtype="float32", align=16)
output_buffer = T.match_buffer(output, [5, 7], dtype="float32", align=16)
# body
tmp_write: T.Ptr(T.uint8) = output_buffer.data
tmp_write: T.handle("uint8") = output_buffer.data
tmp_write_1 = T.Buffer([T.uint64(140)], dtype="uint8", data=tmp_write)
for i in T.serial(140):
tmp_write_1[i] = T.let(tmp_read, a_buffer.data, tmp_read_1[i])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ def test_buffer_conditional_lowering():
"""

@T.prim_func
def before(A: T.Ptr("float32")):
def before(A: T.handle("float32")):
T.func_attr({"global_symbol": "main", "tir.noalias": True})
for i in range(1):
A_1 = T.Buffer((1,), data=A)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def main():
T.func_attr({"from_legacy_te_schedule": True})

# If a pointer defined using a LetStmt,
A_data: T.Ptr("int32") = T.call_extern("dummy_extern_function", dtype="handle")
A_data: T.handle("int32") = T.call_extern("dummy_extern_function", dtype="handle")

# and a buffer is backed by that pointer,
A = T.decl_buffer([1], dtype="float32", data=A_data)
Expand Down
4 changes: 2 additions & 2 deletions tests/python/unittest/test_tir_transform_storage_rewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,12 +689,12 @@ class TestLetBufferRewrite(BaseCompare):
"""

def before() -> None:
A_data: T.Ptr("int32") = T.call_extern("dummy_func", dtype="handle")
A_data: T.handle("int32") = T.call_extern("dummy_func", dtype="handle")
A = T.Buffer([8], "int32", data=A_data)
A[0:8] = T.broadcast(42, 8)

def expected() -> None:
A_data: T.Ptr("int32x8") = T.call_extern("dummy_func", dtype="handle")
A_data: T.handle("int32x8") = T.call_extern("dummy_func", dtype="handle")
A = T.Buffer([1], "int32x8", data=A_data)
A[0] = T.broadcast(42, 8)

Expand Down
Loading

0 comments on commit dc626f3

Please sign in to comment.