-
Notifications
You must be signed in to change notification settings - Fork 58
EmitTE Staging Integration
Authors(alphabetical): @altanh, @hypercubestart, @junrushao1994, @tqchen, @YuchenJin, @ZihengJiang
This doc is a sketch engineering plan for the TE integration with Relax. Relax brings support of directly embedding TIR functions through call_tir
. However, it is still hard to manually construct TIR functions through TVMScript.
TE(tensor expression) is a DSL that we traditionally use to construct a lot of the operators. While we are moving towards TIR for scheduling, te is still very useful to serve as a concise staging API to create TIR functions.
Right now we bring an API create_prim_func
that allows us to effectively create a TIR function from TE. The code block below gives an example.
from tvm.script import tir as T
@T.prim_func
def tir_element_wise(a: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (128, 128))
B = T.match_buffer(b, (128, 128))
with T.block([128, 128]) as [i, j]:
B[i, j] = A[i, j] * 2.0
A = te.placeholder((128, 128), name="A")
B = te.compute((128, 128), lambda x, y: A[x, y] * 2, name="B")
func = te.create_prim_func([A, B])
tvm.ir.assert_structural_equal(func, tir_element_wise)
see the list of files below for additional readings:
- https://github.com/apache/tvm/blob/main/src/te/operation/create_primfunc.cc The current logic to create PrimFunc from TE.
- https://github.com/apache/tvm/blob/main/tests/python/unittest/test_te_create_primfunc.py
Importantly, we leverage TE to write a rich collection of operator libraries(e.g. topi). It is useful to reuse these libraries for quick workload creation and operator lowering.
This note outlines an approach to tightly reuse the TE mechanism for Relax ast staging and high-level operator lowering.
The code snippet below shows our design goals:
bb = rx.BlockBuilder()
n, m = tir.var("n"), tir.var("m")
x : rx.Var = rx.var("x", shape=[n, m])
a : rx.Expr = bb.emit(x + 1)
# create a te tensor
# requires a.shape to be a known ShapeExpr, otherwise match shape is needed here
A : te.Tensor= rx.te_tensor(a)
# construct a te expression via arbitrary topi/te calls
# This will create the prim_func that is related to te, and add a call_tir to it
b : rx.Expr = bb.emit(te.compute((n, m), lambda x, y: A[x, y] * 2, name="B"))
# light weight decorator style that does automatic conversion
# (turn rx arguments into rx.te_tensor if needed)
# otherwise, simply pass things as te
# this decorator can be applied to topi functions in the future so they can directly
# be used to stage relax functions as well.
def te_func(X: te.Tensor, Y: te.Tensor):
return te.compute((128, 128), lambda x, y: A[x, y] + B[x, y])
# directly call into te function to be able to construct the
bb.emit_te(te_func, x, a)
The key items are:
- D0: Allow wrap an relax Expr with known symbolic shape as a
te.Tensor
- D1: Reuse te based libraries to create operator impls
- D2: Upon emit, call
create_prim_func
to create the tir function, generate acall_tir
to the tir function.
In order to support this type of DSL properly, there are some special considerations for dynamic shapes. See the following example:
n, m = tir.var("n"), tir.var("m")
x : rx.Var = rx.var("x", shape=[n, 2*floordiv(m, 2)])
A : te.Tensor= rx.te_tensor(a)
B = bb.emit(te.compute([n, 2*floordiv(m, 2)], lambda i, j: A[i, j] + 1))
If we simply follow the above algorithm, the generated call_tir can be shown as follows:
@T.func
def ewise_fun(A, B):
n = tir.var("n")
m = tir.var("m")
A = T.match_buffer(a, (n, 2*floordiv(m, 2)))
B = T.match_buffer(b, (n, 2*floordiv(m, 2)))
with T.block([n, 2 * 2*floordiv(m, 2)]) as [i, j]:
B[i, j] = A[i, j] * 2.0
x : rx.Var = rx.var("x", shape=[n, 2 * 2*floordiv(m, 2)])
B = bb.call_tir([n, 2 * 2*floordiv(m, 2)], ewise_func, [x])
The main problem of the above code is that symbolic variable m
in 2*floordiv(m, 2)
is not defined in ewise_fun
. This is because match_buffer only defines the variable if it is a plain variable n
, but not an expression 2*floordiv(m, 2)
. Note that we can certainly enhance match's behavior to support some complicated patterns like m*2
which is desirable sometimes. However, a complicated pattern like 2*floordiv(m, 2)
would require us to know m
in the beginning of the function
One solution is to apply match shape to create a new symbolic var z
that matches to 2*floordiv(m, 2)
. This would results in loss of information for tir(that the shape is multiple of 2).
Instead, we enhance the generated tir and call_tir_dyn
, to generate the following code instead
@T.func
def ewise_fun(A, B, m: T.int64):
n = tir.var("n")
A = T.match_buffer(a, (n, 2*floordiv(m, 2)))
B = T.match_buffer(b, (n, 2*floordiv(m, 2)))
with T.block([n, 2 * 2*floordiv(m, 2)]) as [i, j]:
B[i, j] = A[i, j] * 2.0
x : rx.Var = rx.var("x", shape=[n, 2*floordiv(m, 2)])
B = bb.call_tir_dyn([n, 2*floordiv(m, 2)], ewise_func, [x], ShapeExpr([m]))
We first infer unbound TIR variables such as m
and add it as an additional parameter to the function. We use an enhanced call_tir_dyn
convention(as follows, by adding symbolic_int_shape) by also allowing a ShapeExpr that indicates the symbolic integer values to unpack and call into the function.
def call_tir_dyn(shape, lowlevel_func, inputs, symbolic_var_tuple):
out = alloc_tensor(shape)
lowlevel_func(*inputs, out, *symbolic_var_tuple)
return out
A separate symbolic_var_tuple is needed mainly because we needed because:
- C0: We need to pass the symbolic int hints after the output value
- C1: The design constraint that relax symbolic integer itself does not appear as an rx.Expr
Note that C1 can be lifted if we eventually decide to allow further mixing of TIR/relax expr, but that would require more design considerations (see also other considerations).
There are a few different ways to realize this in compilation and VM. The approach currently in use is
- Introduce an args unpack convention to registers, this would mean we unpack r1 (when it is a ShapeTuple or Array) before passing into the packed fn. We solve this by introducing a runtime intrinsic (see
vm.call_tir_dyn
).
r1 = [2, 4]
call fn, r0, *r1
# semantics
call fn, r0, 2, 4
An alternative solution is for the unpacked args calling convention to be indicated by register masking and aligns with python's unpack calling conv.
Note that there are many variants that can appear to make the staging and overloading easier. Right now the note outlines a possible first step with manageable workloads. We can think of possible follow-up steps in the future.
List of other possible ideas:
- Allow compute as a macro sugar in the relax script format that macro expands, this is certainly doable. However, as we start to build importer and lowering mechanism, likely we will need a staging builder style API.
- Allow rx.Expr to directly operate like a te.Tensor. This will result in a deeper mixing of TIR/relax expr. This would be a longer term consideration that needs more thoughts over design tradeoffs.