diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index ee7a5d6b329a..8da8a5b11262 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -4565,6 +4565,23 @@ def _impl_v1(cls, inputs, attr, params): "Attempting to unify ranks but this may produce incorrect results." ) warnings.warn(warning_msg) + # Skip constant If node to avoid irrational broadcast + if isinstance(inputs[0], tvm.relay.expr.Constant): + predicate = inputs[0].data.asnumpy()[0] + node_name = attr["tvm_custom"]["name"] + warn_msg_begin = f"Predicate of If node {node_name} is always " + if predicate == np.bool_(True): + warnings.warn( + warn_msg_begin + + "true so only then branch would be executed. Removing else branch. " + ) + else_expr = then_expr + elif predicate == np.bool_(False): + warnings.warn( + warn_msg_begin + + "false so only else branch would be executed. Removing then branch. " + ) + then_expr = else_expr if len(then_shape) < len(else_shape): then_expr = _op.broadcast_to_like(then_expr, else_expr) else: @@ -6529,6 +6546,7 @@ def _impl_v11(cls, inputs, attr, params): # compatible operators that do NOT require any conversion. _identity_list = [] + # _convert_map defines maps of name to converter functor(callable) # for 1 to 1 mapping, use Renamer if nothing but name is different # use AttrCvt if attributes need to be converted