Skip to content
This repository has been archived by the owner on May 22, 2023. It is now read-only.

Commit

Permalink
[UX] Adopt changes from tvm-main and render code with IPython.display (
Browse files Browse the repository at this point in the history
…#192)

Render code with IPython.display.HTML if possible to fix the ansi-escape 24-bit rendering issue in Colab.
  • Loading branch information
ganler authored Aug 3, 2022
1 parent 660b437 commit dbc0a39
Show file tree
Hide file tree
Showing 4 changed files with 168 additions and 55 deletions.
1 change: 1 addition & 0 deletions docker/install/ubuntu_install_python_package.sh
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ set -o pipefail

# install libraries for python package on ubuntu
pip3 install --upgrade \
"Pygments>=2.4.0" \
attrs \
cloudpickle \
cython \
Expand Down
1 change: 0 additions & 1 deletion python/gen_requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@
(
"Base requirements needed to install tvm",
[
"Pygments",
"attrs",
"cloudpickle",
"decorator",
Expand Down
166 changes: 112 additions & 54 deletions python/tvm/script/highlight.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,72 +17,130 @@
"""Highlight printed TVM script.
"""

from typing import Union

from pygments import highlight
from pygments.lexers import Python3Lexer
from pygments.formatters import Terminal256Formatter
from pygments.style import Style
from pygments.token import Keyword, Name, Comment, String, Number, Operator
from typing import Union, Optional
import warnings
import sys

from tvm.ir import IRModule
from tvm.tir import PrimFunc


class VSCDark(Style):
"""A VSCode-Dark-like Pygments style configuration"""

styles = {
Keyword: "bold #c586c0",
Keyword.Namespace: "#4ec9b0",
Keyword.Type: "#82aaff",
Name.Function: "bold #dcdcaa",
Name.Class: "bold #569cd6",
Name.Decorator: "italic #fe4ef3",
String: "#ce9178",
Number: "#b5cea8",
Operator: "#bbbbbb",
Operator.Word: "#569cd6",
Comment: "italic #6a9956",
}


class JupyterLight(Style):
"""A Jupyter-Notebook-like Pygments style configuration"""

styles = {
Keyword: "bold #008000",
Keyword.Type: "nobold #008000",
Name.Function: "#0000FF",
Name.Class: "bold #0000FF",
Name.Decorator: "#AA22FF",
String: "#BA2121",
Number: "#008000",
Operator: "bold #AA22FF",
Operator.Word: "bold #008000",
Comment: "italic #007979",
}


def cprint(printable: Union[IRModule, PrimFunc], style="light") -> None:
def cprint(printable: Union[IRModule, PrimFunc], style: Optional[str] = None) -> None:
"""
Print highlighted TVM script string with Pygments
Parameters
----------
printable : Union[IRModule, PrimFunc]
The TVM script to be printed
style : str, optional
Style of the printed script
Printing style, auto-detected if None.
Notes
-----
The style parameter follows the Pygments style names or Style objects. Two
built-in styles are extended: "light" (default) and "dark". Other styles
can be found in https://pygments.org/styles/
The style parameter follows the Pygments style names or Style objects. Three
built-in styles are extended: "light", "dark" and "ansi". By default, "light"
will be used for notebook environment and terminal style will be "ansi" for
better style consistency. As an fallback when the optional Pygment library is
not installed, plain text will be printed with a one-time warning to suggest
installing the Pygment library. Other Pygment styles can be found in
https://pygments.org/styles/
"""
if style == "light":
style = JupyterLight
elif style == "dark":
style = VSCDark
print(highlight(printable.script(), Python3Lexer(), Terminal256Formatter(style=style)))

try:
# pylint: disable=import-outside-toplevel
import pygments
from pygments import highlight
from pygments.lexers.python import Python3Lexer
from pygments.formatters import Terminal256Formatter, HtmlFormatter
from pygments.style import Style
from pygments.token import Keyword, Name, Comment, String, Number, Operator
from packaging import version

if version.parse(pygments.__version__) < version.parse("2.4.0"):
raise ImportError("Required Pygments version >= 2.4.0 but got " + pygments.__version__)
except ImportError as err:
with warnings.catch_warnings():
warnings.simplefilter("once", UserWarning)
install_cmd = sys.executable + ' -m pip install "Pygments>=2.4.0" --upgrade --user'
warnings.warn(
str(err)
+ "\n"
+ "To print highlighted TVM script, please install Pygments:\n"
+ install_cmd,
category=UserWarning,
)
print(printable.script())
else:

class JupyterLight(Style):
"""A Jupyter-Notebook-like Pygments style configuration (aka. "light")"""

background_color = ""
styles = {
Keyword: "bold #008000",
Keyword.Type: "nobold #008000",
Name.Function: "#1E90FF",
Name.Class: "bold #1E90FF",
Name.Decorator: "#AA22FF",
String: "#BA2121",
Number: "#008000",
Operator: "bold #AA22FF",
Operator.Word: "bold #008000",
Comment: "italic #007979",
}

class VSCDark(Style):
"""A VSCode-Dark-like Pygments style configuration (aka. "dark")"""

background_color = ""
styles = {
Keyword: "bold #c586c0",
Keyword.Type: "#82aaff",
Keyword.Namespace: "#4ec9b0",
Name.Class: "bold #569cd6",
Name.Function: "bold #dcdcaa",
Name.Decorator: "italic #fe4ef3",
String: "#ce9178",
Number: "#b5cea8",
Operator: "#bbbbbb",
Operator.Word: "#569cd6",
Comment: "italic #6a9956",
}

class AnsiTerminalDefault(Style):
"""The default style for terminal display with ANSI colors (aka. "ansi")"""

background_color = ""
styles = {
Keyword: "bold ansigreen",
Keyword.Type: "nobold ansigreen",
Name.Class: "bold ansiblue",
Name.Function: "bold ansiblue",
Name.Decorator: "italic ansibrightmagenta",
String: "ansiyellow",
Number: "ansibrightgreen",
Operator: "bold ansimagenta",
Operator.Word: "bold ansigreen",
Comment: "italic ansibrightblack",
}

is_in_notebook = "ipykernel" in sys.modules # in notebook env (support html display).

if style is None:
# choose style automatically according to the environment:
style = JupyterLight if is_in_notebook else AnsiTerminalDefault
elif style == "light":
style = JupyterLight
elif style == "dark":
style = VSCDark
elif style == "ansi":
style = AnsiTerminalDefault

if is_in_notebook: # print with HTML display
from IPython.display import display, HTML # pylint: disable=import-outside-toplevel

formatter = HtmlFormatter(style=JupyterLight)
formatter.noclasses = True # inline styles
html = highlight(printable.script(), Python3Lexer(), formatter)
display(HTML(html))
else:
print(highlight(printable.script(), Python3Lexer(), Terminal256Formatter(style=style)))
55 changes: 55 additions & 0 deletions tests/python/unittest/test_tvmscript_printer_highlight.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import pytest

import tvm
from tvm.script import tir as T, relax as R


def test_highlight_script():
@tvm.script.ir_module
class Module:
@T.prim_func
def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None:
T.func_attr({"global_symbol": "tir_matmul"})
k = T.var("int32")
A = T.match_buffer(x, (32, 32))
B = T.match_buffer(y, (32, 32))
C = T.match_buffer(z, (32, 32))

for (i0, j0, k0) in T.grid(32, 32, 32):
with T.block():
i, j, k = T.axis.remap("SSR", [i0, j0, k0])
with T.init():
C[i, j] = 0.0
C[i, j] += A[i, k] * B[j, k]

@R.function
def main(x: Tensor((32, 32), "float32"), w: Tensor((32, 32), "float32")) -> Tensor:
with R.dataflow():
lv0 = R.call_tir(tir_matmul, (x, w), (32, 32), dtype="float32")
R.output(lv0)
return lv0

Module.show()
Module["main"].show()
Module["tir_matmul"].show()
Module["main"].show(style="light")
Module["main"].show(style="dark")
Module["main"].show(style="ansi")

0 comments on commit dbc0a39

Please sign in to comment.