Skip to content

Commit

Permalink
[TVMScript] Simplify TIR Var Definition (#13970)
Browse files Browse the repository at this point in the history
This PR introduces a small tweak to TVMScript printer that simplifies
variable definition in TIR.

Originally, defining a TIR var uses `T.var(dtype)`, e.g.

```python
a = T.var("int32")
```

This PR encourages to shorten the definition to:

```python
a = T.int32()
```

There is no breaking change in this PR, which means the legacy behavior
still works without any problem.
  • Loading branch information
junrushao authored Feb 13, 2023
1 parent dc626f3 commit 82cf9f7
Show file tree
Hide file tree
Showing 28 changed files with 257 additions and 243 deletions.
1 change: 1 addition & 0 deletions python/tvm/script/ir_builder/tir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1393,6 +1393,7 @@ def void(expr: Optional[PrimExpr] = None) -> PrimExpr:
return _ffi_api.Void(expr) # type: ignore[attr-defined] # pylint: disable=no-member


@deprecated("T.var", "T.{dtype}")
def var(dtype: str, name: str = "") -> Var:
"""Construct a new tir.Var.
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/script/parser/tir/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def bind_assign_value(self: Parser, node: doc.expr, var_name: str, value: Any) -
IRBuilder.name(var_name, value)
return value
elif isinstance(value, PrimExpr):
var = T.var(value.dtype)
var = Var("", value.dtype)
IRBuilder.name(var_name, var)
frame = T.let(var, value)
frame.add_callback(partial(frame.__exit__, None, None, None))
Expand Down
40 changes: 20 additions & 20 deletions python/tvm/tir/tensor_intrin/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,8 @@ def ldmatrix_desc(warp_handle: T.handle, shared_handle: T.handle) -> None:

@T.prim_func
def ldmatrix_impl(warp_handle: T.handle, shared_handle: T.handle) -> None:
s0 = T.var("int32")
s1 = T.var("int32")
s0 = T.int32()
s1 = T.int32()
shared = T.match_buffer(
shared_handle,
shmem_shape,
Expand Down Expand Up @@ -385,8 +385,8 @@ def mma_store_desc(a: T.handle, c: T.handle) -> None:

@T.prim_func
def mma_store_impl(a: T.handle, c: T.handle) -> None:
s0 = T.var("int32")
s1 = T.var("int32")
s0 = T.int32()
s1 = T.int32()

C_warp = T.match_buffer(
a, [WARP_SIZE, local_size], dtype=dtype, scope="warp", offset_factor=1
Expand Down Expand Up @@ -530,10 +530,10 @@ def wmma_load_desc(a: T.handle, c: T.handle) -> None:

@T.prim_func
def wmma_load_impl(a: T.handle, c: T.handle) -> None:
s1 = T.var("int32")
s0 = T.var("int32")
d1 = T.var("int32")
d0 = T.var("int32")
s1 = T.int32()
s0 = T.int32()
d1 = T.int32()
d0 = T.int32()
A = T.match_buffer(
a,
(m_dim, n_dim),
Expand Down Expand Up @@ -593,8 +593,8 @@ def wmma_fill_desc(c: T.handle) -> None:

@T.prim_func
def wmma_fill_impl(c: T.handle) -> None:
d1 = T.var("int32")
d0 = T.var("int32")
d1 = T.int32()
d0 = T.int32()
C = T.match_buffer(
c,
(m_dim, n_dim),
Expand Down Expand Up @@ -643,10 +643,10 @@ def wmma_store_desc(a: T.handle, c: T.handle) -> None:

@T.prim_func
def wmma_store_impl(a: T.handle, c: T.handle) -> None:
s1 = T.var("int32")
s0 = T.var("int32")
d1 = T.var("int32")
d0 = T.var("int32")
s1 = T.int32()
s0 = T.int32()
d1 = T.int32()
d0 = T.int32()
A = T.match_buffer(
a,
(m_dim, n_dim),
Expand Down Expand Up @@ -726,12 +726,12 @@ def wmma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None:

@T.prim_func
def wmma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None:
a1 = T.var("int32")
a0 = T.var("int32")
b1 = T.var("int32")
b0 = T.var("int32")
c1 = T.var("int32")
c0 = T.var("int32")
a1 = T.int32()
a0 = T.int32()
b1 = T.int32()
b0 = T.int32()
c1 = T.int32()
c0 = T.int32()

A = T.match_buffer(
a,
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/utils/roofline/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ def estimate_peak_flops(
@T.prim_func
def peak_bandwidth_tir(a: T.handle, b: T.handle, blocks: T.int32, warp_size: T.int32) -> None:
# pylint: disable=invalid-name, missing-function-docstring
N = T.var("int32")
N = T.int32()
A = T.match_buffer(a, [blocks, N, 4, warp_size], "float32")
B = T.match_buffer(b, [blocks, 4, warp_size], "float32")
for i in T.thread_binding(blocks, "blockIdx.x"):
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/utils/roofline/x86.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def estimate_peak_fma_flops(
@T.prim_func
def peak_bandwidth_tir(a: T.handle, b: T.handle, threads: T.int32, vec_width: T.int32) -> None:
# pylint: disable=invalid-name, missing-function-docstring
N = T.var("int32")
N = T.int32()
A = T.match_buffer(a, [threads, N, 4, vec_width], "float32")
B = T.match_buffer(b, [threads, 4, vec_width], "float32")
# Parallelism is necessary to hit all cores/nodes
Expand Down
25 changes: 20 additions & 5 deletions src/script/printer/tir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,29 @@ Doc PrintVar(const tir::Var& var, const ObjectPath& var_p, const IRDocsifier& d)
if (Optional<Frame> opt_f = FindLowestVarDef(var, d)) {
ExprDoc lhs = DefineVar(var, opt_f.value(), d);
Type type = var->type_annotation;
ObjectPath type_p = var_p->Attr("type_annotation");
ExprDoc rhs{nullptr};
if (const auto* ptr_type = type.as<PointerTypeNode>()) {
ICHECK(ptr_type->element_type->IsInstance<PrimTypeNode>());
ExprDoc rhs = d->AsDoc<ExprDoc>(type, var_p->Attr("type_annotation"));
opt_f.value()->stmts.push_back(AssignDoc(lhs, rhs, NullOpt));
const auto* prim_type = ptr_type->element_type.as<PrimTypeNode>();
ICHECK(prim_type);
ExprDoc element_type =
LiteralDoc::DataType(prim_type->dtype, type_p->Attr("element_type")->Attr("dtype"));
rhs = TIR(d, "handle");
rhs->source_paths.push_back(var_p->Attr("dtype"));
if (ptr_type->storage_scope == "") {
rhs = rhs->Call({element_type});
} else {
rhs = rhs->Call({element_type,
LiteralDoc::Str(ptr_type->storage_scope, //
type_p->Attr("storage_scope"))});
}
} else {
ExprDoc rhs = TIR(d, "var")->Call({LiteralDoc::DataType(var->dtype, var_p->Attr("dtype"))});
opt_f.value()->stmts.push_back(AssignDoc(lhs, rhs, NullOpt));
rhs = TIR(d, DType2Str(var->dtype));
rhs->source_paths.push_back(var_p->Attr("dtype"));
rhs = rhs->Call({});
}
rhs->source_paths.push_back(type_p);
opt_f.value()->stmts.push_back(AssignDoc(lhs, rhs, NullOpt));
} else {
LOG(WARNING) << "Didn't find variable definition for: " << var->name_hint;
}
Expand Down
76 changes: 38 additions & 38 deletions tests/python/contrib/test_ethosu/test_copy_compute_reordering.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,16 +476,16 @@ class ModuleBefore:
def main(placeholder: T.Buffer(97156, "int8"), placeholder_encoded: T.Buffer(208, "uint8"), placeholder_encoded_1: T.Buffer(112, "uint8"), placeholder_encoded_2: T.Buffer(96, "uint8"), placeholder_encoded_3: T.Buffer(112, "uint8"), ethosu_write: T.Buffer(43672, "int8")) -> None:
# function attr dict
T.func_attr({"tir.noalias": True, "global_symbol": "main", "from_legacy_te_schedule": True})
ax0_ax1_fused_ax2_fused_ax3_fused = T.var("int32")
ax0_ax1_fused_ax2_fused_ax3_fused_1 = T.var("int32")
ax0_ax1_fused_ax2_fused_ax3_fused_2 = T.var("int32")
ax0_ax1_fused_ax2_fused_ax3_fused_3 = T.var("int32")
nn = T.var("int32")
nn_1 = T.var("int32")
nn_2 = T.var("int32")
nn_3 = T.var("int32")
nn_4 = T.var("int32")
nn_5 = T.var("int32")
ax0_ax1_fused_ax2_fused_ax3_fused = T.int32()
ax0_ax1_fused_ax2_fused_ax3_fused_1 = T.int32()
ax0_ax1_fused_ax2_fused_ax3_fused_2 = T.int32()
ax0_ax1_fused_ax2_fused_ax3_fused_3 = T.int32()
nn = T.int32()
nn_1 = T.int32()
nn_2 = T.int32()
nn_3 = T.int32()
nn_4 = T.int32()
nn_5 = T.int32()
# body
placeholder_d_global = T.decl_buffer([208], "uint8")
placeholder_d_global_1 = T.decl_buffer([112], "uint8")
Expand Down Expand Up @@ -524,16 +524,16 @@ class ModuleAfter:
def main(placeholder: T.Buffer(97156, "int8"), placeholder_encoded: T.Buffer(208, "uint8"), placeholder_encoded_1: T.Buffer(112, "uint8"), placeholder_encoded_2: T.Buffer(96, "uint8"), placeholder_encoded_3: T.Buffer(112, "uint8"), ethosu_write: T.Buffer(43672, "int8")) -> None:
# function attr dict
T.func_attr({"tir.noalias": True, "global_symbol": "main", "from_legacy_te_schedule": True})
ax0_ax1_fused_ax2_fused_ax3_fused = T.var("int32")
ax0_ax1_fused_ax2_fused_ax3_fused_1 = T.var("int32")
ax0_ax1_fused_ax2_fused_ax3_fused_2 = T.var("int32")
ax0_ax1_fused_ax2_fused_ax3_fused_3 = T.var("int32")
nn = T.var("int32")
nn_1 = T.var("int32")
nn_2 = T.var("int32")
nn_3 = T.var("int32")
nn_4 = T.var("int32")
nn_5 = T.var("int32")
ax0_ax1_fused_ax2_fused_ax3_fused = T.int32()
ax0_ax1_fused_ax2_fused_ax3_fused_1 = T.int32()
ax0_ax1_fused_ax2_fused_ax3_fused_2 = T.int32()
ax0_ax1_fused_ax2_fused_ax3_fused_3 = T.int32()
nn = T.int32()
nn_1 = T.int32()
nn_2 = T.int32()
nn_3 = T.int32()
nn_4 = T.int32()
nn_5 = T.int32()
# body
placeholder_d_global = T.decl_buffer([208], "uint8")
placeholder_d_global_1 = T.decl_buffer([112], "uint8")
Expand Down Expand Up @@ -579,15 +579,15 @@ class ModuleBefore:
def main(placeholder: T.Buffer(97156, "int8"), placeholder_encoded: T.Buffer(208, "uint8"), placeholder_encoded_1: T.Buffer(112, "uint8"), placeholder_1: T.Buffer(256, "int8"), placeholder_encoded_2: T.Buffer(96, "uint8"), placeholder_2: T.Buffer(256, "int8"), placeholder_3: T.Buffer(256, "int8"), ethosu_write: T.Buffer(46200, "int8")) -> None:
# function attr dict
T.func_attr({"tir.noalias": True, "global_symbol": "main", "from_legacy_te_schedule": True})
ax0_ax1_fused_ax2_fused_ax3_fused = T.var("int32")
ax0_ax1_fused_ax2_fused_ax3_fused_1 = T.var("int32")
ax0_ax1_fused_ax2_fused_ax3_fused_2 = T.var("int32")
nn = T.var("int32")
nn_1 = T.var("int32")
nn_2 = T.var("int32")
nn_3 = T.var("int32")
nn_4 = T.var("int32")
nn_5 = T.var("int32")
ax0_ax1_fused_ax2_fused_ax3_fused = T.int32()
ax0_ax1_fused_ax2_fused_ax3_fused_1 = T.int32()
ax0_ax1_fused_ax2_fused_ax3_fused_2 = T.int32()
nn = T.int32()
nn_1 = T.int32()
nn_2 = T.int32()
nn_3 = T.int32()
nn_4 = T.int32()
nn_5 = T.int32()
# body
placeholder_d_d_global = T.decl_buffer([208], "uint8")
placeholder_d_d_global_1 = T.decl_buffer([112], "uint8")
Expand Down Expand Up @@ -629,15 +629,15 @@ class ModuleAfter:
def main(placeholder: T.Buffer(97156, "int8"), placeholder_encoded: T.Buffer(208, "uint8"), placeholder_encoded_1: T.Buffer(112, "uint8"), placeholder_1: T.Buffer(256, "int8"), placeholder_encoded_2: T.Buffer(96, "uint8"), placeholder_2: T.Buffer(256, "int8"), placeholder_3: T.Buffer(256, "int8"), ethosu_write: T.Buffer(46200, "int8")) -> None:
# function attr dict
T.func_attr({"tir.noalias": True, "global_symbol": "main", "from_legacy_te_schedule": True})
ax0_ax1_fused_ax2_fused_ax3_fused = T.var("int32")
ax0_ax1_fused_ax2_fused_ax3_fused_1 = T.var("int32")
ax0_ax1_fused_ax2_fused_ax3_fused_2 = T.var("int32")
nn = T.var("int32")
nn_1 = T.var("int32")
nn_2 = T.var("int32")
nn_3 = T.var("int32")
nn_4 = T.var("int32")
nn_5 = T.var("int32")
ax0_ax1_fused_ax2_fused_ax3_fused = T.int32()
ax0_ax1_fused_ax2_fused_ax3_fused_1 = T.int32()
ax0_ax1_fused_ax2_fused_ax3_fused_2 = T.int32()
nn = T.int32()
nn_1 = T.int32()
nn_2 = T.int32()
nn_3 = T.int32()
nn_4 = T.int32()
nn_5 = T.int32()
# body
placeholder_d_d_global = T.decl_buffer([208], "uint8")
placeholder_d_d_global_1 = T.decl_buffer([112], "uint8")
Expand Down
40 changes: 20 additions & 20 deletions tests/python/contrib/test_ethosu/test_merge_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,18 +650,18 @@ class InputModule:
def main(buffer2: T.Buffer((128,), "uint8"), buffer3: T.Buffer((32,), "uint8"), buffer4: T.Buffer((112,), "uint8"), buffer5: T.Buffer((32,), "uint8"), buffer6: T.Buffer((112,), "uint8"), buffer7: T.Buffer((32,), "uint8"), buffer8: T.Buffer((112,), "uint8"), buffer9: T.Buffer((32,), "uint8")) -> None:
# function attr dict
T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
v1a = T.var("int32")
v1b = T.var("int32")
v1c = T.var("int32")
v2a = T.var("int32")
v2b = T.var("int32")
v2c = T.var("int32")
v3a = T.var("int32")
v3b = T.var("int32")
v3c = T.var("int32")
v4a = T.var("int32")
v4b = T.var("int32")
v4c = T.var("int32")
v1a = T.int32()
v1b = T.int32()
v1c = T.int32()
v2a = T.int32()
v2b = T.int32()
v2c = T.int32()
v3a = T.int32()
v3b = T.int32()
v3c = T.int32()
v4a = T.int32()
v4b = T.int32()
v4c = T.int32()
buffer1 = T.Buffer([8192], "int8")
buffer10 = T.Buffer([2048], "int8")
# body
Expand Down Expand Up @@ -713,14 +713,14 @@ class ReferenceModule:
def main(buffer2: T.Buffer((160,), "uint8"), buffer4: T.Buffer((144,), "uint8"), buffer6: T.Buffer((144,), "uint8"), buffer8: T.Buffer((144,), "uint8")) -> None:
# function attr dict
T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
v1a = T.var("int32")
v1c = T.var("int32")
v2a = T.var("int32")
v2c = T.var("int32")
v3a = T.var("int32")
v3c = T.var("int32")
v4a = T.var("int32")
v4c = T.var("int32")
v1a = T.int32()
v1c = T.int32()
v2a = T.int32()
v2c = T.int32()
v3a = T.int32()
v3c = T.int32()
v4a = T.int32()
v4c = T.int32()
buffer1 = T.Buffer([8192], "int8")
buffer10 = T.Buffer([2048], "int8")
# body
Expand Down
12 changes: 6 additions & 6 deletions tests/python/integration/test_lower.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,8 @@ def tensorcore_gemm(handle_a: T.handle, handle_b: T.handle, handle_c: T.handle)
axis_vk * 16 : axis_vk * 16 + 16,
]
)
stride0 = T.var("int32")
stride1 = T.var("int32")
stride0 = T.int32()
stride1 = T.int32()
match_buffer_a0 = T.match_buffer(
shared_a[
new_axis_vi * 16 : new_axis_vi * 16 + 16,
Expand Down Expand Up @@ -198,8 +198,8 @@ def tensorcore_gemm(handle_a: T.handle, handle_b: T.handle, handle_c: T.handle)
axis_vk * 16 : axis_vk * 16 + 16,
]
)
stride0 = T.var("int32")
stride1 = T.var("int32")
stride0 = T.int32()
stride1 = T.int32()
match_buffer_b0 = T.match_buffer(
shared_b[
new_axis_vj * 16 : new_axis_vj * 16 + 16,
Expand Down Expand Up @@ -335,8 +335,8 @@ def tensorcore_gemm(handle_a: T.handle, handle_b: T.handle, handle_c: T.handle)
new_axis_vj * 16 : new_axis_vj * 16 + 16,
]
)
stride0 = T.var("int32")
stride1 = T.var("int32")
stride0 = T.int32()
stride1 = T.int32()
wmma_c2 = T.match_buffer(
wmma_c[
new_axis_vi * 16 : new_axis_vi * 16 + 16,
Expand Down
16 changes: 8 additions & 8 deletions tests/python/unittest/test_aot_legalize_packed_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ def tvm_test_cpacked(

@T.prim_func
def tir_packed_call() -> None:
A = T.var("handle")
B = T.var("handle")
C = T.var("handle")
device_context = T.var("handle")
A = T.handle()
B = T.handle()
C = T.handle()
device_context = T.handle()
# body
T.evaluate(
T.tvm_call_cpacked(
Expand All @@ -65,10 +65,10 @@ def tvm_test_cpacked(

@T.prim_func
def tir_packed_call() -> None:
A = T.var("handle")
B = T.var("handle")
C = T.var("handle")
device_context = T.var("handle")
A = T.handle()
B = T.handle()
C = T.handle()
device_context = T.handle()

# body
T.evaluate(
Expand Down
4 changes: 2 additions & 2 deletions tests/python/unittest/test_arith_domain_touched.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

@T.prim_func
def scalar_func(a: T.handle, b: T.handle):
m = T.var("int32")
m = T.int32()
n = 100
A = T.match_buffer(a, (n, m))
B = T.match_buffer(b, (n, m))
Expand Down Expand Up @@ -73,7 +73,7 @@ def test_domain_touched_vector():

@T.prim_func
def func(a: T.handle, b: T.handle):
n = T.var("int32")
n = T.int32()
A = T.match_buffer(a, (n * m,))
B = T.match_buffer(b, (n * m,))

Expand Down
Loading

0 comments on commit 82cf9f7

Please sign in to comment.