Skip to content

Commit

Permalink
[Bugfix][ONNX] Skip constant If node generated by PyTorch (apache#17383)
Browse files Browse the repository at this point in the history
* [Bugfix][VTA] Fix FSIM compile error on macOS.

VTA FSIM could not be built on macOS, for it leverages malloc.h and
memalign, yet both have been deprecated and are not provided by
macOS. This issue was captured in apache#13173.

This commit stops including malloc.h in VTA Runtime as stdlib.h has
provided functions we need.

This commit uses posix_memalign instead of memalign. It is a portable standard function.

* Fix format.

* [Bugfix][ONNX] Skip constant If node generated by PyTorch
This commit adds a check for If nodes for ONNX frontend of Relay
to skip the broadcast if the predicate is constant.
Sometimes PyTorch to ONNX inserts silly if nodes that produce dynamic
ranks, and ONNX frontend of TVM would broadcast the lower dimensions
between branches, which is irrational for some cases, e.g. 5×5×3×4 to
5×5×3×4×1. The predicate of silly if might be constant and reasonable
to skip to avoid the broadcast problem.
This issue was captured in apache#16898.

* Fix format.
  • Loading branch information
xhmelon authored Sep 22, 2024
1 parent 425e15b commit 72d542e
Showing 1 changed file with 18 additions and 0 deletions.
18 changes: 18 additions & 0 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -4565,6 +4565,23 @@ def _impl_v1(cls, inputs, attr, params):
"Attempting to unify ranks but this may produce incorrect results."
)
warnings.warn(warning_msg)
# Skip constant If node to avoid irrational broadcast
if isinstance(inputs[0], tvm.relay.expr.Constant):
predicate = inputs[0].data.asnumpy()[0]
node_name = attr["tvm_custom"]["name"]
warn_msg_begin = f"Predicate of If node {node_name} is always "
if predicate == np.bool_(True):
warnings.warn(
warn_msg_begin
+ "true so only then branch would be executed. Removing else branch. "
)
else_expr = then_expr
elif predicate == np.bool_(False):
warnings.warn(
warn_msg_begin
+ "false so only else branch would be executed. Removing then branch. "
)
then_expr = else_expr
if len(then_shape) < len(else_shape):
then_expr = _op.broadcast_to_like(then_expr, else_expr)
else:
Expand Down Expand Up @@ -6529,6 +6546,7 @@ def _impl_v11(cls, inputs, attr, params):
# compatible operators that do NOT require any conversion.
_identity_list = []


# _convert_map defines maps of name to converter functor(callable)
# for 1 to 1 mapping, use Renamer if nothing but name is different
# use AttrCvt if attributes need to be converted
Expand Down

0 comments on commit 72d542e

Please sign in to comment.