diff --git a/python/tvm/meta_schedule/testing/space_generation.py b/python/tvm/meta_schedule/testing/space_generation.py index a9b4d9e471b2c..f85faca13f7ae 100644 --- a/python/tvm/meta_schedule/testing/space_generation.py +++ b/python/tvm/meta_schedule/testing/space_generation.py @@ -29,7 +29,7 @@ def check_trace(spaces: List[Schedule], expected: List[List[str]]): for space in spaces: trace = Trace(space.trace.insts, {}) trace = trace.simplified(remove_postproc=True) - str_trace = "\n".join(str(trace).strip().splitlines()) + str_trace = "\n".join(t[2:] for t in str(trace).strip().splitlines()[2:] if t != " pass") actual_traces.add(str_trace) assert str_trace in expected_traces, "\n" + str_trace assert len(expected_traces) == len(actual_traces) diff --git a/python/tvm/script/highlight.py b/python/tvm/script/highlight.py index 5a9c69a0ff20f..dc45b5a3f1cd4 100644 --- a/python/tvm/script/highlight.py +++ b/python/tvm/script/highlight.py @@ -17,20 +17,20 @@ """Highlight printed TVM script. """ -from typing import Union, Optional -import warnings import sys +import warnings +from typing import Optional, Union from tvm.ir import IRModule from tvm.tir import PrimFunc -def cprint(printable: Union[IRModule, PrimFunc], style: Optional[str] = None) -> None: +def cprint(printable: Union[IRModule, PrimFunc, str], style: Optional[str] = None) -> None: """ Print highlighted TVM script string with Pygments Parameters ---------- - printable : Union[IRModule, PrimFunc] + printable : Union[IRModule, PrimFunc, str] The TVM script to be printed style : str, optional Printing style, auto-detected if None. @@ -44,16 +44,17 @@ def cprint(printable: Union[IRModule, PrimFunc], style: Optional[str] = None) -> installing the Pygment library. Other Pygment styles can be found in https://pygments.org/styles/ """ - + if isinstance(printable, (IRModule, PrimFunc)): + printable = printable.script() try: # pylint: disable=import-outside-toplevel import pygments + from packaging import version from pygments import highlight + from pygments.formatters import HtmlFormatter, Terminal256Formatter from pygments.lexers.python import Python3Lexer - from pygments.formatters import Terminal256Formatter, HtmlFormatter from pygments.style import Style - from pygments.token import Keyword, Name, Comment, String, Number, Operator - from packaging import version + from pygments.token import Comment, Keyword, Name, Number, Operator, String if version.parse(pygments.__version__) < version.parse("2.4.0"): raise ImportError("Required Pygments version >= 2.4.0 but got " + pygments.__version__) @@ -68,7 +69,7 @@ def cprint(printable: Union[IRModule, PrimFunc], style: Optional[str] = None) -> + install_cmd, category=UserWarning, ) - print(printable.script()) + print(printable) else: class JupyterLight(Style): @@ -136,11 +137,14 @@ class AnsiTerminalDefault(Style): style = AnsiTerminalDefault if is_in_notebook: # print with HTML display - from IPython.display import display, HTML # pylint: disable=import-outside-toplevel + from IPython.display import ( # pylint: disable=import-outside-toplevel + HTML, + display, + ) formatter = HtmlFormatter(style=JupyterLight) formatter.noclasses = True # inline styles - html = highlight(printable.script(), Python3Lexer(), formatter) + html = highlight(printable, Python3Lexer(), formatter) display(HTML(html)) else: - print(highlight(printable.script(), Python3Lexer(), Terminal256Formatter(style=style))) + print(highlight(printable, Python3Lexer(), Terminal256Formatter(style=style))) diff --git a/python/tvm/tir/schedule/trace.py b/python/tvm/tir/schedule/trace.py index 18bcca373dbb6..da599081df3bd 100644 --- a/python/tvm/tir/schedule/trace.py +++ b/python/tvm/tir/schedule/trace.py @@ -258,3 +258,17 @@ def apply_json_to_schedule(json_obj: JSON_TYPE, sch: "Schedule") -> None: The TensorIR schedule """ _ffi_api.TraceApplyJSONToSchedule(json_obj, sch) # type: ignore # pylint: disable=no-member + + def show(self, style: Optional[str] = None) -> None: + """A sugar for print highlighted trace. + + Parameters + ---------- + style : str, optional + Pygments styles extended by "light" (default) and "dark", by default "light" + """ + from tvm.script.highlight import ( # pylint: disable=import-outside-toplevel + cprint, + ) + + cprint(str(self), style=style) diff --git a/src/tir/schedule/trace.cc b/src/tir/schedule/trace.cc index a3648a7174c0a..e3f3585595268 100644 --- a/src/tir/schedule/trace.cc +++ b/src/tir/schedule/trace.cc @@ -476,6 +476,8 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& obj, ReprPrinter* p) { const auto* self = obj.as(); ICHECK_NOTNULL(self); + p->stream << "# from tvm import tir" << std::endl; + p->stream << "def apply_trace(sch: tir.Schedule) -> None:" << std::endl; Array repr = self->AsPython(/*remove_postproc=*/false); bool is_first = true; for (const String& line : repr) { @@ -484,7 +486,10 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) } else { p->stream << std::endl; } - p->stream << line; + p->stream << " " << line; + } + if (is_first) { + p->stream << " pass"; } }); diff --git a/tests/python/unittest/test_meta_schedule_post_order_apply.py b/tests/python/unittest/test_meta_schedule_post_order_apply.py index 97a49602fb264..b40ba2869d1c5 100644 --- a/tests/python/unittest/test_meta_schedule_post_order_apply.py +++ b/tests/python/unittest/test_meta_schedule_post_order_apply.py @@ -322,18 +322,22 @@ def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]: def correct_trace(a, b, c, d): return "\n".join( [ - 'b0 = sch.get_block(name="A", func_name="main")', - 'b1 = sch.get_block(name="B", func_name="main")', - 'b2 = sch.get_block(name="C", func_name="main")', - "sch.compute_inline(block=b1)", - "l3, l4 = sch.get_loops(block=b2)", - "l5, l6 = sch.split(loop=l3, factors=" + str(a) + ", preserve_unit_iters=True)", - "l7, l8 = sch.split(loop=l4, factors=" + str(b) + ", preserve_unit_iters=True)", - "sch.reorder(l5, l7, l6, l8)", - "l9, l10 = sch.get_loops(block=b0)", - "l11, l12 = sch.split(loop=l9, factors=" + str(c) + ", preserve_unit_iters=True)", - "l13, l14 = sch.split(loop=l10, factors=" + str(d) + ", preserve_unit_iters=True)", - "sch.reorder(l11, l13, l12, l14)", + "# from tvm import tir", + "def apply_trace(sch: tir.Schedule) -> None:", + ' b0 = sch.get_block(name="A", func_name="main")', + ' b1 = sch.get_block(name="B", func_name="main")', + ' b2 = sch.get_block(name="C", func_name="main")', + " sch.compute_inline(block=b1)", + " l3, l4 = sch.get_loops(block=b2)", + " l5, l6 = sch.split(loop=l3, factors=" + str(a) + ", preserve_unit_iters=True)", + " l7, l8 = sch.split(loop=l4, factors=" + str(b) + ", preserve_unit_iters=True)", + " sch.reorder(l5, l7, l6, l8)", + " l9, l10 = sch.get_loops(block=b0)", + " l11, l12 = sch.split(loop=l9, factors=" + str(c) + ", preserve_unit_iters=True)", + " l13, l14 = sch.split(loop=l10, factors=" + + str(d) + + ", preserve_unit_iters=True)", + " sch.reorder(l11, l13, l12, l14)", ] ) diff --git a/tests/python/unittest/test_tir_schedule_trace.py b/tests/python/unittest/test_tir_schedule_trace.py index 684fe22faccf0..8a5155bcba431 100644 --- a/tests/python/unittest/test_tir_schedule_trace.py +++ b/tests/python/unittest/test_tir_schedule_trace.py @@ -163,8 +163,10 @@ def test_trace_construct_1(): trace = _make_trace_1(BlockRV(), LoopRV(), LoopRV()) assert str(trace) == "\n".join( ( - 'b0 = sch.get_block(name="block", func_name="main")', - "l1, l2 = sch.get_loops(block=b0)", + "# from tvm import tir", + "def apply_trace(sch: tir.Schedule) -> None:", + ' b0 = sch.get_block(name="block", func_name="main")', + " l1, l2 = sch.get_loops(block=b0)", ) ) assert len(trace.insts) == 2 @@ -182,9 +184,11 @@ def test_trace_construct_append_1(): trace.append(inst=_make_get_block("block2", BlockRV())) assert str(trace) == "\n".join( ( - 'b0 = sch.get_block(name="block", func_name="main")', - "l1, l2 = sch.get_loops(block=b0)", - 'b3 = sch.get_block(name="block2", func_name="main")', + "# from tvm import tir", + "def apply_trace(sch: tir.Schedule) -> None:", + ' b0 = sch.get_block(name="block", func_name="main")', + " l1, l2 = sch.get_loops(block=b0)", + ' b3 = sch.get_block(name="block2", func_name="main")', ) ) @@ -193,14 +197,32 @@ def test_trace_construct_pop_1(): trace = _make_trace_1(BlockRV(), LoopRV(), LoopRV()) last_inst = trace.insts[-1] assert trace.pop().same_as(last_inst) - assert str(trace) == 'b0 = sch.get_block(name="block", func_name="main")' + assert str(trace) == "\n".join( + ( + "# from tvm import tir", + "def apply_trace(sch: tir.Schedule) -> None:", + ' b0 = sch.get_block(name="block", func_name="main")', + ) + ) def test_trace_construct_pop_2(): trace = Trace([], {}) - assert str(trace) == "" + assert str(trace) == "\n".join( + ( + "# from tvm import tir", + "def apply_trace(sch: tir.Schedule) -> None:", + " pass", + ) + ) assert trace.pop() is None - assert str(trace) == "" + assert str(trace) == "\n".join( + ( + "# from tvm import tir", + "def apply_trace(sch: tir.Schedule) -> None:", + " pass", + ) + ) def test_trace_apply_to_schedule(): @@ -226,18 +248,22 @@ def test_trace_simplified_1(): trace = _make_trace_3(BlockRV(), BlockRV(), add_postproc=True) assert str(trace) == "\n".join( ( - 'b0 = sch.get_block(name="B", func_name="main")', - "sch.compute_inline(block=b0)", - 'b1 = sch.get_block(name="C", func_name="main")', - "sch.enter_postproc()", - "sch.compute_inline(block=b1)", + "# from tvm import tir", + "def apply_trace(sch: tir.Schedule) -> None:", + ' b0 = sch.get_block(name="B", func_name="main")', + " sch.compute_inline(block=b0)", + ' b1 = sch.get_block(name="C", func_name="main")', + " sch.enter_postproc()", + " sch.compute_inline(block=b1)", ) ) trace = trace.simplified(remove_postproc=True) assert str(trace) == "\n".join( ( - 'b0 = sch.get_block(name="B", func_name="main")', - "sch.compute_inline(block=b0)", + "# from tvm import tir", + "def apply_trace(sch: tir.Schedule) -> None:", + ' b0 = sch.get_block(name="B", func_name="main")', + " sch.compute_inline(block=b0)", ) ) @@ -246,21 +272,26 @@ def test_trace_simplified_2(): trace = _make_trace_3(BlockRV(), BlockRV(), add_postproc=True) assert str(trace) == "\n".join( ( - 'b0 = sch.get_block(name="B", func_name="main")', - "sch.compute_inline(block=b0)", - 'b1 = sch.get_block(name="C", func_name="main")', - "sch.enter_postproc()", - "sch.compute_inline(block=b1)", + "# from tvm import tir", + "def apply_trace(sch: tir.Schedule) -> None:", + ' b0 = sch.get_block(name="B", func_name="main")', + " sch.compute_inline(block=b0)", + ' b1 = sch.get_block(name="C", func_name="main")', + " sch.enter_postproc()", + " sch.compute_inline(block=b1)", ) ) trace = trace.simplified(remove_postproc=False) + print(trace.show()) assert str(trace) == "\n".join( ( - 'b0 = sch.get_block(name="B", func_name="main")', - "sch.compute_inline(block=b0)", - 'b1 = sch.get_block(name="C", func_name="main")', - "sch.enter_postproc()", - "sch.compute_inline(block=b1)", + "# from tvm import tir", + "def apply_trace(sch: tir.Schedule) -> None:", + ' b0 = sch.get_block(name="B", func_name="main")', + " sch.compute_inline(block=b0)", + ' b1 = sch.get_block(name="C", func_name="main")', + " sch.enter_postproc()", + " sch.compute_inline(block=b1)", ) ) @@ -269,9 +300,11 @@ def test_trace_simplified_3(): trace = _make_trace_4(BlockRV(), LoopRV(), LoopRV(), LoopRV()).simplified(remove_postproc=False) assert str(trace) == "\n".join( ( - 'b0 = sch.get_block(name="B", func_name="main")', - "l1, = sch.get_loops(block=b0)", - "l2, l3 = sch.split(loop=l1, factors=[None, 32], preserve_unit_iters=True)", + "# from tvm import tir", + "def apply_trace(sch: tir.Schedule) -> None:", + ' b0 = sch.get_block(name="B", func_name="main")', + " l1, = sch.get_loops(block=b0)", + " l2, l3 = sch.split(loop=l1, factors=[None, 32], preserve_unit_iters=True)", ) ) @@ -335,4 +368,5 @@ def test_apply_annotation_from_json(): if __name__ == "__main__": - tvm.testing.main() + test_trace_simplified_2() + # tvm.testing.main()