From 673a41a118863a36b7e79383089cc1ca2606b851 Mon Sep 17 00:00:00 2001 From: Ningxin Date: Tue, 2 Jun 2020 12:46:16 +0000 Subject: [PATCH 01/11] find new bugs --- src/sdk/pynni/nni/_graph_utils.py | 39 ++++++++++++++++++++++++------- 1 file changed, 31 insertions(+), 8 deletions(-) diff --git a/src/sdk/pynni/nni/_graph_utils.py b/src/sdk/pynni/nni/_graph_utils.py index b17aa02994..7370367613 100644 --- a/src/sdk/pynni/nni/_graph_utils.py +++ b/src/sdk/pynni/nni/_graph_utils.py @@ -178,6 +178,7 @@ def __init__(self, name, node_type, op_type, node_cpps, inputs=None, outputs=Non def add_nodes(self, node_cpps): for node_cpp in node_cpps: nodepy = NodePyOP(node_cpp) + # TODO bug here, wrong way to get the name nodepy.name = str(node_cpp).split(':')[0].strip().replace('%', '') self.nodes.append(nodepy) @@ -411,19 +412,35 @@ def _build_graph(self): if module_name in self.leaf_modules: module_to_nodes[module_name].append(node) else: + print(node) func_to_nodes[node.scopeName()].append(node) - + # print(module_to_nodes.keys()) + print('######') + # print(func_to_nodes.keys()) # build node group for module for module_name, node_cpps in module_to_nodes.items(): - node_group = self._build_module_node_group( - module_name, module_to_type[module_name], node_cpps, input_to_node, output_to_node - ) - _logger.debug('node_group: %s', node_group) - nodes_py.nodes_op.append(node_group) - + print(module_name, len(node_cpps)) + print(node_cpps) + if len(node_cpps) == 1: + node_group = self._build_module_node_group( + unique_name, module_to_type[module_name], node_cpps, input_to_node, output_to_node + ) + nodes_py.nodes_op.append(node_group) + continue + for useid, node in enumerate(node_cpps): + unique_name = module_name + '.%d' % useid + node_group = self._build_module_node_group( + unique_name, module_to_type[module_name], [node], input_to_node, output_to_node + ) + _logger.debug('node_group: %s', node_group) + nodes_py.nodes_op.append(node_group) + print('$$$$$$$$$$$$') # each scope_name may have multiple funcs, we split them and create node for each of them # build node group for torch.nn.functional - for _, nodes in func_to_nodes.items(): + for tname, nodes in func_to_nodes.items(): + print('###', tname) + print(len(nodes)) + used = set() # extract non prim:: nodes non_prim_nodes = list() for node in nodes: @@ -432,13 +449,19 @@ def _build_graph(self): # for each non prim node, expand it for node in non_prim_nodes: node_group = self._expand_non_prim_node(node, nodes, input_to_node, output_to_node) + used.update(node_group.node_cpps) nodes_py.nodes_op.append(node_group) # get shape infor for view (aten::view) func if node_group.op_type in ['aten::view', 'aten::flatten']: node_group.auxiliary = self._extract_shape_info(node) + print(len(set(nodes)-used)) + print(set(nodes)-used) + for node in graph.outputs(): # Create sink nodes for output ops node_py = NodePyIO(node, 'output') nodes_py.append(node_py) + + self.nodes_py = nodes_py # build index From b81e9c7cd7af1e919cacb5b2bfb555c1212681ed Mon Sep 17 00:00:00 2001 From: Ningxin Date: Wed, 3 Jun 2020 05:05:45 +0000 Subject: [PATCH 02/11] Fix the bug for issue2485. Signed-off-by: Ningxin --- src/sdk/pynni/nni/_graph_utils.py | 174 ++++++++++++++++-------------- 1 file changed, 92 insertions(+), 82 deletions(-) diff --git a/src/sdk/pynni/nni/_graph_utils.py b/src/sdk/pynni/nni/_graph_utils.py index 7370367613..07f956172a 100644 --- a/src/sdk/pynni/nni/_graph_utils.py +++ b/src/sdk/pynni/nni/_graph_utils.py @@ -8,19 +8,25 @@ from collections import defaultdict import torch from torch.utils.tensorboard._pytorch_graph import NodePy, NodePyIO, NodePyOP, GraphPy - +from tensorboard.compat.proto.config_pb2 import RunMetadata +from tensorboard.compat.proto.graph_pb2 import GraphDef +from tensorboard.compat.proto.step_stats_pb2 import StepStats, DeviceStepStats +from tensorboard.compat.proto.versions_pb2 import VersionDef CLASSTYPE_KIND = 'ClassType' GETATTR_KIND = 'prim::GetAttr' _logger = logging.getLogger(__name__) + def build_module_graph(model, dummy_input): return TorchModuleGraph(model, dummy_input) + def build_graph(model, dummy_input, verbose=False): g = TorchProtoGraph(model, dummy_input, verbose) return g.graph_def, g.stepstats + def parse_traced_name(module_name): prefix = 'TracedModule[' suffix = ']' @@ -28,11 +34,13 @@ def parse_traced_name(module_name): module_name = module_name[len(prefix):-len(suffix)] return module_name + class TorchGraph: """ This class is to extract pytorch model topology graph by tracing """ - def __init__(self, model, dummy_input): + + def __init__(self, model=None, dummy_input=None, traced_model=None): """ Parameters ---------- @@ -42,36 +50,42 @@ def __init__(self, model, dummy_input): The dummy input for ```jit.trace```, users should put it on right device before pass in """ assert torch.__version__ >= '1.3.1' - - self.bound_model = model - self._trace(model, dummy_input) - + # check if the input is legal + if traced_model is not None: + assert isinstance(traced_model, torch.jit.TopLevelTracedModule) + self.trace = traced_model + # it's ok if the graph is already unpacked + torch._C._jit_pass_inline(self.trace.graph) + elif model is not None and dummy_input is not None: + self.bound_model = model + self._trace(model, dummy_input) + else: + raise Exception('Please provide model & dummy_input or the traced_model as inputs') def _trace(self, model, dummy_input): with torch.onnx.set_training(model, False): self.trace = torch.jit.trace(model, dummy_input) torch._C._jit_pass_inline(self.trace.graph) + class TorchProtoGraph(TorchGraph): """ - Generates model graph for pytorch models in protobuf, this implementation is borrowed from pytorch v1.4.0, - and fixed following issues: + Generates model graph for pytorch models in protobuf, this implementation + is borrowed from pytorch v1.4.0, and fixed following issues: https://github.com/pytorch/pytorch/issues/33691 https://github.com/pytorch/pytorch/issues/33670 """ + def __init__(self, model, dummy_input, verbose=False): super().__init__(model, dummy_input) - from tensorboard.compat.proto.config_pb2 import RunMetadata - from tensorboard.compat.proto.graph_pb2 import GraphDef - from tensorboard.compat.proto.step_stats_pb2 import StepStats, DeviceStepStats - from tensorboard.compat.proto.versions_pb2 import VersionDef - list_of_nodes = self.parse(self.trace.graph, self.trace, dummy_input) if verbose: print(self.trace.graph) - self.stepstats = RunMetadata(step_stats=StepStats(dev_stats=[DeviceStepStats(device="/device:CPU:0")])) - self.graph_def = GraphDef(node=list_of_nodes, versions=VersionDef(producer=22)) + self.stepstats = RunMetadata(step_stats=StepStats( + dev_stats=[DeviceStepStats(device="/device:CPU:0")])) + self.graph_def = GraphDef( + node=list_of_nodes, versions=VersionDef(producer=22)) def parse(self, graph, trace, args=None, omit_useless_nodes=True): """This method parses an optimized PyTorch model graph and produces @@ -94,16 +108,19 @@ def parse(self, graph, trace, args=None, omit_useless_nodes=True): nodes_py.append(NodePyIO(node, 'input')) attr_to_scope = dict() - node_to_name = lambda d: str(d).split(":")[0].strip() + def node_to_name(d): + return str(d).split(":")[0].strip() for node in graph.nodes(): if node.kind() == GETATTR_KIND: attr_name = node.s('name') node_name = node_to_name(node) parent = node.input().node() - if parent.kind() == GETATTR_KIND: # If the parent node is not the top-level "self" node + # If the parent node is not the top-level "self" node + if parent.kind() == GETATTR_KIND: parent_scope = attr_to_scope[node_to_name(parent)] attr_scope = parent_scope.split('/')[-1] - attr_to_scope[node_name] = '{}/{}.{}'.format(parent_scope, attr_scope, attr_name) + attr_to_scope[node_name] = '{}/{}.{}'.format( + parent_scope, attr_scope, attr_name) else: attr_to_scope[node_name] = '__module.{}'.format(attr_name) # We don't need classtype nodes; scope will provide this information @@ -114,7 +131,8 @@ def parse(self, graph, trace, args=None, omit_useless_nodes=True): else: nodes_py.append(NodePyOP(node)) - for i, node in enumerate(graph.outputs()): # Create sink nodes for output ops + # Create sink nodes for output ops + for i, node in enumerate(graph.outputs()): node_py = NodePyIO(node, 'output') node_py.debugName = "output.{}".format(i + 1) node_py.inputs = [node.debugName()] @@ -136,17 +154,21 @@ def parse(self, graph, trace, args=None, omit_useless_nodes=True): node.scopeName = base_name else: module_name += '.' + alias - node.scopeName += '/' + (alias_to_name[module_name] if module_name in alias_to_name else alias) + node.scopeName += '/' + \ + (alias_to_name[module_name] + if module_name in alias_to_name else alias) nodes_py.populate_namespace_from_OP_to_IO() return nodes_py.to_proto() + class NodePyGroup(NodePy): """ This class is used to represent a graph node which consists of multiple jit traced nodes. In a pytorch trace graph, there are multiple nodes are traced for one torch.nn.Module object, we group them together to form a single node to represent the torch.nn.Module object. We also group some functional call trace nodes together to form a new node. """ + def __init__(self, name, node_type, op_type, node_cpps, inputs=None, outputs=None): """ Parameters: @@ -187,7 +209,8 @@ def sub_node_names(self): def __repr__(self): return 'name: {}, type: {}, op_type: {}, sub_nodes: {}, inputs: {}, outputs: {}, aux: {}'.format( - self.name, self.type, self.op_type, self.sub_node_names(), self.inputs, self.outputs, self.auxiliary + self.name, self.type, self.op_type, self.sub_node_names( + ), self.inputs, self.outputs, self.auxiliary ) @@ -195,12 +218,13 @@ class TorchModuleGraph(TorchGraph): """ Generates model graph, each node is created from single or multiple jit trace nodes. """ + def __init__(self, model, dummy_input): super().__init__(model, dummy_input) self.global_count = 0 self.name_to_node, self.input_to_node, self.output_to_node = self._build_graph() - def _expand_non_prim_node(self, node, nodes, input_to_node, output_to_node): + def _expand_non_prim_node(self, node, nodes, input_to_node, output_to_node, module_type, node_name=None): """ For trace graph nodes, some nodes are not in modules, these nodes are usually generated by the functions directly called in module ```forward```. For such nodes, some of them are @@ -225,7 +249,9 @@ def _expand_non_prim_node(self, node, nodes, input_to_node, output_to_node): the expanded non-prim node """ # TODO: scope name could be empty - node_name = '.'.join([self._get_module_name(node.scopeName()), node.kind(), str(self.global_count)]) + if not node_name: + node_name = '.'.join([self._get_module_name( + node.scopeName()), node.kind(), str(self.global_count)]) _logger.debug("expand non-prim node, node name: %s", node_name) self.global_count += 1 op_type = node.kind() @@ -240,39 +266,20 @@ def _expand_non_prim_node(self, node, nodes, input_to_node, output_to_node): for _input in curr_node.inputs(): input_name = _input.debugName() if input_name in output_to_node and output_to_node[input_name] in nodes: - predecessor_node = output_to_node[input_name] - if predecessor_node.kind().startswith('prim::'): - node_group.append(predecessor_node) - node_queue.put(predecessor_node) - else: - inputs.append(input_name) + predecessor_node = output_to_node[input_name] + if predecessor_node.kind().startswith('prim::'): + node_group.append(predecessor_node) + node_queue.put(predecessor_node) + else: + inputs.append(input_name) else: inputs.append(input_name) for output in node.outputs(): outputs.append(output.debugName()) - nodepy = NodePyGroup(node_name, 'func', op_type, node_group, inputs=inputs, outputs=outputs) + nodepy = NodePyGroup(node_name, module_type, op_type, + node_group, inputs=inputs, outputs=outputs) return nodepy - def _build_module_node_group(self, module_name, op_type, node_cpps, input_to_node, output_to_node): - graph = self.trace.graph - inputs, outputs = [], [] - for n in node_cpps: - for i in n.inputs(): - name = i.debugName() - if not name in output_to_node and i in graph.inputs(): - inputs.append(name) - elif output_to_node[name] not in node_cpps: - inputs.append(name) - for o in n.outputs(): - name = o.debugName() - if not name in input_to_node and o in graph.outputs(): - outputs.append(name) - elif input_to_node[name] not in node_cpps: - outputs.append(name) - - return NodePyGroup(module_name, 'module', op_type, node_cpps, inputs, outputs) - - def _extract_shape_info(self, node): """ Extract the shape information of ```aten::view``` node @@ -319,11 +326,12 @@ def is_parent(name1, name2): parts1, parts2 = name1.split('.'), name2.split('.') if len(parts1) >= len(parts2): return False - for i in range(len(parts1)): + for i, _ in enumerate(parts1): if parts2[i] != parts1[i]: return False return True - module_names = sorted([x[0] for x in self.trace.named_modules() if x[0]]) + module_names = sorted([x[0] + for x in self.trace.named_modules() if x[0]]) leaf_nodes = [] for i, name in enumerate(module_names): if i + 1 >= len(module_names) or not is_parent(name, module_names[i + 1]): @@ -386,9 +394,11 @@ def _build_graph(self): graph = self.trace.graph _logger.debug(graph) # build output mapping, from output debugName to its node - output_to_node = {x.debugName(): n for n in graph.nodes() for x in n.outputs()} + output_to_node = {x.debugName(): n for n in graph.nodes() + for x in n.outputs()} # build input mapping, from input debugName to its node - input_to_node = {x.debugName(): n for n in graph.nodes() for x in n.inputs()} + input_to_node = {x.debugName(): n for n in graph.nodes() + for x in n.inputs()} # build module mapping, from module name to all nodes (as list) under this module scope module_to_nodes = defaultdict(list) # the mapping of function (non-module in forward) to nodes, key is scope name @@ -404,7 +414,8 @@ def _build_graph(self): nodes_py.append(NodePyIO(node, 'input')) self.leaf_modules = self._extract_leaf_modules() - module_to_type = {name: parse_traced_name(module._name) for name, module in self.trace.named_modules()} + module_to_type = {name: parse_traced_name( + module._name) for name, module in self.trace.named_modules()} # associate module name with their trace graph nodes for node in graph.nodes(): @@ -414,32 +425,30 @@ def _build_graph(self): else: print(node) func_to_nodes[node.scopeName()].append(node) - # print(module_to_nodes.keys()) - print('######') - # print(func_to_nodes.keys()) # build node group for module for module_name, node_cpps in module_to_nodes.items(): print(module_name, len(node_cpps)) print(node_cpps) - if len(node_cpps) == 1: - node_group = self._build_module_node_group( - unique_name, module_to_type[module_name], node_cpps, input_to_node, output_to_node - ) - nodes_py.nodes_op.append(node_group) - continue - for useid, node in enumerate(node_cpps): - unique_name = module_name + '.%d' % useid - node_group = self._build_module_node_group( - unique_name, module_to_type[module_name], [node], input_to_node, output_to_node - ) - _logger.debug('node_group: %s', node_group) - nodes_py.nodes_op.append(node_group) - print('$$$$$$$$$$$$') + use_count = 0 + for node in node_cpps: + if not node.kind().startswith('prim::'): + # modules that have same scope name may have different locations in the + # graph. Futhermore, there are also lots of prim:: nodes that in node_cpps, + # so we also need to call the expand_non_prim_node. + unique_name = module_name + if use_count > 0: + unique_name = module_name + '.%d' % use_count + node_group = self._expand_non_prim_node( + node, node_cpps, input_to_node, output_to_node, + module_to_type[module_name], node_name=unique_name) + nodes_py.nodes_op.append(node_group) + use_count += 1 + print(node_group.name, use_count, + len(node_group.node_cpps)) + # each scope_name may have multiple funcs, we split them and create node for each of them # build node group for torch.nn.functional - for tname, nodes in func_to_nodes.items(): - print('###', tname) - print(len(nodes)) + for _, nodes in func_to_nodes.items(): used = set() # extract non prim:: nodes non_prim_nodes = list() @@ -448,20 +457,17 @@ def _build_graph(self): non_prim_nodes.append(node) # for each non prim node, expand it for node in non_prim_nodes: - node_group = self._expand_non_prim_node(node, nodes, input_to_node, output_to_node) + node_group = self._expand_non_prim_node( + node, nodes, input_to_node, output_to_node, 'func') used.update(node_group.node_cpps) nodes_py.nodes_op.append(node_group) # get shape infor for view (aten::view) func if node_group.op_type in ['aten::view', 'aten::flatten']: node_group.auxiliary = self._extract_shape_info(node) - print(len(set(nodes)-used)) - print(set(nodes)-used) - + for node in graph.outputs(): # Create sink nodes for output ops node_py = NodePyIO(node, 'output') nodes_py.append(node_py) - - self.nodes_py = nodes_py # build index @@ -506,7 +512,11 @@ def find_successors(self, module_name): """ successors = [] for output in self.name_to_node[module_name].outputs: - assert output in self.input_to_node, "No node with input {}".format(output) + # assert output in self.input_to_node, "No node with input {}(from {})".format( + # output, module_name) + if output not in self.input_to_node: + # may reach the output of the whole graph + continue nodes_py = self.input_to_node[output] for node_py in nodes_py: successors.append(node_py.name) From 9ef109f9ba5fb71155f87b642e1db2411c877b78 Mon Sep 17 00:00:00 2001 From: Ningxin Date: Wed, 3 Jun 2020 06:34:34 +0000 Subject: [PATCH 03/11] update Signed-off-by: Ningxin --- src/sdk/pynni/nni/_graph_utils.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/sdk/pynni/nni/_graph_utils.py b/src/sdk/pynni/nni/_graph_utils.py index 07f956172a..28aafe6cb9 100644 --- a/src/sdk/pynni/nni/_graph_utils.py +++ b/src/sdk/pynni/nni/_graph_utils.py @@ -423,12 +423,9 @@ def _build_graph(self): if module_name in self.leaf_modules: module_to_nodes[module_name].append(node) else: - print(node) func_to_nodes[node.scopeName()].append(node) # build node group for module for module_name, node_cpps in module_to_nodes.items(): - print(module_name, len(node_cpps)) - print(node_cpps) use_count = 0 for node in node_cpps: if not node.kind().startswith('prim::'): @@ -443,8 +440,6 @@ def _build_graph(self): module_to_type[module_name], node_name=unique_name) nodes_py.nodes_op.append(node_group) use_count += 1 - print(node_group.name, use_count, - len(node_group.node_cpps)) # each scope_name may have multiple funcs, we split them and create node for each of them # build node group for torch.nn.functional From 7c46eecb2b7e5b4946405586bd203e34b77ee403 Mon Sep 17 00:00:00 2001 From: Ningxin Date: Thu, 4 Jun 2020 07:25:14 +0000 Subject: [PATCH 04/11] move the tensorboadr related package importing back. Signed-off-by: Ningxin --- src/sdk/pynni/nni/_graph_utils.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/sdk/pynni/nni/_graph_utils.py b/src/sdk/pynni/nni/_graph_utils.py index 28aafe6cb9..79d124acc9 100644 --- a/src/sdk/pynni/nni/_graph_utils.py +++ b/src/sdk/pynni/nni/_graph_utils.py @@ -8,10 +8,6 @@ from collections import defaultdict import torch from torch.utils.tensorboard._pytorch_graph import NodePy, NodePyIO, NodePyOP, GraphPy -from tensorboard.compat.proto.config_pb2 import RunMetadata -from tensorboard.compat.proto.graph_pb2 import GraphDef -from tensorboard.compat.proto.step_stats_pb2 import StepStats, DeviceStepStats -from tensorboard.compat.proto.versions_pb2 import VersionDef CLASSTYPE_KIND = 'ClassType' GETATTR_KIND = 'prim::GetAttr' @@ -79,6 +75,11 @@ class TorchProtoGraph(TorchGraph): def __init__(self, model, dummy_input, verbose=False): super().__init__(model, dummy_input) + from tensorboard.compat.proto.config_pb2 import RunMetadata + from tensorboard.compat.proto.graph_pb2 import GraphDef + from tensorboard.compat.proto.step_stats_pb2 import StepStats, DeviceStepStats + from tensorboard.compat.proto.versions_pb2 import VersionDef + list_of_nodes = self.parse(self.trace.graph, self.trace, dummy_input) if verbose: print(self.trace.graph) From e9ba10b70ea473424fbe9f164dbf45f70dd3bdf5 Mon Sep 17 00:00:00 2001 From: Ningxin Date: Thu, 4 Jun 2020 08:12:57 +0000 Subject: [PATCH 05/11] update --- src/sdk/pynni/nni/_graph_utils.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/src/sdk/pynni/nni/_graph_utils.py b/src/sdk/pynni/nni/_graph_utils.py index 79d124acc9..b1aa18625b 100644 --- a/src/sdk/pynni/nni/_graph_utils.py +++ b/src/sdk/pynni/nni/_graph_utils.py @@ -56,7 +56,9 @@ def __init__(self, model=None, dummy_input=None, traced_model=None): self.bound_model = model self._trace(model, dummy_input) else: - raise Exception('Please provide model & dummy_input or the traced_model as inputs') + raise Exception( + 'Please provide model & dummy_input or the traced_model as inputs') + def _trace(self, model, dummy_input): with torch.onnx.set_training(model, False): self.trace = torch.jit.trace(model, dummy_input) @@ -109,6 +111,7 @@ def parse(self, graph, trace, args=None, omit_useless_nodes=True): nodes_py.append(NodePyIO(node, 'input')) attr_to_scope = dict() + def node_to_name(d): return str(d).split(":")[0].strip() for node in graph.nodes(): @@ -201,7 +204,7 @@ def __init__(self, name, node_type, op_type, node_cpps, inputs=None, outputs=Non def add_nodes(self, node_cpps): for node_cpp in node_cpps: nodepy = NodePyOP(node_cpp) - # TODO bug here, wrong way to get the name + # TODO may be a bug here, need confirmation with chengmin~ nodepy.name = str(node_cpp).split(':')[0].strip().replace('%', '') self.nodes.append(nodepy) @@ -225,7 +228,8 @@ def __init__(self, model, dummy_input): self.global_count = 0 self.name_to_node, self.input_to_node, self.output_to_node = self._build_graph() - def _expand_non_prim_node(self, node, nodes, input_to_node, output_to_node, module_type, node_name=None): + def _expand_non_prim_node(self, node, nodes, input_to_node, output_to_node, + module_type, node_name=None, op_type=None): """ For trace graph nodes, some nodes are not in modules, these nodes are usually generated by the functions directly called in module ```forward```. For such nodes, some of them are @@ -243,6 +247,12 @@ def _expand_non_prim_node(self, node, nodes, input_to_node, output_to_node, modu key: input name, value: a node that uses this input output_to_node : dict key: output name, value: a node that generates this output + module_type : str + can be 'module' or 'func' + node_name : str + specify the node_name for NodePyGroup + op_type : str + specify the op_type for the NodePyGroup Returns ------- @@ -255,8 +265,8 @@ def _expand_non_prim_node(self, node, nodes, input_to_node, output_to_node, modu node.scopeName()), node.kind(), str(self.global_count)]) _logger.debug("expand non-prim node, node name: %s", node_name) self.global_count += 1 - op_type = node.kind() - + if not op_type: + op_type = node.kind() node_group = [node] inputs = list() outputs = list() @@ -438,14 +448,13 @@ def _build_graph(self): unique_name = module_name + '.%d' % use_count node_group = self._expand_non_prim_node( node, node_cpps, input_to_node, output_to_node, - module_to_type[module_name], node_name=unique_name) + 'module', node_name=unique_name, op_type=module_to_type[module_name]) nodes_py.nodes_op.append(node_group) use_count += 1 # each scope_name may have multiple funcs, we split them and create node for each of them # build node group for torch.nn.functional for _, nodes in func_to_nodes.items(): - used = set() # extract non prim:: nodes non_prim_nodes = list() for node in nodes: @@ -455,7 +464,6 @@ def _build_graph(self): for node in non_prim_nodes: node_group = self._expand_non_prim_node( node, nodes, input_to_node, output_to_node, 'func') - used.update(node_group.node_cpps) nodes_py.nodes_op.append(node_group) # get shape infor for view (aten::view) func if node_group.op_type in ['aten::view', 'aten::flatten']: From 525d67625dc96809e654e7d7796cb792bb596bdd Mon Sep 17 00:00:00 2001 From: Ningxin Date: Thu, 4 Jun 2020 11:48:37 +0000 Subject: [PATCH 06/11] update --- src/sdk/pynni/nni/_graph_utils.py | 44 ++++++++++++++++++++----------- 1 file changed, 29 insertions(+), 15 deletions(-) diff --git a/src/sdk/pynni/nni/_graph_utils.py b/src/sdk/pynni/nni/_graph_utils.py index b1aa18625b..a8a1a5add4 100644 --- a/src/sdk/pynni/nni/_graph_utils.py +++ b/src/sdk/pynni/nni/_graph_utils.py @@ -173,12 +173,18 @@ class NodePyGroup(NodePy): represent the torch.nn.Module object. We also group some functional call trace nodes together to form a new node. """ - def __init__(self, name, node_type, op_type, node_cpps, inputs=None, outputs=None): + def __init__(self, name, unique_name, node_type, op_type, node_cpps, inputs=None, outputs=None): """ Parameters: ----------- name: str node name, such as `conv1`, `backbone.classifier` + unique_name: str + A global unique name for current node. Due to some modules, + such as relu, may be reused several times, so the scopename + is not suitable as the global unique identifier, so we add a + unique_name for each node as the global unique identifier. + We should use the unique_name to traverset the module graph. node_type: str `module` or `func` op_type: str @@ -193,6 +199,7 @@ def __init__(self, name, node_type, op_type, node_cpps, inputs=None, outputs=Non super(NodePyGroup, self).__init__(name, []) self.node_cpps = node_cpps self.name = name + self.unique_name = unique_name self.op_type = op_type self.type = node_type self.nodes = [] @@ -229,7 +236,7 @@ def __init__(self, model, dummy_input): self.name_to_node, self.input_to_node, self.output_to_node = self._build_graph() def _expand_non_prim_node(self, node, nodes, input_to_node, output_to_node, - module_type, node_name=None, op_type=None): + module_type, node_name=None, op_type=None, unique_name=None): """ For trace graph nodes, some nodes are not in modules, these nodes are usually generated by the functions directly called in module ```forward```. For such nodes, some of them are @@ -253,6 +260,8 @@ def _expand_non_prim_node(self, node, nodes, input_to_node, output_to_node, specify the node_name for NodePyGroup op_type : str specify the op_type for the NodePyGroup + unique_name: str + unique_name for the NodePyGroup Returns ------- @@ -263,6 +272,9 @@ def _expand_non_prim_node(self, node, nodes, input_to_node, output_to_node, if not node_name: node_name = '.'.join([self._get_module_name( node.scopeName()), node.kind(), str(self.global_count)]) + if not unique_name: + # if unique name is None, use the same value as node_name + unique_name = node_name _logger.debug("expand non-prim node, node name: %s", node_name) self.global_count += 1 if not op_type: @@ -287,7 +299,7 @@ def _expand_non_prim_node(self, node, nodes, input_to_node, output_to_node, inputs.append(input_name) for output in node.outputs(): outputs.append(output.debugName()) - nodepy = NodePyGroup(node_name, module_type, op_type, + nodepy = NodePyGroup(node_name, unique_name, module_type, op_type, node_group, inputs=inputs, outputs=outputs) return nodepy @@ -374,7 +386,7 @@ def _build_index(self, nodes_op): input_to_node = defaultdict(list) output_to_node = dict() for node in nodes_op: - name_to_node[node.name] = node + name_to_node[node.unique_name] = node for _input in node.inputs: input_to_node[_input].append(node) for output in node.outputs: @@ -448,7 +460,9 @@ def _build_graph(self): unique_name = module_name + '.%d' % use_count node_group = self._expand_non_prim_node( node, node_cpps, input_to_node, output_to_node, - 'module', node_name=unique_name, op_type=module_to_type[module_name]) + 'module', node_name=module_name, + op_type=module_to_type[module_name], + unique_name=unique_name) nodes_py.nodes_op.append(node_group) use_count += 1 @@ -477,14 +491,14 @@ def _build_graph(self): # build index return self._build_index(self.nodes_py.nodes_op) - def find_predecessors(self, module_name): + def find_predecessors(self, unique_name): """ Find predecessor node of the given node Parameters ---------- - module_name : str - The name of the node + unique_name : str + The unique name of the node Returns ------- @@ -492,22 +506,22 @@ def find_predecessors(self, module_name): a list of nodes who are the given node's predecessor """ predecessors = [] - for _input in self.name_to_node[module_name].inputs: + for _input in self.name_to_node[unique_name].inputs: if not _input in self.output_to_node: _logger.debug("cannot find node with %s as its output", _input) else: node_py = self.output_to_node[_input] - predecessors.append(node_py.name) + predecessors.append(node_py.unique_name) return predecessors - def find_successors(self, module_name): + def find_successors(self, unique_name): """ Find successor nodes of the given node Parameters ---------- - module_name : str - The name of the node + unique_name : str + The unique name of the node Returns ------- @@ -515,7 +529,7 @@ def find_successors(self, module_name): a list of nodes who are the given node's successor """ successors = [] - for output in self.name_to_node[module_name].outputs: + for output in self.name_to_node[unique_name].outputs: # assert output in self.input_to_node, "No node with input {}(from {})".format( # output, module_name) if output not in self.input_to_node: @@ -523,5 +537,5 @@ def find_successors(self, module_name): continue nodes_py = self.input_to_node[output] for node_py in nodes_py: - successors.append(node_py.name) + successors.append(node_py.unique_name) return successors From 758e7f1dcd256ba542236bf463254e2505e120bb Mon Sep 17 00:00:00 2001 From: Ningxin Date: Thu, 4 Jun 2020 12:51:38 +0000 Subject: [PATCH 07/11] Add _expand_module_node to fix the bug. Signed-off-by: Ningxin --- src/sdk/pynni/nni/_graph_utils.py | 79 +++++++++++++++++++++++++++++-- 1 file changed, 76 insertions(+), 3 deletions(-) diff --git a/src/sdk/pynni/nni/_graph_utils.py b/src/sdk/pynni/nni/_graph_utils.py index a8a1a5add4..e0feb06660 100644 --- a/src/sdk/pynni/nni/_graph_utils.py +++ b/src/sdk/pynni/nni/_graph_utils.py @@ -212,7 +212,8 @@ def add_nodes(self, node_cpps): for node_cpp in node_cpps: nodepy = NodePyOP(node_cpp) # TODO may be a bug here, need confirmation with chengmin~ - nodepy.name = str(node_cpp).split(':')[0].strip().replace('%', '') + # nodepy.name = str(node_cpp).split(':')[0].strip().replace('%', '') + nodepy.name = node_cpp.scopeName() + '_' + node_cpp.kind() self.nodes.append(nodepy) def sub_node_names(self): @@ -303,6 +304,76 @@ def _expand_non_prim_node(self, node, nodes, input_to_node, output_to_node, node_group, inputs=inputs, outputs=outputs) return nodepy + def _expand_module_node(self, node, nodes, input_to_node, output_to_node, + module_type, node_name=None, op_type=None, unique_name=None): + """ + merge the adjacent nodes of the module. The difference between the + _expand_module_node and _expand_non_prim_node is that, the _expand_non_prim_node + only merge the prim:: nodes into the aten:: node, in contrast,the _expand_module_node + will merge all adjacent nodes into a same nodepy group. + + Parameters + ---------- + node : trace graph node + The non-prim node to expand + nodes : list of trace graph node + All the trace graph nodes within the same scope as the non-prim node + input_to_node : dict + key: input name, value: a node that uses this input + output_to_node : dict + key: output name, value: a node that generates this output + module_type : str + can be 'module' or 'func' + node_name : str + specify the node_name for NodePyGroup + op_type : str + specify the op_type for the NodePyGroup + unique_name: str + unique_name for the NodePyGroup + + Returns + ------- + node + the expanded non-prim node + + """ + _logger.debug("expand module node, node name: %s", node_name) + self.global_count += 1 + if not op_type: + op_type = node.kind() + node_group = [node] + inputs = list() + outputs = list() + node_queue = queue.Queue() + node_queue.put(node) + visited = {node} + while not node_queue.empty(): + curr_node = node_queue.get() + for _input in curr_node.inputs(): + input_name = _input.debugName() + if input_name in output_to_node and output_to_node[input_name] in nodes: + predecessor_node = output_to_node[input_name] + if predecessor_node not in visited: + node_group.append(predecessor_node) + node_queue.put(predecessor_node) + visited.add(predecessor_node) + else: + inputs.append(input_name) + for _output in curr_node.outputs(): + output_name = _output.debugName() + if output_name in input_to_node and input_to_node[output_name] in nodes: + successor_node = input_to_node[output_name] + if successor_node not in visited: + node_group.append(successor_node) + node_queue.put(successor_node) + visited.add(successor_node) + else: + outputs.append(output_name) + + nodepy = NodePyGroup(node_name, unique_name, module_type, op_type, + node_group, inputs=inputs, outputs=outputs) + return nodepy + def _extract_shape_info(self, node): """ Extract the shape information of ```aten::view``` node @@ -450,21 +521,23 @@ def _build_graph(self): # build node group for module for module_name, node_cpps in module_to_nodes.items(): use_count = 0 + merged = set() for node in node_cpps: - if not node.kind().startswith('prim::'): + if node not in merged: # modules that have same scope name may have different locations in the # graph. Futhermore, there are also lots of prim:: nodes that in node_cpps, # so we also need to call the expand_non_prim_node. unique_name = module_name if use_count > 0: unique_name = module_name + '.%d' % use_count - node_group = self._expand_non_prim_node( + node_group = self._expand_module_node( node, node_cpps, input_to_node, output_to_node, 'module', node_name=module_name, op_type=module_to_type[module_name], unique_name=unique_name) nodes_py.nodes_op.append(node_group) use_count += 1 + merged.update(node_group.node_cpps) # each scope_name may have multiple funcs, we split them and create node for each of them # build node group for torch.nn.functional From 3ebf1e11d2dd316e9f0304820444ab647294a6ac Mon Sep 17 00:00:00 2001 From: Ningxin Date: Thu, 4 Jun 2020 13:45:15 +0000 Subject: [PATCH 08/11] update interface. Signed-off-by: Ningxin --- src/sdk/pynni/nni/_graph_utils.py | 43 +++++++++++-------------------- 1 file changed, 15 insertions(+), 28 deletions(-) diff --git a/src/sdk/pynni/nni/_graph_utils.py b/src/sdk/pynni/nni/_graph_utils.py index e0feb06660..db89d00230 100644 --- a/src/sdk/pynni/nni/_graph_utils.py +++ b/src/sdk/pynni/nni/_graph_utils.py @@ -237,7 +237,7 @@ def __init__(self, model, dummy_input): self.name_to_node, self.input_to_node, self.output_to_node = self._build_graph() def _expand_non_prim_node(self, node, nodes, input_to_node, output_to_node, - module_type, node_name=None, op_type=None, unique_name=None): + module_type): """ For trace graph nodes, some nodes are not in modules, these nodes are usually generated by the functions directly called in module ```forward```. For such nodes, some of them are @@ -257,12 +257,6 @@ def _expand_non_prim_node(self, node, nodes, input_to_node, output_to_node, key: output name, value: a node that generates this output module_type : str can be 'module' or 'func' - node_name : str - specify the node_name for NodePyGroup - op_type : str - specify the op_type for the NodePyGroup - unique_name: str - unique_name for the NodePyGroup Returns ------- @@ -270,16 +264,12 @@ def _expand_non_prim_node(self, node, nodes, input_to_node, output_to_node, the expanded non-prim node """ # TODO: scope name could be empty - if not node_name: - node_name = '.'.join([self._get_module_name( - node.scopeName()), node.kind(), str(self.global_count)]) - if not unique_name: - # if unique name is None, use the same value as node_name - unique_name = node_name + node_name = '.'.join([self._get_module_name( + node.scopeName()), node.kind(), str(self.global_count)]) + unique_name = node_name _logger.debug("expand non-prim node, node name: %s", node_name) self.global_count += 1 - if not op_type: - op_type = node.kind() + op_type = node.kind() node_group = [node] inputs = list() outputs = list() @@ -304,8 +294,8 @@ def _expand_non_prim_node(self, node, nodes, input_to_node, output_to_node, node_group, inputs=inputs, outputs=outputs) return nodepy - def _expand_module_node(self, node, nodes, input_to_node, output_to_node, - module_type, node_name=None, op_type=None, unique_name=None): + def _expand_module_node(self, node, node_name, unique_name, op_type, nodes, + input_to_node, output_to_node, module_type): """ merge the adjacent nodes of the module. The difference between the _expand_module_node and _expand_non_prim_node is that, the _expand_non_prim_node @@ -316,6 +306,12 @@ def _expand_module_node(self, node, nodes, input_to_node, output_to_node, ---------- node : trace graph node The non-prim node to expand + node_name : str + specify the node_name for NodePyGroup + unique_name : str + unique_name for the NodePyGroup + op_type : str + specify the op_type for the NodePyGroup nodes : list of trace graph node All the trace graph nodes within the same scope as the non-prim node input_to_node : dict @@ -324,13 +320,6 @@ def _expand_module_node(self, node, nodes, input_to_node, output_to_node, key: output name, value: a node that generates this output module_type : str can be 'module' or 'func' - node_name : str - specify the node_name for NodePyGroup - op_type : str - specify the op_type for the NodePyGroup - unique_name: str - unique_name for the NodePyGroup - Returns ------- node @@ -531,10 +520,8 @@ def _build_graph(self): if use_count > 0: unique_name = module_name + '.%d' % use_count node_group = self._expand_module_node( - node, node_cpps, input_to_node, output_to_node, - 'module', node_name=module_name, - op_type=module_to_type[module_name], - unique_name=unique_name) + node, module_name, unique_name, module_to_type[module_name], + node_cpps, input_to_node, output_to_node, 'module') nodes_py.nodes_op.append(node_group) use_count += 1 merged.update(node_group.node_cpps) From 188d95c66098aa48b9672c190c02571ba716d113 Mon Sep 17 00:00:00 2001 From: Ningxin Date: Tue, 9 Jun 2020 06:55:03 +0000 Subject: [PATCH 09/11] Add a test case for the module reuse scenario. Signed-off-by: Ningxin --- src/sdk/pynni/nni/_graph_utils.py | 4 +-- src/sdk/pynni/tests/test_graph_utils.py | 42 ++++++++++++++++++++++++- 2 files changed, 43 insertions(+), 3 deletions(-) diff --git a/src/sdk/pynni/nni/_graph_utils.py b/src/sdk/pynni/nni/_graph_utils.py index db89d00230..d6e3b36a69 100644 --- a/src/sdk/pynni/nni/_graph_utils.py +++ b/src/sdk/pynni/nni/_graph_utils.py @@ -231,8 +231,8 @@ class TorchModuleGraph(TorchGraph): Generates model graph, each node is created from single or multiple jit trace nodes. """ - def __init__(self, model, dummy_input): - super().__init__(model, dummy_input) + def __init__(self, model=None, dummy_input=None, traced_model=None): + super().__init__(model, dummy_input, traced_model) self.global_count = 0 self.name_to_node, self.input_to_node, self.output_to_node = self._build_graph() diff --git a/src/sdk/pynni/tests/test_graph_utils.py b/src/sdk/pynni/tests/test_graph_utils.py index 5960a41733..d1317c56de 100644 --- a/src/sdk/pynni/tests/test_graph_utils.py +++ b/src/sdk/pynni/tests/test_graph_utils.py @@ -15,7 +15,7 @@ import unittest from unittest import TestCase, main -from nni._graph_utils import build_module_graph, build_graph +from nni._graph_utils import build_module_graph, build_graph, TorchModuleGraph class BackboneModel1(nn.Module): def __init__(self): @@ -153,6 +153,46 @@ def forward(self, x): torch.randn(4, 5), os.path.join(os.path.dirname(__file__), "expect", "test_graph_module3.expect") ) + + @unittest.skipIf(torch.__version__ < "1.4.0", "not supported") + def test_module_reuse(self): + class MyModule(nn.Module): + def __init__(self): + super().__init__() + self.liner1 = nn.Linear(10, 10) + self.relu = nn.ReLU(inplace=True) + self.liner2 = nn.Linear(10, 20) + self.liner3 = nn.Linear(20, 10) + + def forward(self, x): + x = self.liner1(x) + x = self.relu(x) + x = self.liner2(x) + x = self.relu(x) + x = self.liner3(x) + x = self.relu(x) + return x + + data = torch.rand(10, 10) + net = MyModule() + traced = torch.jit.trace(net, data) + modulegraph = TorchModuleGraph(traced_model=traced) + # Traverse the TorchModuleGraph, due the resue of the relu module, + # there will be three cpp_nodes corrspoding to the same module. + # During traversing the graph, there should be only one + # successor of each cpp-node (including the cpp_nodes that corresponds + # to the same relu module). + for name, nodeio in modulegraph.nodes_py.nodes_io.items(): + if nodeio.input_or_output == 'input': + # Find the first node of the whole graph + start_nodes = modulegraph.input_to_node[name] + # We have only one single path top-down + assert len(start_nodes) == 1 + node = start_nodes[0].unique_name + while modulegraph.find_successors(node): + nodes = modulegraph.find_successors(node) + assert len(nodes) == 1 + node = nodes[0] if __name__ == '__main__': main() From cc9485e60844950e7d79d342aeb74c2761c99163 Mon Sep 17 00:00:00 2001 From: Ningxin Date: Wed, 10 Jun 2020 09:13:06 +0000 Subject: [PATCH 10/11] update --- src/sdk/pynni/nni/_graph_utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/sdk/pynni/nni/_graph_utils.py b/src/sdk/pynni/nni/_graph_utils.py index d6e3b36a69..cc89cd2b38 100644 --- a/src/sdk/pynni/nni/_graph_utils.py +++ b/src/sdk/pynni/nni/_graph_utils.py @@ -515,7 +515,7 @@ def _build_graph(self): if node not in merged: # modules that have same scope name may have different locations in the # graph. Futhermore, there are also lots of prim:: nodes that in node_cpps, - # so we also need to call the expand_non_prim_node. + # so we also need to call the expand_module_node. unique_name = module_name if use_count > 0: unique_name = module_name + '.%d' % use_count @@ -590,8 +590,6 @@ def find_successors(self, unique_name): """ successors = [] for output in self.name_to_node[unique_name].outputs: - # assert output in self.input_to_node, "No node with input {}(from {})".format( - # output, module_name) if output not in self.input_to_node: # may reach the output of the whole graph continue From 4a8d5192dc5a918bc1f5758e2fdfc6cfe1f97af9 Mon Sep 17 00:00:00 2001 From: Ningxin Date: Wed, 10 Jun 2020 13:54:35 +0000 Subject: [PATCH 11/11] update --- src/sdk/pynni/nni/_graph_utils.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/sdk/pynni/nni/_graph_utils.py b/src/sdk/pynni/nni/_graph_utils.py index cc89cd2b38..445aaebd58 100644 --- a/src/sdk/pynni/nni/_graph_utils.py +++ b/src/sdk/pynni/nni/_graph_utils.py @@ -44,6 +44,9 @@ def __init__(self, model=None, dummy_input=None, traced_model=None): The model user wants to speed up dummy_input : pytorch tensor The dummy input for ```jit.trace```, users should put it on right device before pass in + traced_model : torch._C.torch.jit.TopLevelTracedModule + An alredy traced model, if traced_model is not None, then TorchGraph will build the graph + based on this traced model and won't trace the model again. """ assert torch.__version__ >= '1.3.1' # check if the input is legal @@ -211,8 +214,6 @@ def __init__(self, name, unique_name, node_type, op_type, node_cpps, inputs=None def add_nodes(self, node_cpps): for node_cpp in node_cpps: nodepy = NodePyOP(node_cpp) - # TODO may be a bug here, need confirmation with chengmin~ - # nodepy.name = str(node_cpp).split(':')[0].strip().replace('%', '') nodepy.name = node_cpp.scopeName() + '_' + node_cpp.kind() self.nodes.append(nodepy) @@ -221,8 +222,8 @@ def sub_node_names(self): def __repr__(self): return 'name: {}, type: {}, op_type: {}, sub_nodes: {}, inputs: {}, outputs: {}, aux: {}'.format( - self.name, self.type, self.op_type, self.sub_node_names( - ), self.inputs, self.outputs, self.auxiliary + self.name, self.type, self.op_type, self.sub_node_names(), + self.inputs, self.outputs, self.auxiliary )