Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix][ONNX] Skip constant If node generated by PyTorch #17383

Merged
merged 10 commits into from
Sep 22, 2024
18 changes: 18 additions & 0 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -4565,6 +4565,23 @@ def _impl_v1(cls, inputs, attr, params):
"Attempting to unify ranks but this may produce incorrect results."
)
warnings.warn(warning_msg)
# Skip constant If node to avoid irrational broadcast
if isinstance(inputs[0], tvm.relay.expr.Constant):
predicate = inputs[0].data.asnumpy()[0]
node_name = attr["tvm_custom"]["name"]
warn_msg_begin = f"Predicate of If node {node_name} is always "
if predicate == np.bool_(True):
warnings.warn(
warn_msg_begin
+ "true so only then branch would be executed. Removing else branch. "
)
else_expr = then_expr
elif predicate == np.bool_(False):
warnings.warn(
warn_msg_begin
+ "false so only else branch would be executed. Removing then branch. "
)
then_expr = else_expr
if len(then_shape) < len(else_shape):
then_expr = _op.broadcast_to_like(then_expr, else_expr)
else:
Expand Down Expand Up @@ -6529,6 +6546,7 @@ def _impl_v11(cls, inputs, attr, params):
# compatible operators that do NOT require any conversion.
_identity_list = []


# _convert_map defines maps of name to converter functor(callable)
# for 1 to 1 mapping, use Renamer if nothing but name is different
# use AttrCvt if attributes need to be converted
Expand Down
Loading