Skip to content

Commit

Permalink
[Parser] Core Parser (apache#40)
Browse files Browse the repository at this point in the history
  • Loading branch information
junrushao authored Jun 8, 2022
1 parent 232a51d commit 4557682
Show file tree
Hide file tree
Showing 14 changed files with 589 additions and 2 deletions.
2 changes: 0 additions & 2 deletions python/tvm/script/builder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,5 @@
# under the License.
# pylint: disable=unused-import
"""Namespace for the TVMScript Builder API."""


from .builder import Builder, def_, def_many
from .frame import Frame, IRModuleFrame
4 changes: 4 additions & 0 deletions python/tvm/script/builder/tir/axis.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,7 @@ def reduce(dom, binding, dtype="int32") -> IterVar:

def remap(kinds, bindings, dtype="int32") -> IterVar:
return _ffi_api.AxisRemap(kinds, bindings, dtype) # pylint: disable=no-member # type: ignore


S = spatial
R = reduce
4 changes: 4 additions & 0 deletions python/tvm/script/builder/tir/prim_func_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from tvm.tir.buffer import Buffer
from tvm.tir.expr import Var

from ..builder import Builder
from . import _ffi_api
from .base import TIRFrame

Expand All @@ -36,3 +37,6 @@ def prim_func(name) -> PrimFuncFrame:

def arg(name, obj) -> Union[Var, Buffer]:
return _ffi_api.Arg(name, obj) # pylint: disable=no-member # type: ignore


setattr(prim_func, "dispatch_token", "tir")
19 changes: 19 additions & 0 deletions python/tvm/script/parse/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the Licens.
"""The parser"""
from . import dispatch, parser, tir
from .entry import parse
67 changes: 67 additions & 0 deletions python/tvm/script/parse/dispatch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""The dispatcher"""

import ast
from typing import TYPE_CHECKING, Callable, Dict, Optional, Tuple

if TYPE_CHECKING:
from .parser import Parser


ParseMethod = Callable[
["Parser", ast.AST],
None,
]


class DispatchTable:
"""Dispatch table for parse methods"""

_instance: Optional["DispatchTable"] = None
table: Dict[Tuple[str, str], ParseMethod]

def __init__(self):
self.table = {}


DispatchTable._instance = DispatchTable() # pylint: disable=protected-access


def register(
token: str,
type_name: str,
):
"""Register a method for a dispatch token and type name"""

def f(method: ParseMethod):
DispatchTable._instance.table[ # pylint: disable=protected-access
(token, type_name)
] = method

return f


def get(
token: str,
type_name: str,
default: Optional[ParseMethod] = None,
) -> Optional[ParseMethod]:
return DispatchTable._instance.table.get( # pylint: disable=protected-access
(token, type_name),
default,
)
84 changes: 84 additions & 0 deletions python/tvm/script/parse/entry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""The entry point of TVM parser."""
import ast
import inspect
from typing import Any, Dict, Optional, Union

from ..builder import Builder
from .parser import Parser


class SourceCode:
source_name: str
start_line: int
start_column: int
source: str
full_source: str

def __init__(self, program: Union[str, ast.AST]):
if isinstance(program, str):
self.source_name = "<str>"
self.start_line = 1
self.start_column = 0
self.source = program
self.full_source = program
else:
self.source_name = inspect.getsourcefile(program) # type: ignore
lines, self.start_line = inspect.getsourcelines(program) # type: ignore

if lines:
self.start_column = len(lines[0]) - len(lines[0].lstrip())
else:
self.start_column = 0
if self.start_column and lines:
self.source = "\n".join([l[self.start_column :].rstrip() for l in lines])
else:
self.source = ""
try:
# It will cause a problem when running in Jupyter Notebook.
# `mod` will be <module '__main__'>, which is a built-in module
# and `getsource` will throw a TypeError
mod = inspect.getmodule(program)
if mod:
self.full_source = inspect.getsource(mod)
else:
self.full_source = self.source
except TypeError:
# It's a work around for Jupyter problem.
# Since `findsource` is an internal API of inspect, we just use it
# as a fallback method.
src, _ = inspect.findsource(program) # type: ignore
self.full_source = "".join(src)

def as_ast(self) -> ast.AST:
return ast.parse(self.source)


def parse(
program: Union[ast.AST, Any, str],
extra_vars: Optional[Dict[str, Any]] = None,
):
program_ast = SourceCode(program).as_ast()
parser = Parser()
with Builder() as builder:
with parser.var_table.with_frame():
if extra_vars:
for k, v in extra_vars.items():
parser.var_table.add(k, v)
parser.visit(program_ast)
return builder.get()
61 changes: 61 additions & 0 deletions python/tvm/script/parse/evaluator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""AST Evaluation"""
import ast
from typing import Any, Dict, Optional, Union


def eval_expr(
node: Union[ast.expr, ast.Expression],
dict_globals: Optional[Dict[str, Any]],
) -> Any:
if isinstance(node, ast.expr):
node = ast.Expression(body=node)
assert isinstance(node, ast.Expression)
if dict_globals is None:
dict_globals = {}
node = ast.fix_missing_locations(node)
exe = compile(node, filename="<ast>", mode="eval")
return eval(exe, dict_globals) # pylint: disable=eval-used


def eval_assign(
target: ast.expr,
source: Any,
) -> Dict[str, Any]:
assert isinstance(target, ast.expr)
RHS_VAR_NAME = "__tvm_rhs_var__" # pylint: disable=invalid-name
rhs_var_name = RHS_VAR_NAME
dict_locals = {rhs_var_name: source}
mod = ast.fix_missing_locations(
ast.Module(
body=[
ast.Assign(
targets=[target],
value=ast.Name(
id=rhs_var_name,
ctx=ast.Load(),
),
)
],
type_ignores=[],
)
)
exe = compile(mod, filename="<ast>", mode="exec")
exec(exe, {}, dict_locals) # pylint: disable=exec-used
del dict_locals[rhs_var_name]
return dict_locals
109 changes: 109 additions & 0 deletions python/tvm/script/parse/parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""The core parser"""
import ast
from typing import Any, Dict, List, Optional, Union

from ..builder import def_
from . import dispatch
from .evaluator import eval_assign, eval_expr
from .utils import deferred
from .var_table import VarTable


def _dispatch(self: "Parser", type_name: str) -> dispatch.ParseMethod:
for token in [self.dispatch_tokens[-1], "default"]:
func = dispatch.get(token=token, type_name=type_name, default=None)
if func is not None:
return func
return lambda self, node: self.generic_visit(node)


def _handle_function(self: "Parser", node: ast.FunctionDef) -> None:
if not node.decorator_list:
self.report_error(node, "Function must be decorated")
# TODO: only the last decorator is parsed
decorator = self.eval_expr(node.decorator_list[-1])
if hasattr(decorator, "dispatch_token"):
token = decorator.dispatch_token
func = dispatch.get(token=token, type_name="FunctionDef", default=None)
if func is not None:
func(self, node)
return
self.report_error(node, "The parser does not understand the decorator")


class Parser(ast.NodeVisitor):
"""The TVMScript parser"""

dispatch_tokens: List[str]
var_table: VarTable

def __init__(self) -> None:
self.dispatch_tokens = ["default"]
self.var_table = VarTable()

def with_dispatch_token(self, token: str):
def pop_token():
self.dispatch_tokens.pop()

self.dispatch_tokens.append(token)
return deferred(pop_token)

def eval_expr(
self,
node: Union[ast.Expression, ast.expr],
extra_vars: Optional[Dict[str, Any]] = None,
) -> Any:
var_values = self.var_table.get()
if extra_vars is not None:
for k, v in extra_vars.items():
var_values[k] = v
return eval_expr(node, var_values)

def eval_assign(
self,
target: ast.expr,
source: Any,
) -> Dict[str, Any]:
var_values = eval_assign(target, source)
for k, v in var_values.items():
def_(k, v)
self.var_table.add(k, v)
return var_values

def report_error(self, node: ast.AST, msg: str) -> None: # pylint: disable=no-self-use
raise SyntaxError(f"At {node.lineno}:{node.col_offset}: {msg}")

def visit_FunctionDef(self, node: ast.FunctionDef) -> Any: # pylint: disable=invalid-name
_handle_function(self, node)

def visit_body(self, node: List[ast.stmt]) -> Any:
for stmt in node:
self.visit(stmt)

def visit_arguments(self, node: ast.arguments) -> Any:
_dispatch(self, "arguments")(self, node)

def visit_For(self, node: ast.For) -> Any: # pylint: disable=invalid-name
_dispatch(self, "For")(self, node)

def visit_With(self, node: ast.With) -> Any: # pylint: disable=invalid-name
_dispatch(self, "With")(self, node)

def visit_Assign(self, node: ast.Assign) -> Any: # pylint: disable=invalid-name
_dispatch(self, "Assign")(self, node)
17 changes: 17 additions & 0 deletions python/tvm/script/parse/tir/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from . import tir
Loading

0 comments on commit 4557682

Please sign in to comment.