diff --git a/src/sdk/pynni/nni/compression/speedup/torch/compressor.py b/src/sdk/pynni/nni/compression/speedup/torch/compressor.py index c594aa2ff8..363eb714da 100644 --- a/src/sdk/pynni/nni/compression/speedup/torch/compressor.py +++ b/src/sdk/pynni/nni/compression/speedup/torch/compressor.py @@ -229,42 +229,39 @@ def _extract_leaf_modules(self, graph): list a list of scope name of all the leaf modules """ - pieces = [] # each element is a dict + class SNode: + def __init__(self, name): + self.sname = name + self.childs = {} + + root = None for node in graph.nodes(): scope_name = node.scopeName() if scope_name == '': continue segs = scope_name.split('/') - segs_len = len(segs) - # increase the length of `pieces` if not enough - for _ in range(segs_len - len(pieces)): - pieces.append({}) - # process internal segments of the scope name - # 'L' means leaf segment - # 'I' means internal segment - # internal segment can replace leaf segment at the same position of `pieces` - for i, seg in enumerate(segs[:-1]): - seg_name_dict = pieces[i] - if seg in seg_name_dict: - if seg_name_dict[seg][0] == 'L': - seg_name_dict[seg] = ('I', node) - else: - seg_name_dict[seg] = ('I', node) - # process the leaf segment of the scope name - last_segs_dict = pieces[len(segs) - 1] - if not segs[-1] in last_segs_dict: - last_segs_dict[segs[-1]] = ('L', node) - # traverse `pieces` to obtain all the leaf modules which are labeled with 'L' - leaf_modules = [] - for piece in pieces: - for _, value in piece.items(): - if value[0] == 'L': - assert value[1].scopeName() not in leaf_modules - # if this is a leaf module, the last segment of its scope name - # must be in pattern `xxx[xxx]` - if value[1].scopeName()[-1] == ']': - leaf_modules.append(value[1].scopeName()) - return leaf_modules + if root is None: + root = SNode(segs[0]) + curr = root + for seg in segs[1:]: + if not seg in curr.childs: + curr.childs[seg] = SNode(seg) + curr = curr.childs[seg] + + leaf_nodes = [] + def traverse_tree(node, scope_name): + if scope_name == '': + sn = node.sname + else: + sn = scope_name + '/' + node.sname + if not node.childs: + if node.sname[-1] == ']': + leaf_nodes.append(sn) + else: + for key in node.childs: + traverse_tree(node.childs[key], sn) + traverse_tree(root, '') + return leaf_nodes def _build_graph(self): """