diff --git a/src/sdk/pynni/nni/compression/speedup/torch/compressor.py b/src/sdk/pynni/nni/compression/speedup/torch/compressor.py
index c594aa2ff8..363eb714da 100644
--- a/src/sdk/pynni/nni/compression/speedup/torch/compressor.py
+++ b/src/sdk/pynni/nni/compression/speedup/torch/compressor.py
@@ -229,42 +229,39 @@ def _extract_leaf_modules(self, graph):
         list
             a list of scope name of all the leaf modules
         """
-        pieces = [] # each element is a dict
+        class SNode:
+            def __init__(self, name):
+                self.sname = name
+                self.childs = {}
+
+        root = None
         for node in graph.nodes():
             scope_name = node.scopeName()
             if scope_name == '':
                 continue
             segs = scope_name.split('/')
-            segs_len = len(segs)
-            # increase the length of `pieces` if not enough
-            for _ in range(segs_len - len(pieces)):
-                pieces.append({})
-            # process internal segments of the scope name
-            # 'L' means leaf segment
-            # 'I' means internal segment
-            # internal segment can replace leaf segment at the same position of `pieces`
-            for i, seg in enumerate(segs[:-1]):
-                seg_name_dict = pieces[i]
-                if seg in seg_name_dict:
-                    if seg_name_dict[seg][0] == 'L':
-                        seg_name_dict[seg] = ('I', node)
-                else:
-                    seg_name_dict[seg] = ('I', node)
-            # process the leaf segment of the scope name
-            last_segs_dict = pieces[len(segs) - 1]
-            if not segs[-1] in last_segs_dict:
-                last_segs_dict[segs[-1]] = ('L', node)
-        # traverse `pieces` to obtain all the leaf modules which are labeled with 'L'
-        leaf_modules = []
-        for piece in pieces:
-            for _, value in piece.items():
-                if value[0] == 'L':
-                    assert value[1].scopeName() not in leaf_modules
-                    # if this is a leaf module, the last segment of its scope name
-                    # must be in pattern `xxx[xxx]`
-                    if value[1].scopeName()[-1] == ']':
-                        leaf_modules.append(value[1].scopeName())
-        return leaf_modules
+            if root is None:
+                root = SNode(segs[0])
+            curr = root
+            for seg in segs[1:]:
+                if not seg in curr.childs:
+                    curr.childs[seg] = SNode(seg)
+                curr = curr.childs[seg]
+
+        leaf_nodes = []
+        def traverse_tree(node, scope_name):
+            if scope_name == '':
+                sn = node.sname
+            else:
+                sn = scope_name + '/' + node.sname
+            if not node.childs:
+                if node.sname[-1] == ']':
+                    leaf_nodes.append(sn)
+            else:
+                for key in node.childs:
+                    traverse_tree(node.childs[key], sn)
+        traverse_tree(root, '')
+        return leaf_nodes
 
     def _build_graph(self):
         """