Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TVMScript] Simplify TIR Var Definition #13970

Merged
merged 2 commits into from
Feb 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions include/tvm/script/ir_builder/tir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -415,12 +415,12 @@ void Prefetch(Buffer buffer, Array<Range> bounds);
void Evaluate(PrimExpr value);

/*!
* \brief The pointer declaration function.
* \brief Create a TIR var that represents a pointer
* \param dtype The data type of the pointer.
* \param storage_scope The storage scope of the pointer.
* \return The pointer.
*/
PrimExpr Ptr(runtime::DataType dtype, String storage_scope = "global");
Var Handle(runtime::DataType dtype = runtime::DataType::Void(), String storage_scope = "global");

#define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName, DType) \
inline PrimExpr FuncName(Optional<PrimExpr> expr = NullOpt) { \
Expand Down Expand Up @@ -455,7 +455,6 @@ TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(Float, DataType::Float);
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(UInt, DataType::UInt);
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(Int, DataType::Int);
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Boolean, DataType::Bool());
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Handle, DataType::Handle());
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Void, DataType::Void());

#undef TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST
Expand Down
14 changes: 9 additions & 5 deletions python/tvm/script/ir_builder/tir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1358,20 +1358,23 @@ def boolean(expr: Optional[PrimExpr] = None) -> PrimExpr:
return _ffi_api.Boolean(expr) # type: ignore[attr-defined] # pylint: disable=no-member


def handle(expr: Optional[PrimExpr] = None) -> PrimExpr:
"""Construct a new tir.Var with type handle or cast expression to type handle.
def handle(dtype: str = "void", storage_scope: str = "global") -> Var:
"""Create a TIR var that represents a pointer.

Parameters
----------
expr: PrimExpr
The expression to be cast.
dtype: str
The data type of the pointer.

storage_scope: str
The storage scope of the pointer.

Returns
-------
res : PrimExpr
The new tir.Var with type handle or casted expression with type handle.
"""
return _ffi_api.Handle(expr) # type: ignore[attr-defined] # pylint: disable=no-member
return _ffi_api.Handle(dtype, storage_scope) # type: ignore[attr-defined] # pylint: disable=no-member


def void(expr: Optional[PrimExpr] = None) -> PrimExpr:
Expand All @@ -1390,6 +1393,7 @@ def void(expr: Optional[PrimExpr] = None) -> PrimExpr:
return _ffi_api.Void(expr) # type: ignore[attr-defined] # pylint: disable=no-member


@deprecated("T.var", "T.{dtype}")
def var(dtype: str, name: str = "") -> Var:
"""Construct a new tir.Var.

Expand Down
5 changes: 3 additions & 2 deletions python/tvm/script/parser/tir/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def __call__(
axis_separators=axis_separators,
)

@deprecated("T.Buffer(...)", "T.Buffer(...)")
@deprecated("T.Buffer[...]", "T.Buffer(...)")
def __getitem__(self, keys) -> Buffer:
if not isinstance(keys, tuple):
return self(keys)
Expand All @@ -93,12 +93,13 @@ class PtrProxy:
Overload __call__ and __getitem__ to support syntax as T.Ptr() and T.Ptr().
"""

@deprecated("T.Ptr(...)", "T.handle(...)")
def __call__(self, dtype, storage_scope="global"):
if callable(dtype):
dtype = dtype().dtype
return ptr(dtype, storage_scope) # pylint: disable=no-member # type: ignore

@deprecated("T.Ptr(...)", "T.Ptr(...)")
@deprecated("T.Ptr[...]", "T.handle(...)")
def __getitem__(self, keys):
if not isinstance(keys, tuple):
return self(keys)
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/script/parser/tir/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def bind_assign_value(self: Parser, node: doc.expr, var_name: str, value: Any) -
IRBuilder.name(var_name, value)
return value
elif isinstance(value, PrimExpr):
var = T.var(value.dtype)
var = Var("", value.dtype)
IRBuilder.name(var_name, var)
frame = T.let(var, value)
frame.add_callback(partial(frame.__exit__, None, None, None))
Expand Down
40 changes: 20 additions & 20 deletions python/tvm/tir/tensor_intrin/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,8 @@ def ldmatrix_desc(warp_handle: T.handle, shared_handle: T.handle) -> None:

@T.prim_func
def ldmatrix_impl(warp_handle: T.handle, shared_handle: T.handle) -> None:
s0 = T.var("int32")
s1 = T.var("int32")
s0 = T.int32()
s1 = T.int32()
shared = T.match_buffer(
shared_handle,
shmem_shape,
Expand Down Expand Up @@ -385,8 +385,8 @@ def mma_store_desc(a: T.handle, c: T.handle) -> None:

@T.prim_func
def mma_store_impl(a: T.handle, c: T.handle) -> None:
s0 = T.var("int32")
s1 = T.var("int32")
s0 = T.int32()
s1 = T.int32()

C_warp = T.match_buffer(
a, [WARP_SIZE, local_size], dtype=dtype, scope="warp", offset_factor=1
Expand Down Expand Up @@ -530,10 +530,10 @@ def wmma_load_desc(a: T.handle, c: T.handle) -> None:

@T.prim_func
def wmma_load_impl(a: T.handle, c: T.handle) -> None:
s1 = T.var("int32")
s0 = T.var("int32")
d1 = T.var("int32")
d0 = T.var("int32")
s1 = T.int32()
s0 = T.int32()
d1 = T.int32()
d0 = T.int32()
A = T.match_buffer(
a,
(m_dim, n_dim),
Expand Down Expand Up @@ -593,8 +593,8 @@ def wmma_fill_desc(c: T.handle) -> None:

@T.prim_func
def wmma_fill_impl(c: T.handle) -> None:
d1 = T.var("int32")
d0 = T.var("int32")
d1 = T.int32()
d0 = T.int32()
C = T.match_buffer(
c,
(m_dim, n_dim),
Expand Down Expand Up @@ -643,10 +643,10 @@ def wmma_store_desc(a: T.handle, c: T.handle) -> None:

@T.prim_func
def wmma_store_impl(a: T.handle, c: T.handle) -> None:
s1 = T.var("int32")
s0 = T.var("int32")
d1 = T.var("int32")
d0 = T.var("int32")
s1 = T.int32()
s0 = T.int32()
d1 = T.int32()
d0 = T.int32()
A = T.match_buffer(
a,
(m_dim, n_dim),
Expand Down Expand Up @@ -726,12 +726,12 @@ def wmma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None:

@T.prim_func
def wmma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None:
a1 = T.var("int32")
a0 = T.var("int32")
b1 = T.var("int32")
b0 = T.var("int32")
c1 = T.var("int32")
c0 = T.var("int32")
a1 = T.int32()
a0 = T.int32()
b1 = T.int32()
b0 = T.int32()
c1 = T.int32()
c0 = T.int32()

A = T.match_buffer(
a,
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/utils/roofline/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ def estimate_peak_flops(
@T.prim_func
def peak_bandwidth_tir(a: T.handle, b: T.handle, blocks: T.int32, warp_size: T.int32) -> None:
# pylint: disable=invalid-name, missing-function-docstring
N = T.var("int32")
N = T.int32()
A = T.match_buffer(a, [blocks, N, 4, warp_size], "float32")
B = T.match_buffer(b, [blocks, 4, warp_size], "float32")
for i in T.thread_binding(blocks, "blockIdx.x"):
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/utils/roofline/x86.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def estimate_peak_fma_flops(
@T.prim_func
def peak_bandwidth_tir(a: T.handle, b: T.handle, threads: T.int32, vec_width: T.int32) -> None:
# pylint: disable=invalid-name, missing-function-docstring
N = T.var("int32")
N = T.int32()
A = T.match_buffer(a, [threads, N, 4, vec_width], "float32")
B = T.match_buffer(b, [threads, 4, vec_width], "float32")
# Parallelism is necessary to hit all cores/nodes
Expand Down
10 changes: 10 additions & 0 deletions src/script/ir_builder/tir/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -545,6 +545,16 @@ PrimExpr Ptr(runtime::DataType dtype, String storage_scope) {
return tvm::tir::Var("", tvm::PointerType(PrimType(dtype), storage_scope));
}

Var Handle(runtime::DataType dtype, String storage_scope) {
Type type_annotation{nullptr};
if (dtype.is_void() && storage_scope == "global") {
type_annotation = PrimType(runtime::DataType::Handle());
} else {
type_annotation = PointerType(PrimType(dtype), storage_scope);
}
return tvm::tir::Var("", type_annotation);
}

using tvm::script::ir_builder::details::Namer;

TVM_STATIC_IR_FUNCTOR(Namer, vtable)
Expand Down
25 changes: 20 additions & 5 deletions src/script/printer/tir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,29 @@ Doc PrintVar(const tir::Var& var, const ObjectPath& var_p, const IRDocsifier& d)
if (Optional<Frame> opt_f = FindLowestVarDef(var, d)) {
ExprDoc lhs = DefineVar(var, opt_f.value(), d);
Type type = var->type_annotation;
ObjectPath type_p = var_p->Attr("type_annotation");
ExprDoc rhs{nullptr};
if (const auto* ptr_type = type.as<PointerTypeNode>()) {
ICHECK(ptr_type->element_type->IsInstance<PrimTypeNode>());
ExprDoc rhs = d->AsDoc<ExprDoc>(type, var_p->Attr("type_annotation"));
opt_f.value()->stmts.push_back(AssignDoc(lhs, rhs, NullOpt));
const auto* prim_type = ptr_type->element_type.as<PrimTypeNode>();
ICHECK(prim_type);
ExprDoc element_type =
LiteralDoc::DataType(prim_type->dtype, type_p->Attr("element_type")->Attr("dtype"));
rhs = TIR(d, "handle");
rhs->source_paths.push_back(var_p->Attr("dtype"));
if (ptr_type->storage_scope == "") {
rhs = rhs->Call({element_type});
} else {
rhs = rhs->Call({element_type,
LiteralDoc::Str(ptr_type->storage_scope, //
type_p->Attr("storage_scope"))});
}
} else {
ExprDoc rhs = TIR(d, "var")->Call({LiteralDoc::DataType(var->dtype, var_p->Attr("dtype"))});
opt_f.value()->stmts.push_back(AssignDoc(lhs, rhs, NullOpt));
rhs = TIR(d, DType2Str(var->dtype));
rhs->source_paths.push_back(var_p->Attr("dtype"));
rhs = rhs->Call({});
}
rhs->source_paths.push_back(type_p);
opt_f.value()->stmts.push_back(AssignDoc(lhs, rhs, NullOpt));
} else {
LOG(WARNING) << "Didn't find variable definition for: " << var->name_hint;
}
Expand Down
6 changes: 3 additions & 3 deletions src/script/printer/tir/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,10 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
element_type = d->AsDoc<ExprDoc>(ty->element_type, ty_p->Attr("element_type"));
}
if (ty->storage_scope == "") {
return TIR(d, "Ptr")->Call({element_type});
return TIR(d, "handle")->Call({element_type});
} else {
return TIR(d, "Ptr")->Call(
{element_type, LiteralDoc::Str(ty->storage_scope, ty_p->Attr("storage_scope"))});
return TIR(d, "handle")
->Call({element_type, LiteralDoc::Str(ty->storage_scope, ty_p->Attr("storage_scope"))});
}
});

Expand Down
76 changes: 38 additions & 38 deletions tests/python/contrib/test_ethosu/test_copy_compute_reordering.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,16 +476,16 @@ class ModuleBefore:
def main(placeholder: T.Buffer(97156, "int8"), placeholder_encoded: T.Buffer(208, "uint8"), placeholder_encoded_1: T.Buffer(112, "uint8"), placeholder_encoded_2: T.Buffer(96, "uint8"), placeholder_encoded_3: T.Buffer(112, "uint8"), ethosu_write: T.Buffer(43672, "int8")) -> None:
# function attr dict
T.func_attr({"tir.noalias": True, "global_symbol": "main", "from_legacy_te_schedule": True})
ax0_ax1_fused_ax2_fused_ax3_fused = T.var("int32")
ax0_ax1_fused_ax2_fused_ax3_fused_1 = T.var("int32")
ax0_ax1_fused_ax2_fused_ax3_fused_2 = T.var("int32")
ax0_ax1_fused_ax2_fused_ax3_fused_3 = T.var("int32")
nn = T.var("int32")
nn_1 = T.var("int32")
nn_2 = T.var("int32")
nn_3 = T.var("int32")
nn_4 = T.var("int32")
nn_5 = T.var("int32")
ax0_ax1_fused_ax2_fused_ax3_fused = T.int32()
ax0_ax1_fused_ax2_fused_ax3_fused_1 = T.int32()
ax0_ax1_fused_ax2_fused_ax3_fused_2 = T.int32()
ax0_ax1_fused_ax2_fused_ax3_fused_3 = T.int32()
nn = T.int32()
nn_1 = T.int32()
nn_2 = T.int32()
nn_3 = T.int32()
nn_4 = T.int32()
nn_5 = T.int32()
# body
placeholder_d_global = T.decl_buffer([208], "uint8")
placeholder_d_global_1 = T.decl_buffer([112], "uint8")
Expand Down Expand Up @@ -524,16 +524,16 @@ class ModuleAfter:
def main(placeholder: T.Buffer(97156, "int8"), placeholder_encoded: T.Buffer(208, "uint8"), placeholder_encoded_1: T.Buffer(112, "uint8"), placeholder_encoded_2: T.Buffer(96, "uint8"), placeholder_encoded_3: T.Buffer(112, "uint8"), ethosu_write: T.Buffer(43672, "int8")) -> None:
# function attr dict
T.func_attr({"tir.noalias": True, "global_symbol": "main", "from_legacy_te_schedule": True})
ax0_ax1_fused_ax2_fused_ax3_fused = T.var("int32")
ax0_ax1_fused_ax2_fused_ax3_fused_1 = T.var("int32")
ax0_ax1_fused_ax2_fused_ax3_fused_2 = T.var("int32")
ax0_ax1_fused_ax2_fused_ax3_fused_3 = T.var("int32")
nn = T.var("int32")
nn_1 = T.var("int32")
nn_2 = T.var("int32")
nn_3 = T.var("int32")
nn_4 = T.var("int32")
nn_5 = T.var("int32")
ax0_ax1_fused_ax2_fused_ax3_fused = T.int32()
ax0_ax1_fused_ax2_fused_ax3_fused_1 = T.int32()
ax0_ax1_fused_ax2_fused_ax3_fused_2 = T.int32()
ax0_ax1_fused_ax2_fused_ax3_fused_3 = T.int32()
nn = T.int32()
nn_1 = T.int32()
nn_2 = T.int32()
nn_3 = T.int32()
nn_4 = T.int32()
nn_5 = T.int32()
# body
placeholder_d_global = T.decl_buffer([208], "uint8")
placeholder_d_global_1 = T.decl_buffer([112], "uint8")
Expand Down Expand Up @@ -579,15 +579,15 @@ class ModuleBefore:
def main(placeholder: T.Buffer(97156, "int8"), placeholder_encoded: T.Buffer(208, "uint8"), placeholder_encoded_1: T.Buffer(112, "uint8"), placeholder_1: T.Buffer(256, "int8"), placeholder_encoded_2: T.Buffer(96, "uint8"), placeholder_2: T.Buffer(256, "int8"), placeholder_3: T.Buffer(256, "int8"), ethosu_write: T.Buffer(46200, "int8")) -> None:
# function attr dict
T.func_attr({"tir.noalias": True, "global_symbol": "main", "from_legacy_te_schedule": True})
ax0_ax1_fused_ax2_fused_ax3_fused = T.var("int32")
ax0_ax1_fused_ax2_fused_ax3_fused_1 = T.var("int32")
ax0_ax1_fused_ax2_fused_ax3_fused_2 = T.var("int32")
nn = T.var("int32")
nn_1 = T.var("int32")
nn_2 = T.var("int32")
nn_3 = T.var("int32")
nn_4 = T.var("int32")
nn_5 = T.var("int32")
ax0_ax1_fused_ax2_fused_ax3_fused = T.int32()
ax0_ax1_fused_ax2_fused_ax3_fused_1 = T.int32()
ax0_ax1_fused_ax2_fused_ax3_fused_2 = T.int32()
nn = T.int32()
nn_1 = T.int32()
nn_2 = T.int32()
nn_3 = T.int32()
nn_4 = T.int32()
nn_5 = T.int32()
# body
placeholder_d_d_global = T.decl_buffer([208], "uint8")
placeholder_d_d_global_1 = T.decl_buffer([112], "uint8")
Expand Down Expand Up @@ -629,15 +629,15 @@ class ModuleAfter:
def main(placeholder: T.Buffer(97156, "int8"), placeholder_encoded: T.Buffer(208, "uint8"), placeholder_encoded_1: T.Buffer(112, "uint8"), placeholder_1: T.Buffer(256, "int8"), placeholder_encoded_2: T.Buffer(96, "uint8"), placeholder_2: T.Buffer(256, "int8"), placeholder_3: T.Buffer(256, "int8"), ethosu_write: T.Buffer(46200, "int8")) -> None:
# function attr dict
T.func_attr({"tir.noalias": True, "global_symbol": "main", "from_legacy_te_schedule": True})
ax0_ax1_fused_ax2_fused_ax3_fused = T.var("int32")
ax0_ax1_fused_ax2_fused_ax3_fused_1 = T.var("int32")
ax0_ax1_fused_ax2_fused_ax3_fused_2 = T.var("int32")
nn = T.var("int32")
nn_1 = T.var("int32")
nn_2 = T.var("int32")
nn_3 = T.var("int32")
nn_4 = T.var("int32")
nn_5 = T.var("int32")
ax0_ax1_fused_ax2_fused_ax3_fused = T.int32()
ax0_ax1_fused_ax2_fused_ax3_fused_1 = T.int32()
ax0_ax1_fused_ax2_fused_ax3_fused_2 = T.int32()
nn = T.int32()
nn_1 = T.int32()
nn_2 = T.int32()
nn_3 = T.int32()
nn_4 = T.int32()
nn_5 = T.int32()
# body
placeholder_d_d_global = T.decl_buffer([208], "uint8")
placeholder_d_d_global_1 = T.decl_buffer([112], "uint8")
Expand Down
Loading