Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

[BUG] finding leaf modules #2241

Merged
merged 15 commits into from
Mar 27, 2020
59 changes: 28 additions & 31 deletions src/sdk/pynni/nni/compression/speedup/torch/compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,42 +229,39 @@ def _extract_leaf_modules(self, graph):
list
a list of scope name of all the leaf modules
"""
pieces = [] # each element is a dict
class SNode:
def __init__(self, name):
self.sname = name
self.childs = {}
QuanluZhang marked this conversation as resolved.
Show resolved Hide resolved

root = None
for node in graph.nodes():
scope_name = node.scopeName()
if scope_name == '':
continue
segs = scope_name.split('/')
segs_len = len(segs)
# increase the length of `pieces` if not enough
for _ in range(segs_len - len(pieces)):
pieces.append({})
# process internal segments of the scope name
# 'L' means leaf segment
# 'I' means internal segment
# internal segment can replace leaf segment at the same position of `pieces`
for i, seg in enumerate(segs[:-1]):
seg_name_dict = pieces[i]
if seg in seg_name_dict:
if seg_name_dict[seg][0] == 'L':
seg_name_dict[seg] = ('I', node)
else:
seg_name_dict[seg] = ('I', node)
# process the leaf segment of the scope name
last_segs_dict = pieces[len(segs) - 1]
if not segs[-1] in last_segs_dict:
last_segs_dict[segs[-1]] = ('L', node)
# traverse `pieces` to obtain all the leaf modules which are labeled with 'L'
leaf_modules = []
for piece in pieces:
for _, value in piece.items():
if value[0] == 'L':
assert value[1].scopeName() not in leaf_modules
# if this is a leaf module, the last segment of its scope name
# must be in pattern `xxx[xxx]`
if value[1].scopeName()[-1] == ']':
leaf_modules.append(value[1].scopeName())
return leaf_modules
if root is None:
root = SNode(segs[0])
curr = root
for seg in segs[1:]:
if not seg in curr.childs:
curr.childs[seg] = SNode(seg)
curr = curr.childs[seg]

leaf_nodes = []
def traverse_tree(node, scope_name):
if scope_name == '':
sn = node.sname
else:
sn = scope_name + '/' + node.sname
if not node.childs:
if node.sname[-1] == ']':
leaf_nodes.append(sn)
else:
for key in node.childs:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for child in node.children.values():
    traverse_tree(child, sn)

traverse_tree(node.childs[key], sn)
traverse_tree(root, '')
return leaf_nodes
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I'm right, what this is doing is actually: finding all leaf nodes that ends with ]. If so, there is a simpler way to do it:

all_names = [node.scopeName() for node in graph.nodes()]
all_names = list(filter(lambda x: x, all_names))  # filter out non-empty strings
all_names.sort()
leaf_nodes = []
for i, name in enumerate(all_names):
    if (i + 1 >= len(all_names) or not all_names[i + 1].startswith(name)) and name.endswith("]"):
        leaf_nodes.append(name)
return leaf_nodes

Since I didn't get enough context, I might be wrong...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I test both your functions, they can generate same results.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ultmaster too complex, I am not sure whether the logic is correct or not...


def _build_graph(self):
"""
Expand Down