Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[UX] highlight tvm script #12197

Merged
merged 9 commits into from
Jul 29, 2022
1 change: 1 addition & 0 deletions python/gen_requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
(
"Base requirements needed to install tvm",
[
"Pygments",
ganler marked this conversation as resolved.
Show resolved Hide resolved
"attrs",
"cloudpickle",
"decorator",
Expand Down
13 changes: 13 additions & 0 deletions python/tvm/ir/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,19 @@ def script(self, tir_prefix: str = "T", show_meta: bool = False) -> str:
self, tir_prefix, show_meta
) # type: ignore

def show(self, style: str = "light") -> None:
"""
A sugar for print highlighted TVM script.
Parameters
----------
style : str, optional
Pygments styles extended by "light" (default) and "dark", by default "light"
"""
from tvm.script.highlight import cprint # pylint: disable=import-outside-toplevel

# Use deferred import to avoid circular import while keeping cprint under tvm/script
cprint(self, style=style)

def get_attr(self, attr_key):
"""Get the IRModule attribute.

Expand Down
86 changes: 86 additions & 0 deletions python/tvm/script/highlight.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Highlight printed TVM script.
"""

from typing import Union

from pygments import highlight
from pygments.lexers import Python3Lexer
from pygments.formatters import Terminal256Formatter
from pygments.style import Style
from pygments.token import Keyword, Name, Comment, String, Number, Operator

from tvm.ir import IRModule
from tvm.tir import PrimFunc


class VSCDark(Style):
"""A VSCode-Dark-like Pygments style configuration"""

styles = {
Keyword: "bold #c586c0",
Keyword.Namespace: "#4ec9b0",
Keyword.Type: "#82aaff",
Name.Function: "bold #dcdcaa",
Name.Class: "bold #569cd6",
Name.Decorator: "italic #fe4ef3",
String: "#ce9178",
Number: "#b5cea8",
Operator: "#bbbbbb",
Operator.Word: "#569cd6",
Comment: "italic #6a9956",
}


class JupyterLight(Style):
"""A Jupyter-Notebook-like Pygments style configuration"""

styles = {
Keyword: "bold #008000",
Keyword.Type: "nobold #008000",
Name.Function: "#0000FF",
Name.Class: "bold #0000FF",
Name.Decorator: "#AA22FF",
String: "#BA2121",
Number: "#008000",
Operator: "bold #AA22FF",
Operator.Word: "bold #008000",
Comment: "italic #007979",
}


def cprint(printable: Union[IRModule, PrimFunc], style="light") -> None:
"""
Print highlighted TVM script string with Pygments
Parameters
----------
printable : Union[IRModule, PrimFunc]
The TVM script to be printed
style : str, optional
Style of the printed script
Notes
-----
The style parameter follows the Pygments style names or Style objects. Two
built-in styles are extended: "light" (default) and "dark". Other styles
can be found in https://pygments.org/styles/
"""
if style == "light":
style = JupyterLight
elif style == "dark":
style = VSCDark
print(highlight(printable.script(), Python3Lexer(), Terminal256Formatter(style=style)))
13 changes: 13 additions & 0 deletions python/tvm/tir/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,19 @@ def script(self, tir_prefix: str = "T", show_meta: bool = False) -> str:
self, tir_prefix, show_meta
) # type: ignore

def show(self, style: str = "light") -> None:
"""
A sugar for print highlighted TVM script.
Parameters
----------
style : str, optional
Pygments styles extended by "light" (default) and "dark", by default "light"
"""
from tvm.script.highlight import cprint # pylint: disable=import-outside-toplevel

# Use deferred import to avoid circular import while keeping cprint under tvm/script
cprint(self, style=style)


@tvm._ffi.register_object("tir.TensorIntrin")
class TensorIntrin(Object):
Expand Down
26 changes: 26 additions & 0 deletions tests/python/unittest/test_tvmscript_printer_python_doc_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
# under the License.
import pytest

import tvm
from tvm.script import tir as T
from tvm.script.printer.doc_printer import to_python_script
from tvm.script.printer.doc import LiteralDoc

Expand Down Expand Up @@ -51,3 +53,27 @@ def format_script(s: str) -> str:
)
def test_print_literal_doc(doc, expected):
assert to_python_script(doc).rstrip("\n") == format_script(expected)


def test_highlight_script():
@tvm.script.ir_module
class Module:
@T.prim_func
def main( # type: ignore
a: T.handle,
b: T.handle,
c: T.handle,
) -> None: # pylint: disable=no-self-argument
T.func_attr({"global_symbol": "main", "tir.noalias": True})
A = T.match_buffer(a, [16, 128, 128])
B = T.match_buffer(b, [16, 128, 128])
C = T.match_buffer(c, [16, 128, 128])
for n, i, j, k in T.grid(16, 128, 128, 128):
with T.block("matmul"):
vn, vi, vj, vk = T.axis.remap("SSSR", [n, i, j, k])
with T.init():
C[vn, vi, vj] = 0.0 # type: ignore
C[vn, vi, vj] = C[vn, vi, vj] + A[vn, vi, vk] * B[vn, vj, vk]

Module.show()
Module["main"].show()