Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add get_flow_node_by_name #80

Merged
merged 1 commit into from
Sep 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions oneflow_onnx/oneflow2onnx/flow2onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def FlowToOnnxNaive(graph, shape_override):
op_cnt = collections.Counter()
attr_cnt = collections.Counter()
onnx_nodes = []
flow_nodes = {}

def is_user_op(node):
return node.WhichOneof("op_type") == "user_conf"
Expand Down Expand Up @@ -176,11 +177,12 @@ def get_outputs(node):
op_type, input_names, output_names, name=node.name, **attr
)
onnx_nodes.append(onnx_node)
flow_nodes[node.name] = node
except Exception as ex:
logger.error("pass1 convert failed for %s, ex=%s", node, ex)
raise

return onnx_nodes, op_cnt, attr_cnt, dtypes, shape_override
return onnx_nodes, flow_nodes, op_cnt, attr_cnt, dtypes, shape_override


def FlowOnnxMapping(g, ops_mapping):
Expand Down Expand Up @@ -309,11 +311,11 @@ def ProcessFlowGraph(
if shape_override is None:
shape_override = {}

(onnx_nodes, op_cnt, attr_cnt, dtypes, output_shapes,) = FlowToOnnxNaive(
(onnx_nodes, flow_nodes, op_cnt, attr_cnt, dtypes, output_shapes,) = FlowToOnnxNaive(
flow_graph, shape_override
)

g = Graph(onnx_nodes, model_save_dir, output_shapes, dtypes, opset, extra_opset,)
g = Graph(onnx_nodes, flow_nodes, model_save_dir, output_shapes, dtypes, opset, extra_opset,)

# create ops mapping for the desired opsets
ops_mapping = handler.flow_op.CreateMapping(g.opset, g.extra_opset)
Expand Down
7 changes: 7 additions & 0 deletions oneflow_onnx/onnx_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,7 @@ class Graph(object):
def __init__(
self,
nodes,
flow_nodes,
model_save_dir,
output_shapes=None,
dtypes=None,
Expand All @@ -408,6 +409,7 @@ def __init__(
self._param_dict = oneflow.load(self._model_save_dir)
self._output_shapes = output_shapes
self._opset = util.FindOpset(opset)
self._flow_nodes_by_name = flow_nodes

if extra_opset is not None:
util.MakeSure(isinstance(extra_opset, list), "invalid extra_opset")
Expand Down Expand Up @@ -796,6 +798,11 @@ def get_node_by_name(self, name):
ret = self._nodes_by_name.get(name)
return ret

def get_flow_node_by_name(self, name):
"""Get OneFlow op node by name."""
ret = self._flow_nodes_by_name.get(name)
return ret

def set_node_by_name(self, node):
"""Set node by name."""
self._nodes_by_name[node.name] = node
Expand Down