From 91645d80d5d3105c5b8392d554f8a51ce256ce7c Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Fri, 24 Jun 2022 01:54:45 -0700 Subject: [PATCH] Allow nest parsing (#50) --- python/tvm/script/builder/__init__.py | 2 +- python/tvm/script/builder/frame.py | 8 --- python/tvm/script/builder/ir/__init__.py | 19 ++++++ python/tvm/script/builder/ir/_ffi_api.py | 20 ++++++ python/tvm/script/builder/ir/ir.py | 45 +++++++++++++ .../tvm/script/builder/tir/prim_func_frame.py | 22 +++++- python/tvm/script/builder/tir/stmt.py | 12 ++-- python/tvm/script/parse/__init__.py | 2 +- python/tvm/script/parse/entry.py | 2 +- python/tvm/script/parse/ir/__init__.py | 17 +++++ python/tvm/script/parse/ir/ir.py | 28 ++++++++ python/tvm/script/parse/parser.py | 40 +++++++---- src/script/builder/ir/ir.cc | 5 +- src/script/builder/ir/ir.h | 3 +- src/script/builder/tir/prim_func_frame.cc | 2 +- tests/python/tvmscript/test_parse_basic.py | 67 ++++++++++++++----- 16 files changed, 238 insertions(+), 56 deletions(-) create mode 100644 python/tvm/script/builder/ir/__init__.py create mode 100644 python/tvm/script/builder/ir/_ffi_api.py create mode 100644 python/tvm/script/builder/ir/ir.py create mode 100644 python/tvm/script/parse/ir/__init__.py create mode 100644 python/tvm/script/parse/ir/ir.py diff --git a/python/tvm/script/builder/__init__.py b/python/tvm/script/builder/__init__.py index 087b3955afe7..9c161a42855d 100644 --- a/python/tvm/script/builder/__init__.py +++ b/python/tvm/script/builder/__init__.py @@ -17,4 +17,4 @@ # pylint: disable=unused-import """Namespace for the TVMScript Builder API.""" from .builder import Builder, def_, def_many -from .frame import Frame, IRModuleFrame +from .frame import Frame diff --git a/python/tvm/script/builder/frame.py b/python/tvm/script/builder/frame.py index faf0c4271d7d..6493b8fa11af 100644 --- a/python/tvm/script/builder/frame.py +++ b/python/tvm/script/builder/frame.py @@ -29,11 +29,3 @@ def __enter__(self) -> "Frame": def __exit__(self, ptype, value, trace) -> None: # pylint: disable=unused-argument _ffi_api.FrameExit(self) # pylint: disable=no-member # type: ignore - - -@_register_object("script.builder.IRModuleFrame") -class IRModuleFrame(Frame): - def __init__(self) -> None: - self.__init_handle_by_constructor__( - _ffi_api.IRModuleFrame # pylint: disable=no-member # type: ignore - ) diff --git a/python/tvm/script/builder/ir/__init__.py b/python/tvm/script/builder/ir/__init__.py new file mode 100644 index 000000000000..d81e4b2a8cf2 --- /dev/null +++ b/python/tvm/script/builder/ir/__init__.py @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""TVMScript IR""" + +from .ir import IRModuleFrame, ir_module diff --git a/python/tvm/script/builder/ir/_ffi_api.py b/python/tvm/script/builder/ir/_ffi_api.py new file mode 100644 index 000000000000..4a4cf40cf456 --- /dev/null +++ b/python/tvm/script/builder/ir/_ffi_api.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""FFI APIs for tvm.script.builder.ir""" +import tvm._ffi + +tvm._ffi._init_api("script.builder.ir", __name__) # pylint: disable=protected-access diff --git a/python/tvm/script/builder/ir/ir.py b/python/tvm/script/builder/ir/ir.py new file mode 100644 index 000000000000..c6f1b8c3eb5f --- /dev/null +++ b/python/tvm/script/builder/ir/ir.py @@ -0,0 +1,45 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""TVMScript IR""" + +import inspect +from typing import Optional, Type, Union + +from tvm._ffi import register_object as _register_object +from tvm.ir import IRModule + +from ..frame import Frame +from . import _ffi_api + + +@_register_object("script.builder.ir.IRModuleFrame") +class IRModuleFrame(Frame): + ... + + +def ir_module(f: Optional[Type] = None) -> Union[IRModuleFrame, IRModule]: + if f is not None: + from tvm.script.parse import parse # pylint: disable=import-outside-toplevel + + if not inspect.isclass(f): + raise TypeError(f"Expect a class, but got: {f}") + + return parse(f) + return _ffi_api.IRModule() # pylint: disable=no-member # type: ignore + + +setattr(ir_module, "dispatch_token", "ir") diff --git a/python/tvm/script/builder/tir/prim_func_frame.py b/python/tvm/script/builder/tir/prim_func_frame.py index 53ac43eb1ab3..3d4209efe59c 100644 --- a/python/tvm/script/builder/tir/prim_func_frame.py +++ b/python/tvm/script/builder/tir/prim_func_frame.py @@ -15,11 +15,12 @@ # specific language governing permissions and limitations # under the License. """TVM Script TIR Prim Func Frame""" +import inspect from typing import Any, Callable, Dict, Optional, Union from tvm._ffi import register_object as _register_object from tvm.ir import Type -from tvm.tir.buffer import Buffer +from tvm.tir import Buffer, PrimFunc from tvm.tir.expr import Var from . import _ffi_api @@ -31,12 +32,27 @@ class PrimFuncFrame(TIRFrame): ... -def prim_func(f: Optional[Callable] = None) -> PrimFuncFrame: +def _is_defined_in_class(frames): + if len(frames) > 2: + maybe_class_frame = frames[2] + statement_list = maybe_class_frame[4] + first_statement = statement_list[0] + if first_statement.strip().startswith("class "): + return True + return False + + +def prim_func(f: Optional[Callable] = None) -> Union[PrimFuncFrame, PrimFunc, Callable]: if f is not None: from tvm.script.parse import parse # pylint: disable=import-outside-toplevel + if not inspect.isfunction(f): + raise TypeError(f"Expect a function, but got: {f}") + + if _is_defined_in_class(inspect.stack()): + return f return parse(f) - return _ffi_api.PrimFuncFrame() # pylint: disable=no-member # type: ignore + return _ffi_api.PrimFunc() # pylint: disable=no-member # type: ignore setattr(prim_func, "dispatch_token", "tir") diff --git a/python/tvm/script/builder/tir/stmt.py b/python/tvm/script/builder/tir/stmt.py index 5d3ca58b7b9f..450eacb891a2 100644 --- a/python/tvm/script/builder/tir/stmt.py +++ b/python/tvm/script/builder/tir/stmt.py @@ -14,17 +14,17 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""TVM Script TIR For Frame""" -import numpy as np +"""TVMScript TIR statements""" from typing import List, Union +import numpy as np from tvm._ffi import register_object as _register_object -from tvm.tir import Buffer, IterVar, PrimExpr, Var, BufferRegion, Stmt, StringImm -from tvm.ir import Type, Range -from tvm.runtime import ndarray as nd, Object +from tvm.runtime import Object +from tvm.runtime import ndarray as nd +from tvm.tir import Buffer, BufferRegion, IterVar, PrimExpr, StringImm, Var -from . import _ffi_api from .. import _ffi_api as _base_ffi_api +from . import _ffi_api from .base import TIRFrame diff --git a/python/tvm/script/parse/__init__.py b/python/tvm/script/parse/__init__.py index 8844bbebe0d6..bae1e056558e 100644 --- a/python/tvm/script/parse/__init__.py +++ b/python/tvm/script/parse/__init__.py @@ -15,5 +15,5 @@ # specific language governing permissions and limitations # under the Licens. """The parser""" -from . import dispatch, doc, parser, tir +from . import dispatch, doc, parser, tir, ir from .entry import parse diff --git a/python/tvm/script/parse/entry.py b/python/tvm/script/parse/entry.py index a6a3e114ce85..1fefd02b07ea 100644 --- a/python/tvm/script/parse/entry.py +++ b/python/tvm/script/parse/entry.py @@ -16,7 +16,7 @@ # under the License. """The entry point of TVM parser.""" import inspect -from typing import Any, Dict, Optional, Union +from typing import Any, Union from ..builder import Builder from . import doc diff --git a/python/tvm/script/parse/ir/__init__.py b/python/tvm/script/parse/ir/__init__.py new file mode 100644 index 000000000000..d2e245918fd0 --- /dev/null +++ b/python/tvm/script/parse/ir/__init__.py @@ -0,0 +1,17 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from . import ir diff --git a/python/tvm/script/parse/ir/ir.py b/python/tvm/script/parse/ir/ir.py new file mode 100644 index 000000000000..44c473913ff5 --- /dev/null +++ b/python/tvm/script/parse/ir/ir.py @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from ...builder import Frame +from ...builder import ir as I +from .. import dispatch, doc +from ..parser import Parser + + +@dispatch.register(token="ir", type_name="ClassDef") +def visit_class_def(self: Parser, node: doc.ClassDef) -> None: + with self.var_table.with_frame(): + with I.ir_module(): + with self.with_dispatch_token("ir"): + self.visit_body(node.body) diff --git a/python/tvm/script/parse/parser.py b/python/tvm/script/parse/parser.py index e9939672d4bf..21c9ec21e157 100644 --- a/python/tvm/script/parse/parser.py +++ b/python/tvm/script/parse/parser.py @@ -32,20 +32,6 @@ def _dispatch(self: "Parser", type_name: str) -> dispatch.ParseMethod: return lambda self, node: self.generic_visit(node) -def _handle_function(self: "Parser", node: doc.FunctionDef) -> None: - if not node.decorator_list: - self.report_error(node, "Function must be decorated") - # TODO: only the last decorator is parsed - decorator = self.eval_expr(node.decorator_list[-1]) - if hasattr(decorator, "dispatch_token"): - token = decorator.dispatch_token - func = dispatch.get(token=token, type_name="FunctionDef", default=None) - if func is not None: - func(self, node) - return - self.report_error(node, "The parser does not understand the decorator") - - class Parser(doc.NodeVisitor): """The TVMScript parser""" @@ -91,6 +77,9 @@ def report_error(self, node: doc.AST, msg: str) -> None: # pylint: disable=no-s def visit_FunctionDef(self, node: doc.FunctionDef) -> Any: # pylint: disable=invalid-name _handle_function(self, node) + def visit_ClassDef(self, node: doc.ClassDef) -> Any: # pylint: disable=invalid-name + _handle_class(self, node) + def visit_body(self, node: List[doc.stmt]) -> Any: for stmt in node: self.visit(stmt) @@ -106,3 +95,26 @@ def visit_With(self, node: doc.With) -> Any: # pylint: disable=invalid-name def visit_Assign(self, node: doc.Assign) -> Any: # pylint: disable=invalid-name _dispatch(self, "Assign")(self, node) + + +def _handle_function(self: Parser, node: doc.FunctionDef) -> None: + if not node.decorator_list: + self.report_error(node, "Function must be decorated") + # TODO: only the last decorator is parsed + decorator = self.eval_expr(node.decorator_list[-1]) + if hasattr(decorator, "dispatch_token"): + token = decorator.dispatch_token + func = dispatch.get(token=token, type_name="FunctionDef", default=None) + if func is not None: + func(self, node) + return + self.report_error(node, "The parser does not understand the decorator") + + +def _handle_class(self: Parser, node: doc.ClassDef) -> None: + # TODO: assume IRModule + func = dispatch.get(token="ir", type_name="ClassDef", default=None) + if func is not None: + func(self, node) + return + self.report_error(node, "The parser does not understand the decorator") diff --git a/src/script/builder/ir/ir.cc b/src/script/builder/ir/ir.cc index 38d8eb1098bc..8f8abb2ef049 100644 --- a/src/script/builder/ir/ir.cc +++ b/src/script/builder/ir/ir.cc @@ -25,11 +25,11 @@ namespace script { namespace builder { namespace ir { -IRModuleFrame::IRModuleFrame() { +IRModuleFrame IRModule() { ObjectPtr n = make_object(); n->global_vars.clear(); n->functions.clear(); - data_ = std::move(n); + return IRModuleFrame(n); } void IRModuleFrameNode::ExitWithScope() { @@ -45,6 +45,7 @@ void IRModuleFrameNode::ExitWithScope() { } TVM_REGISTER_NODE_TYPE(IRModuleFrameNode); +TVM_REGISTER_GLOBAL("script.builder.ir.IRModule").set_body_typed(IRModule); } // namespace ir } // namespace builder diff --git a/src/script/builder/ir/ir.h b/src/script/builder/ir/ir.h index 890a06e3c76a..0fb84a245499 100644 --- a/src/script/builder/ir/ir.h +++ b/src/script/builder/ir/ir.h @@ -46,11 +46,10 @@ class IRModuleFrameNode : public FrameNode { class IRModuleFrame : public Frame { public: - IRModuleFrame(); TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IRModuleFrame, Frame, IRModuleFrameNode); }; -IRModuleFrame ir_module(); +IRModuleFrame IRModule(); } // namespace ir } // namespace builder diff --git a/src/script/builder/tir/prim_func_frame.cc b/src/script/builder/tir/prim_func_frame.cc index 595b037765c2..0cbdfe6da933 100644 --- a/src/script/builder/tir/prim_func_frame.cc +++ b/src/script/builder/tir/prim_func_frame.cc @@ -191,7 +191,7 @@ void PreflattenedBuffer(tvm::tir::Buffer postflattened_buffer, Array s }; TVM_REGISTER_NODE_TYPE(PrimFuncFrameNode); -TVM_REGISTER_GLOBAL("script.builder.tir.PrimFuncFrame").set_body_typed(PrimFunc); +TVM_REGISTER_GLOBAL("script.builder.tir.PrimFunc").set_body_typed(PrimFunc); TVM_REGISTER_GLOBAL("script.builder.tir.Arg") .set_body_typed([](String name, ObjectRef obj) -> ObjectRef { using namespace tvm::tir; diff --git a/tests/python/tvmscript/test_parse_basic.py b/tests/python/tvmscript/test_parse_basic.py index 8a9ecf4656dc..4809d7925ed4 100644 --- a/tests/python/tvmscript/test_parse_basic.py +++ b/tests/python/tvmscript/test_parse_basic.py @@ -1,27 +1,60 @@ +import inspect + +from tvm.script.builder import ir as I from tvm.script.builder import tir as T -# pylint: disable=unused-argument,unused-variable,invalid-name -@T.prim_func -def elementwise( - A: T.Buffer(shape=(128, 128, 128), dtype="float32"), # type: ignore - B: T.Buffer(shape=(128, 128, 128), dtype="float32"), # type: ignore -) -> None: - for i, j, *vvv, k in T.grid(128, 128, 128, 128, 128, 128, 128): - with T.block("inner_block"): - # vi, vj, vk = T.axis.remap("SSR", [i, j, k]) - vi = T.axis.S(128, i + 1) - vj = T.axis.S(128, j + 20) - vk = T.axis.R(128, k - i) +def test_parse_elementwise(): + # pylint: disable=unused-argument,unused-variable,invalid-name + @T.prim_func + def elementwise( + A: T.Buffer(shape=(128, 128, 128), dtype="float32"), # type: ignore + B: T.Buffer(shape=(128, 128, 128), dtype="float32"), # type: ignore + ) -> None: + for i, j, *vvv, k in T.grid(128, 128, 128, 128, 128, 128, 128): + with T.block("inner_block"): + # vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + vi = T.axis.S(128, i + 1) + vj = T.axis.S(128, j + 20) + vk = T.axis.R(128, k - i) + + # pylint: enable=unused-argument,unused-variable,invalid-name + result = elementwise + # print(result.script()) -# pylint: enable=unused-argument,unused-variable,invalid-name +def test_parse_skip(): + class Skip: + @T.prim_func + def f(): # type: ignore + ... -def main(): - result = elementwise - print(result.script()) + assert inspect.isfunction(Skip.f) + + +def test_parse_class(): + # pylint: disable=unused-argument,unused-variable,invalid-name + @I.ir_module + class C: + @T.prim_func + def elementwise( + A: T.Buffer(shape=(128, 128, 128), dtype="float32"), # type: ignore + B: T.Buffer(shape=(128, 128, 128), dtype="float32"), # type: ignore + ) -> None: + for i, j, *vvv, k in T.grid(128, 128, 128, 128, 128, 128, 128): + with T.block("inner_block"): + # vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + vi = T.axis.S(128, i + 1) + vj = T.axis.S(128, j + 20) + vk = T.axis.R(128, k - i) + + # pylint: enable=unused-argument,unused-variable,invalid-name + + # print(C.script()) if __name__ == "__main__": - main() + test_parse_elementwise() + test_parse_skip() + test_parse_class()