Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relax] Handle binary operations between Tensor and PrimValue #16827

Merged
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 82 additions & 48 deletions python/tvm/relax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,20 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# pylint: disable=invalid-name,too-many-locals

"""Utility functions for Relax"""

import functools
import inspect
import itertools
import string

from typing import Tuple as typing_Tuple
from typing import Any, Callable, List, Dict, Optional, TypeVar

import tvm
from .. import tir
from ..tir import PrimExpr
from ..runtime import String, convert_to_object
Expand Down Expand Up @@ -302,9 +309,23 @@ def gen_call_tir_inputs(
out_sinfo, and tir_vars.
"""

def _convert_te_arg(
te_args: Any, tir_var_map: Dict[tir.Var, tir.PrimExpr]
) -> typing_Tuple[Any, List[te_Tensor]]:
tir_var_map: Dict[tir.Var, tir.PrimExpr] = {}

call_tir_args = []
create_primfunc_args = []
# extra list of tir expression arguments
# that are not covered by Tensor
extra_tir_args_list = []

def _copy_undefined_var(expr: tir.PrimExpr):
def _visit_expr(e: tir.PrimExpr):
if isinstance(e, tir.Var) and e not in tir_var_map:
new_var = tir.Var(e.name, e.dtype)
tir_var_map[e] = new_var

tir.stmt_functor.post_order_visit(expr, _visit_expr)

def _convert_te_arg(te_args: Any) -> Any:
"""Helper function used to convert Relax expressions to TE tensor.

In the common case, the type of te_args is a Relax expression and is converted
Expand Down Expand Up @@ -335,23 +356,8 @@ def _convert_te_arg(
A tuple of the converted te_args, and a list of te tensors for each converted
Relax expression
"""
te_args_list = []
# extra list of tir expression arguments
# that are not covered by Tensor
extra_tir_args_list = []

def _copy_undefined_var(expr: tir.PrimExpr):
def _visit_expr(e: tir.PrimExpr):
if isinstance(e, tir.Var) and e not in tir_var_map:
new_var = tir.Var(e.name, e.dtype)
tir_var_map[e] = new_var

tir.stmt_functor.post_order_visit(expr, _visit_expr)

n_tensor = 0

def _convert_te_arg_helper(arg):
nonlocal n_tensor
if isinstance(arg, Expr): # type: ignore
if isinstance(arg.struct_info, TensorStructInfo):
assert isinstance(
Expand All @@ -360,21 +366,46 @@ def _convert_te_arg_helper(arg):
for shape_value in arg.struct_info.shape.values:
_copy_undefined_var(shape_value)

name = chr(ord("A") + n_tensor) if n_tensor < 26 else f"input{n_tensor}"
arg = te_tensor(arg, tir_var_map, name)
n_tensor += 1
te_args_list.append(arg)
return arg
n_args = len(create_primfunc_args)
if isinstance(arg, tvm.relax.Var):
name = arg.name_hint
elif n_args < len(string.ascii_uppercase):
name = string.ascii_uppercase[n_args]
else:
name = f"tensor_input_{n_args}"

te_arg = te_tensor(arg, tir_var_map, name)

call_tir_args.append(arg)
create_primfunc_args.append(te_arg)

return te_arg

if isinstance(arg.struct_info, ShapeStructInfo):
assert isinstance(
arg, ShapeExpr
), "For Expr having ShapeStructInfo, emit_te now only supports ShapeExpr"
return [_convert_te_arg_helper(val) for val in arg.values]
if (
isinstance(arg.struct_info, PrimStructInfo)
and arg.struct_info.value is not None
):
return _convert_te_arg_helper(arg.struct_info.value)

if isinstance(arg.struct_info, PrimStructInfo):
if arg.struct_info.value is None:
n_args = len(create_primfunc_args)
if isinstance(arg, tvm.relax.Var):
name = arg.name_hint
elif n_args < len(string.ascii_lowercase):
name = string.ascii_lowercase[n_args]
else:
name = f"scalar_input_{n_args}"

tir_param = tir.Var(name, arg.struct_info.dtype)

call_tir_args.append(arg)
create_primfunc_args.append(tir_param)

return tir_param
else:
return _convert_te_arg_helper(arg.struct_info.value)

elif isinstance(arg, (list, Array)):
return [_convert_te_arg_helper(x) for x in arg]
elif isinstance(arg, tuple):
Expand All @@ -395,28 +426,36 @@ def _convert_te_arg_helper(arg):
raise TypeError("not supported type in emit_te: {}".format(type(arg)))

new_arg = _convert_te_arg_helper(te_args)
return new_arg, te_args_list, extra_tir_args_list
return new_arg

def _get_unbound_tir_vars(
args: List[te_Tensor], extra_tir_args: List[PrimExpr]
) -> List[tir.Var]:
"""get unbound TIR vars (i.e TIR vars used in the shape but is not
itself a dimension of a shape)"""

bound_vars = set()
used_vars = set()

def _populate_bound_vars(expr):
if isinstance(expr, te_Tensor):
for dim in expr.shape:
_populate_bound_vars(dim)
elif isinstance(expr, tir.Var):
bound_vars.add(expr)

def _populate_used_vars(expr):
if isinstance(expr, tir.Var):
used_vars.add(expr)
if isinstance(expr, te_Tensor):
for dim in expr.shape:
_populate_used_vars(dim)
elif isinstance(expr, tir.PrimExpr):
used_vars.update(tir.analysis.undefined_vars(expr))

for val in extra_tir_args:
tir.stmt_functor.post_order_visit(val, _populate_used_vars)
for arg in itertools.chain(args, extra_tir_args):
_populate_used_vars(arg)

for x in args:
for s in x.shape:
tir.stmt_functor.post_order_visit(s, _populate_used_vars)
if isinstance(s, tir.Var):
bound_vars.add(s)
for arg in args:
_populate_bound_vars(arg)

diff = used_vars - bound_vars
return list(diff)
Expand Down Expand Up @@ -448,30 +487,25 @@ def _shape_with_old_tir_var(

primfunc_attrs = kwargs.pop("primfunc_attrs", None)

tir_var_map: Dict[tir.Var, tir.PrimExpr] = {}
new_args, te_arg_list, tir_arg_list = _convert_te_arg(args, tir_var_map)
new_kwargs, te_kwarg_list, tir_kwarg_list = _convert_te_arg(kwargs, tir_var_map)

te_args = te_arg_list + te_kwarg_list
te_args = _convert_te_arg(args)
te_kwargs = _convert_te_arg(kwargs)

te_out = func(*new_args, **new_kwargs)
te_out = func(*te_args, **te_kwargs)
assert isinstance(te_out, te_Tensor) or (
isinstance(te_out, (tuple, list, Array)) and all(isinstance(t, te_Tensor) for t in te_out)
), "only support te.tensor or tuple/list/Array of te.tensor as function output"

outs = [te_out] if isinstance(te_out, te_Tensor) else list(te_out)
unbound_tir_vars = _get_unbound_tir_vars(te_args + outs, tir_arg_list + tir_kwarg_list)
unbound_tir_vars = _get_unbound_tir_vars([*create_primfunc_args, *outs], extra_tir_args_list)

inputs = [*te_args] + outs + unbound_tir_vars
inputs = [*create_primfunc_args] + outs + unbound_tir_vars
tir_func = create_prim_func(inputs, "int64")

if primfunc_attrs:
tir_func = tir_func.with_attrs(primfunc_attrs)

tir_func = tir_func.without_attr("global_symbol")

call_tir_args = [x.op.value for x in te_args]

# Invert the TIR variable mapping, to convert the output shape back
# with old set of variables.
tir_var_inverse_map = {v: k for k, v in tir_var_map.items()}
Expand Down
81 changes: 60 additions & 21 deletions src/relax/op/op_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -239,52 +239,91 @@ InferLayoutOutput InferLayoutUnaryEwise(const Call& call,
const Map<String, Array<String>>& desired_layouts,
const VarLayoutMap& var_layout_map);

/*!
* \brief Get the element dtype from StructInfo
*
* \param sinfo The StructInfo to expect
* \return The inferred element dtype.
* \throw Throw exception if the StructInfo doesn't have an element type.
*/
inline DataType GetElementDType(const StructInfo& sinfo) {
if (const auto* prim = sinfo.as<PrimStructInfoNode>()) {
return prim->dtype;
} else if (const auto* tensor = sinfo.as<TensorStructInfoNode>()) {
return tensor->dtype;
} else {
LOG(FATAL) << "TypeError: "
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Originally our error message would ask for TensorStructInfo. In this particular case, would this error message be less informative than before? Given this is a global change across all binary ops, would be good to cross confirm the usages here and make error more informative.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, and this no longer tells the user which operation it was. Updated.

<< "Only PrimStructInfo and TensorStructInfo "
<< "have an associated data type. "
<< "Cannot determine element type of " << sinfo;
}
}

/*!
* \brief Infer the output datatype for binary arithmetic operators.
* \param call The context Call to the operator.
* \param ctx The error reporting context.
* \param x1_sinfo The struct info of the first operand
* \param x2_sinfo The struct info of the second operand
* \param lhs_sinfo The struct info of the first operand
* \param rhs_sinfo The struct info of the second operand
* \return The inferred output dtype.
* \throw Throw exception if the dtype of two input TensorStructInfo don’t match
*/
inline DataType InferBinaryArithOpOutDtype(const Call& call, const BlockBuilder& ctx,
const TensorStructInfo& x1_sinfo,
const TensorStructInfo& x2_sinfo) {
if (x1_sinfo->IsUnknownDtype() || x2_sinfo->IsUnknownDtype()) {
const StructInfo& lhs_sinfo,
const StructInfo& rhs_sinfo) {
auto lhs_dtype = GetElementDType(lhs_sinfo);
auto rhs_dtype = GetElementDType(rhs_sinfo);
if (lhs_dtype.is_void() || rhs_dtype.is_void()) {
return DataType::Void();
} else if (x1_sinfo->dtype != x2_sinfo->dtype) {
} else if (lhs_dtype != rhs_dtype) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "Data types " << x1_sinfo->dtype << " and " << x2_sinfo->dtype
<< " must be equal for binary operators");
<< "TypeErorr: "
<< "Binary operators must have the same datatype for both operands. "
<< "However, " << call << " uses datatype " << lhs_dtype
<< " on the LHS (StructInfo of " << lhs_sinfo << "), and datatype "
<< rhs_dtype << " on the RHS (StructInfo of " << rhs_sinfo << ").");
}
return x1_sinfo->dtype;
return lhs_dtype;
}

/*!
* \brief Infer the output virtual device for binary arithmetic operators.
* \param call The context Call to the operator.
* \param ctx The error reporting context.
* \param x1_sinfo The struct info of the first operand
* \param x2_sinfo The struct info of the second operand
* \param lhs_sinfo The struct info of the first operand
* \param rhs_sinfo The struct info of the second operand
* \return The inferred output vdevice.
* \throw Throw exception if the vdevice of two input TensorStructInfo don’t match
*/
inline Optional<VDevice> InferBinaryArithOpOutVDevice(const Call& call, const BlockBuilder& ctx,
const TensorStructInfo& x1_sinfo,
const TensorStructInfo& x2_sinfo) {
if (!x1_sinfo->vdevice.defined() || !x1_sinfo->vdevice.value()->target.defined()) {
return x2_sinfo->vdevice;
const StructInfo& lhs_sinfo,
const StructInfo& rhs_sinfo) {
auto get_vdevice = [&](const StructInfo& sinfo) -> Optional<VDevice> {
if (const auto* tensor = sinfo.as<TensorStructInfoNode>()) {
return tensor->vdevice;
} else {
return NullOpt;
}
};

auto lhs_vdevice = get_vdevice(lhs_sinfo);
auto rhs_vdevice = get_vdevice(rhs_sinfo);

if (!lhs_vdevice.defined() || !lhs_vdevice.value()->target.defined()) {
return rhs_vdevice;
}
if (!x2_sinfo->vdevice.defined() || !x2_sinfo->vdevice.value()->target.defined()) {
return x1_sinfo->vdevice;
if (!rhs_vdevice.defined() || !rhs_vdevice.value()->target.defined()) {
return lhs_vdevice;
}
if (x1_sinfo->vdevice.value() != x2_sinfo->vdevice.value()) {
if (lhs_vdevice.value() != rhs_vdevice.value()) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "VDevice " << x1_sinfo->vdevice.value() << " and "
<< x2_sinfo->vdevice.value() << " must be equal for binary operators");
<< "TypeErorr: "
<< "Binary operators with Tensor arguments "
<< "must have the same VDevice for both operands. "
<< "However, " << call << " has a LHS on VDevice " << lhs_vdevice
<< " and a RHS on VDevice " << rhs_vdevice);
}
return x1_sinfo->vdevice;
return lhs_vdevice;
}

/*!
Expand Down
Loading
Loading