Skip to content

Commit

Permalink
[Unity][TVMScript] Use explicit R.shape in TVMScript (apache#13979)
Browse files Browse the repository at this point in the history
As we've introduced `arg_sinfo` in CallNode, implicit shape constructor
is not widely used in TVMScript. This PR removes the implicit shape since
it may cause confusion between shape and tuple.
  • Loading branch information
Hzfengsy authored Feb 14, 2023
1 parent 745222e commit 1bf0f1b
Show file tree
Hide file tree
Showing 11 changed files with 93 additions and 43 deletions.
16 changes: 3 additions & 13 deletions python/tvm/relax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from ..runtime import String, convert_to_object
from ..tir import PrimExpr
from . import _ffi_api
from .expr import Expr, Function, PrimValue, ShapeExpr, StringImm
from .expr import Expr, Function, PrimValue, StringImm
from .expr import Tuple as rx_Tuple


Expand Down Expand Up @@ -74,14 +74,12 @@ def convert_to_expr(value: Any) -> Expr:
1. Return the input itself if it's already a `relax.Expr`;
2. Return `relax.PrimValue` if the input is a `PrimExpr`;
3. Return `relax.StringImm` if the input is `tvm.String` or `str`;
4. Return `relax.ShapeExpr` if the input is a tuple/list of `PrimExpr` w/ int dtype;
5. Return `relax.Tuple` if the input is a tuple/list of `Expr`.
4. Return `relax.Tuple` if the input is a tuple/list of `Expr`.
Notes
-----
1. `tvm.tir.StringImm` is not allowed because of ambiguity,
which can be either `relax.StringImm` or `relax.PrimValue`.
2. We regard empty tuple/list as `relax.Tuple` instead of `relax.ShapeExpr`
"""
if isinstance(value, int):
return PrimValue(tir.IntImm("int64", value))
Expand All @@ -102,16 +100,8 @@ def convert_to_expr(value: Any) -> Expr:
# Case 3
if isinstance(tvm_value, String):
return StringImm(value)
# Case 4 & 5
# Case 4
if isinstance(value, (tuple, list)):
# Note 2
if len(value) == 0:
return rx_Tuple([])
# Case 4
opt_prim_value = [convert_to_object(v) for v in value]
if all([isinstance(v, PrimExpr) and v.dtype.startswith("int") for v in opt_prim_value]):
return ShapeExpr(value)
# Case 5
# `convert_to_expr` ensures that all elements are `Expr` if no exception raises
return rx_Tuple([convert_to_expr(v) for v in value])
raise TypeError(f"Cannot convert {value} with type {type(value)} to `relax.Expr`")
Expand Down
18 changes: 18 additions & 0 deletions python/tvm/script/ir_builder/relax/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,23 @@ def tuple(*fields: Expr) -> Expr:
return relax.Tuple(fields) # type: ignore[attr-defined] # pylint: disable=no-member


############################### R.shape ################################


def shape(value: List[PrimExpr]) -> Expr:
"""Create a ShapeExpr.
Parameters
----------
value : List[PrimExpr]
The fields of the tuple.
Returns
-------
res : Expr
The result tuple.
"""
return relax.ShapeExpr(value) # pylint: disable=no-member # type: ignore


############################### PrimValue ##############################


Expand Down Expand Up @@ -407,6 +424,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr:
"prim_value",
"print",
"reshape",
"shape",
"shape_of",
"str",
"tuple",
Expand Down
22 changes: 17 additions & 5 deletions python/tvm/script/parser/relax/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from tvm.relax import (
Expr,
ShapeExpr,
FuncStructInfo,
Function,
ObjectStructInfo,
Expand Down Expand Up @@ -84,24 +85,31 @@ class TensorProxy(StructInfoProxy):

def __init__(
self,
shape: Optional[List[Union[PrimExpr, str]]] = None,
shape: Optional[Union[List[Union[PrimExpr, str]], Expr]] = None,
dtype: Optional[str] = None,
ndim: int = -1,
) -> None:
self.shape = shape
if isinstance(shape, Expr) and not isinstance(shape, ShapeExpr):
raise ValueError(
"Only ShapeExpr is allowed as shape expr, but got: "
f"{shape} with type: {type(shape)}"
)
self.dtype = dtype
self.ndim = ndim
super().__init__()

def get_symbolic_vars(self) -> Set[str]:
if self.shape is None:
if self.shape is None or isinstance(self.shape, Expr):
return {}
else:
return {s for s in self.shape if isinstance(s, str) and s.isidentifier()}

def as_struct_info(self, dict_globals: Optional[Dict[str, Any]] = None) -> TensorStructInfo:
if self.shape is None:
return TensorStructInfo(None, self.dtype, self.ndim)
elif isinstance(self.shape, ShapeExpr):
return TensorStructInfo(self.shape, self.dtype, self.ndim)
else:
if dict_globals is None and any([isinstance(s, str) for s in self.shape]):
raise ValueError(
Expand All @@ -113,7 +121,7 @@ def as_struct_info(self, dict_globals: Optional[Dict[str, Any]] = None) -> Tenso


def Tensor(
shape: Optional[List[Union[PrimExpr, str]]] = None,
shape: Optional[Union[List[Union[PrimExpr, str]], ShapeExpr]] = None,
dtype: Optional[str] = None,
ndim: int = -1,
) -> TensorProxy:
Expand All @@ -124,8 +132,12 @@ def Tensor(
dtype = shape
shape = None

if shape is not None and not isinstance(shape, (tuple, list)):
raise ValueError(f"shape must be a list or tuple, but got: {shape}")
if (
shape is not None
and not isinstance(shape, (tuple, list))
and not isinstance(shape, ShapeExpr)
):
raise ValueError(f"shape must be a list/tuple or a ShapeExpr, but got: {shape}")
return TensorProxy(shape, dtype, ndim)


Expand Down
2 changes: 1 addition & 1 deletion src/script/printer/relax/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
for (int i = 0, l = n->values.size(); i < l; ++i) {
values_doc.push_back(PrintShapeVar(n->values[i], values_p->ArrayIndex(i), d));
}
return TupleDoc(values_doc);
return Relax(d, "shape")->Call({ListDoc(values_doc)});
});

Optional<ExprDoc> SpecialScalar(const runtime::NDArray& n, const ObjectPath& p) {
Expand Down
14 changes: 13 additions & 1 deletion src/script/printer/relax/struct_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,19 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
Array<String> kwargs_keys;
Array<ExprDoc> kwargs_values;
if (n->shape.defined()) {
args.push_back(d->AsDoc<ExprDoc>(n->shape.value(), n_p->Attr("shape")));
// Need to dig into ShapeExpr to preserve the `R.shape` prefix
if (const auto* shape = n->shape.value().as<relax::ShapeExprNode>()) {
auto shape_expr = GetRef<relax::ShapeExpr>(shape);
ObjectPath shape_p = n_p->Attr("shape")->Attr("values");
Array<ExprDoc> shape_docs;
for (int i = 0, ndim = shape_expr->values.size(); i < ndim; ++i) {
shape_docs.push_back(
PrintShapeVar(shape_expr->values[i], shape_p->ArrayIndex(i), d));
}
args.push_back(TupleDoc(shape_docs));
} else {
args.push_back(d->AsDoc<ExprDoc>(n->shape.value(), n_p->Attr("shape")));
}
}
if (!n->IsUnknownDtype()) {
kwargs_keys.push_back("dtype");
Expand Down
2 changes: 1 addition & 1 deletion tests/python/relax/test_backend_transform_shape_lower.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def main(
n = T.Var("n", "int64")
k = T.Var("k", "int64")
z = R.match_cast(y, R.Tensor([k, m, k + 1], dtype=None))
return (k + 1, m, 2)
return R.shape([k + 1, m, 2])

# slot assignment:
# 0: n, 1: m, 2:k, 3: k+1
Expand Down
2 changes: 1 addition & 1 deletion tests/python/relax/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ class TestVMBuiltinLower:
@R.function
def foo(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor:
m, n = T.var("int64"), T.var("int64")
alloc = R.builtin.alloc_tensor((m, n), runtime_device_index=0, dtype="float32")
alloc = R.builtin.alloc_tensor(R.shape([m, n]), runtime_device_index=0, dtype="float32")
_ = R.call_packed(
"test.op.identity", x, alloc, sinfo_args=(R.Tensor(ndim=2, dtype="float32"))
)
Expand Down
36 changes: 26 additions & 10 deletions tests/python/relax/test_tvmscript_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,9 @@
import tvm.script
import tvm.testing
from tvm import IRModule, relax, tir, topi
from tvm.relax import DynTensorType
from tvm.script import ir as I
from tvm.script import relax as R
from tvm.script import tir as T
from tvm.script.parser import ir as I
from tvm.script.parser import relax as R
from tvm.script.parser import tir as T


def _check(
Expand Down Expand Up @@ -202,6 +201,23 @@ def foo(x: R.Tensor((4, 4), "float32")) -> R.Tensor((4, 4), "float32"):
_check(foo, bb.get()["foo"])


def test_relax_base_op():
@R.function
def foo(x: R.Tensor((4, 4), "float32")):
alloc = R.builtin.alloc_tensor(R.shape([4, 4]), runtime_device_index=0, dtype="float32")
shape = R.shape_of(alloc)
return shape

x = relax.Var("x", R.Tensor((4, 4), "float32"))
bb = relax.BlockBuilder()
with bb.function("foo", (x,)):
alloc = bb.emit(relax.op.builtin.alloc_tensor(relax.ShapeExpr((4, 4)), "float32", 0))
shape = bb.emit(relax.op.shape_of(alloc))
bb.emit_func_output(shape)
# todo(yongwww): comment this check because 0 was changed to R.prim_value(0) in the printed IR
# _check(foo, bb.get()["foo"])


def test_symbolic_shape():
@R.function
def foo(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"):
Expand Down Expand Up @@ -274,7 +290,7 @@ def foo(x: R.Tensor("float32"), y: R.Tensor("float32")):
y0 = R.match_cast(y, R.Tensor([n], "float32"))
gv = y0
R.output(gv)
return (x0, (m, n * 2))
return (x0, R.shape([m, n * 2]))

x = relax.Var("x", R.Tensor("float32"))
y = relax.Var("y", R.Tensor("float32"))
Expand Down Expand Up @@ -314,7 +330,7 @@ def test_tuple_return_2():
def foo(x: R.Tensor("float32", ndim=2)):
n, m = T.var("int64"), T.var("int64")
x0 = R.match_cast(x, R.Tensor((n, m), "float32"))
return (x0, (n + 1, m, 1))
return (x0, R.shape([n + 1, m, 1]))

x = relax.Var("x", R.Tensor("float32", ndim=2))
n, m = tir.Var("n", "int64"), tir.Var("m", "int64")
Expand All @@ -332,7 +348,7 @@ def foo(x: R.Tensor("float32", ndim=2)):
n, m = T.var("int64"), T.var("int64")
x0 = R.match_cast(x, R.Tensor((n, m), "float32"))
t0 = (x, x0)
t1 = (x, (n, m), t0)
t1 = (x, R.shape([n, m]), t0)
return t1

x = relax.Var("x", R.Tensor("float32", ndim=2))
Expand Down Expand Up @@ -965,9 +981,9 @@ def test_vm_ops():
def foo(x: R.Tensor(("m", "n"), dtype="float32")):
m = T.var("int64")
n = T.var("int64")
storage = R.vm.alloc_storage((4 * m * n,), dtype="float32", runtime_device_index=0)
alloc = R.vm.alloc_tensor(storage, (m, n), offset=0, dtype="float32")
tensor = R.builtin.alloc_tensor((m, n), dtype="float32", runtime_device_index=0)
storage = R.vm.alloc_storage(R.shape([4 * m * n]), dtype="float32", runtime_device_index=0)
alloc = R.vm.alloc_tensor(storage, shape=R.shape([m, n]), offset=0, dtype="float32")
tensor = R.builtin.alloc_tensor(R.shape([m, n]), dtype="float32", runtime_device_index=0)
_ = R.vm.call_tir_dyn("te_func", (x, tensor, (m, n)))
gv = tensor
return alloc, gv
Expand Down
4 changes: 2 additions & 2 deletions tests/python/relax/test_tvmscript_printer_relax.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ def test_tuple_get_item():

def test_shape_expr():
obj = relax.ShapeExpr([1, 2, 3])
_assert_print(obj, "(1, 2, 3)")
_assert_print(obj, "R.shape([1, 2, 3])")


def test_call():
Expand All @@ -304,7 +304,7 @@ def test_call():
"""
x = T.Var("x", "int64")
a: R.Tensor((1, x, 3), dtype="float32")
R.call_tir("my_func", (a,), out_sinfo=R.Tensor((1, x, 3), dtype="float32"), tir_vars=(x,))
R.call_tir("my_func", (a,), out_sinfo=R.Tensor((1, x, 3), dtype="float32"), tir_vars=R.shape([x]))
""",
)

Expand Down
6 changes: 3 additions & 3 deletions tests/python/relax/test_vm_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ class TestVMCompileStage2:
def foo(x: R.Tensor(dtype="float32")) -> R.Shape:
n, m = T.var("int64"), T.var("int64")
_ = R.match_cast(x, R.Tensor((n, m), "float32"))
return (n * 2, m * 3)
return R.shape([n * 2, m * 3])

mod = TestVMCompileStage2
target = tvm.target.Target("llvm", host="llvm")
Expand Down Expand Up @@ -511,9 +511,9 @@ class TestMemoryAllocStorageTensor:
@R.function
def main(x: R.Tensor((2, 3), dtype="float32")):
storage = R.memory.alloc_storage(
(24,), virtual_device_index=0, storage_scope="global", dtype="float32"
R.shape([24]), virtual_device_index=0, storage_scope="global", dtype="float32"
)
y = R.memory.alloc_tensor(storage, 0, (2, 3), dtype="float32")
y = R.memory.alloc_tensor(storage, 0, R.shape([2, 3]), dtype="float32")
_ = copy(x, y)
return y

Expand Down
14 changes: 8 additions & 6 deletions tests/python/relax/test_vm_codegen_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@
Restrictions: all shape lowered, explicit allocation.
"""
import tvm
import pytest
import numpy as np
from tvm import relax, TVMError
from tvm.script import relax as R, tir as T
import pytest
import tvm
import tvm.testing
from tvm import relax
from tvm.relax.testing.runtime_builtin import MakeShapeCode, MatchShapeCode
from tvm.relax.testing.vm import check_saved_func
from tvm.relax.testing.runtime_builtin import MatchShapeCode, MakeShapeCode
from tvm.script import relax as R
from tvm.script import tir as T

EXEC_MODE = ["bytecode"]

Expand Down Expand Up @@ -312,7 +314,7 @@ class TestVMBuiltinReshape:
def main(x: R.Tensor((3, 4), "float32")):
R.func_attr({"global_symbol": "main"})
y = R.call_packed(
"vm.builtin.reshape", x, (6, 2), sinfo_args=R.Tensor((6, 2), "float32")
"vm.builtin.reshape", x, R.shape([6, 2]), sinfo_args=R.Tensor((6, 2), "float32")
)
return y

Expand Down

0 comments on commit 1bf0f1b

Please sign in to comment.