Skip to content

Commit

Permalink
Delete pool index output (#97)
Browse files Browse the repository at this point in the history
* delete  maxpool index output

* refine

* format

* fix comment

* refine
  • Loading branch information
BBuf authored Sep 26, 2022
1 parent bb746f6 commit 48fa7c2
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 0 deletions.
2 changes: 2 additions & 0 deletions oneflow_onnx/oneflow2onnx/handlers/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,8 @@ def _Convert(cls, ctx, node, **kwargs):
else:
pads = node.attrs.get("padding_before", [0, 0]) + node.attrs.get("padding_after", [0, 0])
node.attrs["pads"] = pads
if len(node.output_tensor_names) > 1 and len(ctx.FindOutputConsumers(node.output_tensor_names[1])) == 0:
ctx.RemoveOutput(node, node.output_tensor_names[1])


@flow_op(["pad"], onnx_op="Pad")
Expand Down
15 changes: 15 additions & 0 deletions oneflow_onnx/onnx_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -1057,6 +1057,21 @@ def RemoveInput(node, to_be_removed):
# don't remove output from parent since others might depend on it
return True

@staticmethod
def RemoveOutput(node, to_be_removed):
"""Remove output from Node.
Args:
node: the node we expect the output on
to_be_removed: the node name we want to remove
"""
assert isinstance(node, Node) and isinstance(to_be_removed, six.text_type)
for i, name in enumerate(node.output_tensor_names):
if name == to_be_removed:
del node.output_tensor_names[i]
break
# don't remove output from parent since others might depend on it
return True

def InsertNewNodeOnInput(self, node, op_type, input_name, name=None, domain=None, **kwargs):
"""Create and insert a new node into the graph.
Args:
Expand Down

0 comments on commit 48fa7c2

Please sign in to comment.