diff --git a/python/tvm/script/parser/core/diagnostics.py b/python/tvm/script/parser/core/diagnostics.py index d673e0eb139f..ad7ae5034780 100644 --- a/python/tvm/script/parser/core/diagnostics.py +++ b/python/tvm/script/parser/core/diagnostics.py @@ -163,9 +163,10 @@ def findsource(obj): name = tokens[1].split(":")[0].split("(")[0] + "" elif tokens[0] == "class": name = tokens[1].split(":")[0].split("(")[0] + # pop scope if we are less indented + while scope_stack and indent_info[scope_stack[-1]] >= indent: + scope_stack.pop() if name: - while scope_stack and indent_info[scope_stack[-1]] >= indent: - scope_stack.pop() scope_stack.append(name) indent_info[name] = indent if scope_stack == qual_names: diff --git a/tests/python/unittest/test_tvmscript_parser_source.py b/tests/python/unittest/test_tvmscript_parser_source.py index f5dc17fdfe56..359583c1aa06 100644 --- a/tests/python/unittest/test_tvmscript_parser_source.py +++ b/tests/python/unittest/test_tvmscript_parser_source.py @@ -82,5 +82,20 @@ def test_source_ast(): assert isinstance(for_block, doc.With) and len(for_block.body) == 2 +def test_nesting_parsing(): + class dummy: + pass + + for i in range(1): + + @tvm.script.ir_module + class Module: + @T.prim_func + def impl( + A: T.Buffer[(12, 196, 64), "float32"], + ) -> None: + T.evaluate(0) + + if __name__ == "__main__": tvm.testing.main()