Skip to content

Commit

Permalink
[Relax][Frontent] "tensor_ir_inplace" op (#16498)
Browse files Browse the repository at this point in the history
This PR introduces the `tensor_ir_inplace_op` for frontend
so that we can leverage our `call_tir_inplace` in SLM model
definition flow.

One unit test is added. This PR also fixed a few typos in
type annotations.
  • Loading branch information
MasterJH5574 authored Feb 2, 2024
1 parent 7a303d9 commit 5c68932
Show file tree
Hide file tree
Showing 3 changed files with 184 additions and 10 deletions.
69 changes: 69 additions & 0 deletions python/tvm/relax/frontend/nn/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1629,6 +1629,75 @@ def tensor_ir_op(
)


def tensor_ir_inplace_op(
func: _tir.PrimFunc,
name_hint: str,
args: Union[Tensor, Sequence[Union[Tensor, rx.ShapeExpr, _tir.PrimExpr]]],
inplace_indices: Union[int, List[int]],
out: OutType,
) -> OutType:
"""Create a `call_tir_inplace` binding with given PrimFunc
Parameters
----------
func : _tir.PrimFunc
The PrimFunc to call.
name_hint : str
Name hint.
args : Union[Tensor, Sequence[Union[Tensor, rx.ShapeExpr, _tir.PrimExpr]]]
The arguments to pass to the PrimFunc.
inplace_indices : Union[int, List[int]]
Specify which arguments should be used for in-place computations.
If `inplace_indices` is a single integer, it will be made into a singleton list.
Suppose `inplace_indices[i] = j`, where `j >= 0`. Then the `i`th output
will be an alias of `args[j]`.
If `inplace_indices[i] = -1`, then the `i`th output will be a freshly allocated tensor.
At least one member of `inplace_indices` must not be -1.
out : Union[Tensor, List[Tensor]]
The output tensors.
Returns
-------
result : Tensor
The result tensor
"""
from tvm import relax as rx # pylint: disable=import-outside-toplevel

call_tir_args, tir_vars = [], []
if not isinstance(args, (tuple, list)):
args = [args]

for arg in args:
if isinstance(arg, Tensor):
call_tir_args.append(arg._expr)
elif isinstance(arg, (rx.ShapeExpr, _tir.PrimExpr)):
tir_vars.append(arg)
else:
raise TypeError(
"Unsupported type: tensor_ir_inplace_op args expect Tensor or ShapeExpr or"
f" PrimExpr, but got {type(arg)}"
)

if isinstance(out, Tensor):
out_sinfo = [out._expr.struct_info]
else:
out_sinfo = [x._expr.struct_info for x in out]

bb = BlockBuilder.current()
global_var = bb.add_func(func, name_hint)

return wrap_nested(
bb.emit(
rx.call_tir_inplace(global_var, call_tir_args, inplace_indices, out_sinfo, tir_vars)
),
name=name_hint,
)


def extern(
name: str,
args: Sequence[Union[Tensor, _tir.PrimExpr, int, float, str]],
Expand Down
20 changes: 10 additions & 10 deletions python/tvm/relax/op/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,13 +198,13 @@ def call_tir_inplace(
args : Expr
The input arguments.
input_indices : Union[int, List[int]]
inplace_indices : Union[int, List[int]]
Specify which arguments should be used for in-place computations.
If `input_indices` is a single integer, it will be made into a singleton list.
Suppose `input_indices[i] = j`, where `j >= 0`. Then the `i`th output
If `inplace_indices` is a single integer, it will be made into a singleton list.
Suppose `inplace_indices[i] = j`, where `j >= 0`. Then the `i`th output
will be an alias of `args[j]`.
If `input_indices[i] = -1`, then the `i`th output will be a freshly allocated tensor.
At least one member of `input_indices` must not be -1.
If `inplace_indices[i] = -1`, then the `i`th output will be a freshly allocated tensor.
At least one member of `inplace_indices` must not be -1.
out_sinfo : Union[TensorStructInfo, List[TensorStructInfo]]
The structure info of the call_tir_inplace output.
Expand Down Expand Up @@ -637,13 +637,13 @@ def call_inplace_packed(
args: Expr
The arguments for the PackedFunc.
input_indices : Union[int, List[int]]
inplace_indices : Union[int, List[int]]
Specify which arguments should be used for in-place computations.
If `input_indices` is a single integer, it will be made into a singleton list.
Suppose `input_indices[i] = j`, where `j >= 0`. Then the `i`th output
If `inplace_indices` is a single integer, it will be made into a singleton list.
Suppose `inplace_indices[i] = j`, where `j >= 0`. Then the `i`th output
will be an alias of `args[j]`.
If `input_indices[i] = -1`, then the `i`th output will be a freshly allocated tensor.
At least one member of `input_indices` must not be -1.
If `inplace_indices[i] = -1`, then the `i`th output will be a freshly allocated tensor.
At least one member of `inplace_indices` must not be -1.
sinfo_args: Union[StructInfo, List[StructInfo]]
The list of structure info arguments (giving the structural info for the returned value).
Expand Down
105 changes: 105 additions & 0 deletions tests/python/relax/test_frontend_nn_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,111 @@ def test(qkv: R.Tensor((1, 1, 24, 16), dtype="float16"), offset: R.Shape(["offse
tvm.ir.assert_structural_equal(irmodule, Expected)


def test_tensor_ir_inplace_op():
hidden_size = 4096
dtype = "float16"

@T.prim_func
def inplace_take(
var_weight: T.handle, var_pos: T.handle, var_embeddings: T.handle, offset: T.int64
):
T.func_attr({"tir.noalias": T.bool(True)})
vocab_size = T.int64()
weight = T.match_buffer(var_weight, (vocab_size, hidden_size), dtype)
seq_len = T.int64()
total_seq_len = T.int64()
pos = T.match_buffer(var_pos, (seq_len,), "int32")
embeddings = T.match_buffer(var_embeddings, (total_seq_len, hidden_size), dtype)
for ax0, ax1 in T.grid(seq_len, hidden_size):
with T.block("T_take"):
v0, v1 = T.axis.remap("SS", [ax0, ax1])
T.reads(weight[pos[v0], v1], pos[v0])
T.writes(embeddings[v0, v1])
embeddings[v0 + offset, v1] = weight[pos[v0], v1]

class Model(Module):
def test(
self, embedding_table: Tensor, input_ids: Tensor, embedding_dst: Tensor, offset: int
):
tensor_expr_op_out = op.tensor_ir_op(
inplace_take,
"inplace_take",
args=[embedding_table, input_ids, embedding_dst, offset],
out=Tensor.placeholder(embedding_dst.shape, embedding_dst.dtype),
)
return tensor_expr_op_out

@I.ir_module
class Expected:
@T.prim_func
def inplace_take(
var_weight: T.handle, var_pos: T.handle, var_embeddings: T.handle, offset: T.int64
):
T.func_attr({"tir.noalias": T.bool(True)})
vocab_size = T.int64()
weight = T.match_buffer(var_weight, (vocab_size, hidden_size), dtype)
seq_len = T.int64()
total_seq_len = T.int64()
pos = T.match_buffer(var_pos, (seq_len,), "int32")
embeddings = T.match_buffer(var_embeddings, (total_seq_len, hidden_size), dtype)
for ax0, ax1 in T.grid(seq_len, hidden_size):
with T.block("T_take"):
v0, v1 = T.axis.remap("SS", [ax0, ax1])
T.reads(weight[pos[v0], v1], pos[v0])
T.writes(embeddings[v0, v1])
embeddings[v0 + offset, v1] = weight[pos[v0], v1]

@R.function
def _initialize_effect() -> R.Tuple(R.Object):
with R.dataflow():
_io: R.Object = R.null_value()
lv: R.Tuple(R.Object) = (_io,)
gv: R.Tuple(R.Object) = lv
R.output(gv)
return gv

@R.function
def test(
embedding_table: R.Tensor(("vocab_size", hidden_size), dtype),
input_ids: R.Tensor(("seq_len",), "int32"),
embedding_dst: R.Tensor(("total_seq_len", hidden_size), dtype),
offset: R.Shape(["offset_1"]),
packed_params: R.Tuple,
) -> R.Tensor(("total_seq_len", hidden_size), dtype):
total_seq_len = T.int64()
offset_1 = T.int64()
R.func_attr({"num_input": 4})
cls = Expected
with R.dataflow():
lv1 = R.call_tir(
cls.inplace_take,
(embedding_table, input_ids, embedding_dst),
out_sinfo=R.Tensor((total_seq_len, hidden_size), dtype),
tir_vars=R.shape([offset_1]),
)
gv1: R.Tensor((total_seq_len, hidden_size), dtype) = lv1
R.output(gv1)
return gv1

m = Model()
irmodule, _ = m.export_tvm(
spec={
"test": {
"embedding_table": spec.Tensor(["vocab_size", hidden_size], dtype),
"input_ids": spec.Tensor(["seq_len"], "int32"),
"embedding_dst": spec.Tensor(["total_seq_len", hidden_size], dtype),
"offset": int,
"$": {
"param_mode": "packed",
"effect_mode": "none",
},
},
},
debug=True,
)
tvm.ir.assert_structural_equal(irmodule, Expected)


def test_extern():
class Model(Module):
def test(self, q: Tensor, k: Tensor, v: Tensor):
Expand Down

0 comments on commit 5c68932

Please sign in to comment.