From a0d0e35fa5af5ae0af5258fb8b571ac6a7d90ed8 Mon Sep 17 00:00:00 2001 From: Yong Wu Date: Fri, 19 Aug 2022 12:39:19 +0000 Subject: [PATCH 1/6] doc for parser B5 --- include/tvm/script/ir_builder/tir/frame.h | 62 +++++++++++++++ include/tvm/script/ir_builder/tir/ir.h | 34 ++++++++ python/tvm/script/ir_builder/tir/ir.py | 96 +++++++++++++++++++++++ 3 files changed, 192 insertions(+) diff --git a/include/tvm/script/ir_builder/tir/frame.h b/include/tvm/script/ir_builder/tir/frame.h index d2d2485bbe..c35ff6229a 100644 --- a/include/tvm/script/ir_builder/tir/frame.h +++ b/include/tvm/script/ir_builder/tir/frame.h @@ -171,9 +171,17 @@ class PrimFuncFrame : public TIRFrame { TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(PrimFuncFrame, TIRFrame, PrimFuncFrameNode); }; +/*! + * \brief A frame that represents the assert statement. Proceeds if the condition is true, + * otherwise aborts with the message. + * + * \sa AssertFrame + */ class AssertFrameNode : public TIRFrameNode { public: + /*! \brief The PrimExpr to test. */ PrimExpr condition; + /*! \brief The output error message when the assertion failed. */ PrimExpr message; void VisitAttrs(tvm::AttrVisitor* v) { @@ -186,17 +194,33 @@ class AssertFrameNode : public TIRFrameNode { TVM_DECLARE_FINAL_OBJECT_INFO(AssertFrameNode, TIRFrameNode); public: + /*! + * \brief The method called when exiting RAII scope. + * \sa tvm::support::With + */ void ExitWithScope() final; }; +/*! + * \brief Reference type of AssertFrameNode. + * + * \sa AssertFrameNode + */ class AssertFrame : public TIRFrame { public: TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(AssertFrame, TIRFrame, AssertFrameNode); }; +/*! + * \brief A frame that represents the let binding expression, which binds a var. + * + * \sa LetFrameNode + */ class LetFrameNode : public TIRFrameNode { public: + /*! \brief The variable we bind to */ tvm::tir::Var var; + /*! \brief The value we bind var to */ PrimExpr value; void VisitAttrs(tvm::AttrVisitor* v) { @@ -209,9 +233,18 @@ class LetFrameNode : public TIRFrameNode { TVM_DECLARE_FINAL_OBJECT_INFO(LetFrameNode, TIRFrameNode); public: + /*! + * \brief The method called when exiting RAII scope. + * \sa tvm::support::With + */ void ExitWithScope() final; }; +/*! + * \brief Reference type of LetFrameNode. + * + * \sa LetFrameNode + */ class LetFrame : public TIRFrame { public: TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(LetFrame, TIRFrame, LetFrameNode); @@ -278,10 +311,17 @@ class AllocateConstFrame : public TIRFrame { AllocateConstFrameNode); }; +/*! + * \brief The LaunchThreadFrameNode. + * \note It can only be used inside a PrimFunc. + */ class LaunchThreadFrameNode : public TIRFrameNode { public: + /*! \brief The extent of environment thread. */ PrimExpr extent; + /*! \brief The attribute key, could be either virtual_thread or thread_extent. */ String attr_key; + /*! \brief The iteration variable. */ tvm::tir::IterVar iter_var; void VisitAttrs(tvm::AttrVisitor* v) { @@ -295,19 +335,36 @@ class LaunchThreadFrameNode : public TIRFrameNode { TVM_DECLARE_FINAL_OBJECT_INFO(LaunchThreadFrameNode, TIRFrameNode); public: + /*! + * \brief The method called when exiting RAII scope. + * \sa tvm::support::With + */ void ExitWithScope() final; }; +/*! + * \brief Reference type of LaunchThreadFrameNode. + * + * \sa LaunchThreadFrameNode + */ class LaunchThreadFrame : public TIRFrame { public: TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(LaunchThreadFrame, TIRFrame, LaunchThreadFrameNode); }; +/*! + * \brief A frame that represents realization. + * + * \sa RealizeFrame + */ class RealizeFrameNode : public TIRFrameNode { public: + /*! \brief The region of buffer access. */ tvm::tir::BufferRegion buffer_slice; + /*! \brief The storage scope associated with this realization. */ String storage_scope; + /*! \brief The condition expression. */ PrimExpr condition; void VisitAttrs(tvm::AttrVisitor* v) { @@ -324,6 +381,11 @@ class RealizeFrameNode : public TIRFrameNode { void ExitWithScope() final; }; +/*! + * \brief Reference type of RealizeFrameNode. + * + * \sa RealizeFrameNode + */ class RealizeFrame : public TIRFrame { public: TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(RealizeFrame, TIRFrame, RealizeFrameNode); diff --git a/include/tvm/script/ir_builder/tir/ir.h b/include/tvm/script/ir_builder/tir/ir.h index c26d552737..991347fe1e 100644 --- a/include/tvm/script/ir_builder/tir/ir.h +++ b/include/tvm/script/ir_builder/tir/ir.h @@ -87,7 +87,19 @@ void PreflattenedBuffer(Buffer postflattened_buffer, Array shape, String storage_scope = "global", int align = -1, int offset_factor = 0, String buffer_type = "default", Array axis_separators = {}); +/*! + * \brief The assertion statement. + * \param condition The assertion condition. + * \param message The error message when the assertion fails. + * \return The AssertFrame. + */ AssertFrame Assert(PrimExpr condition, String message); +/*! + * \brief The let binding. + * \param var The variable or name of variable. + * \param value The value to be bound. + * \return The created LetFrame. + */ LetFrame Let(Var var, PrimExpr value); AllocateFrame Allocate(Array extents, DataType dtype, String storage_scope = "", Optional condition = NullOpt, @@ -95,16 +107,38 @@ AllocateFrame Allocate(Array extents, DataType dtype, String storage_s AllocateConstFrame AllocateConst( NDArray data, DataType dtype, Array extents, Map annotations = NullValue>()); +/*! + * \brief The realization. + * \param buffer_slice The region of buffer access. + * \param storage_scope The storage scope associated with this realization. + * \param condition The condition expression. + * \return The result RealizeFrame. + */ RealizeFrame Realize(tvm::tir::BufferRegion buffer_slice, String storage_scope, PrimExpr condition); AttrFrame Attr(ObjectRef node, String attr_key, PrimExpr value); WhileFrame While(PrimExpr condition); IfFrame If(PrimExpr condition); ThenFrame Then(); ElseFrame Else(); +/*! + * \brief Launch a thread. + * \param var The iteration variable. + * \param extent The extent of environment thread. + * \return The result LaunchThreadFrame. + */ LaunchThreadFrame LaunchThread(Var var, PrimExpr extent); +/*! + * \brief Bind a var to thread env. + * \param thread_tag The thread type tag. + * \return The result variable which gets bound to the thread env. + */ Var EnvThread(String thread_tag); void BufferStore(Buffer buffer, PrimExpr value, Array indices); void Prefetch(Buffer buffer, Array bounds); +/*! + * \brief Evaluate the input expression. + * \param value The input expression to be evaluated. + */ void Evaluate(PrimExpr value); #define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName, DType) \ diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index ebd764cf1d..055439ef82 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -334,6 +334,21 @@ def preflattened_buffer( def Assert(condition: PrimExpr, message: str) -> frame.AssertFrame: # pylint: disable=invalid-name + """Create an assertion statement. + + Parameters + ---------- + condition : PrimExpr + The PrimExpr to test. + + message : str + The output error message when the assertion fails. + + Returns + ------- + res : frame.AssertFrame + The result AssertFrame. + """ return _ffi_api.Assert(condition, message) # pylint: disable=no-member # type: ignore @@ -342,6 +357,24 @@ def let( value: PrimExpr, body: PrimExpr = None, ) -> frame.LetFrame: + """Create a new let binding. + + Parameters + ---------- + v : Var + The variable or name of variable. + + value : PrimExpr + The value to be bound. + + body : PrimExpr + The body expression, None will be used if it was not specified. + + Returns + ------- + res : frame.LetFrame + The result LetFrame. + """ if body is None: return _ffi_api.Let(v, value) # pylint: disable=no-member # type: ignore return Let(v, value, body) @@ -378,6 +411,24 @@ def realize( storage_scope: str, condition: PrimExpr = True, ) -> frame.RealizeFrame: + """Create a realization. + + Parameters + ---------- + buffer_slice : BufferRegion + The region of buffer access. + + storage_scope : str + The storage scope associated with this realization. + + condition: PrimExpr + The condition expression, the default is True. + + Returns + ------- + res : frame.RealizeFrame + The result RealizeFrame. + """ return _ffi_api.Realize( # pylint: disable=no-member # type: ignore buffer_slice, storage_scope, condition ) @@ -442,10 +493,48 @@ def launch_thread( iter_var: IterVar, # pylint: disable=redefined-outer-name extent: PrimExpr, ) -> frame.LaunchThreadFrame: + """Launch a thread. + + Parameters + ---------- + iter_var : IterVar + The iteration variable. + + extent : PrimExpr + The extent of environment thread. + + Returns + ------- + res : frame.LaunchThreadFrame + The result LaunchThreadFrame. + + Examples + -------- + + .. code-block:: python + + from tvm.script.ir_builder import tir as T + brow = T.env_thread("blockIdx.y") + T.launch_thread(brow, 1) + + """ return _ffi_api.LaunchThread(iter_var, extent) # pylint: disable=no-member # type: ignore def env_thread(thread_tag: str) -> IterVar: + """Bind a var to thread env" + + Parameters + ---------- + thread_tag : str + The thread type tag. + + Returns + ------- + res : IterVar + The result iteration variable which gets bound to the thread env. + + """ return _ffi_api.EnvThread(thread_tag) # pylint: disable=no-member # type: ignore @@ -475,6 +564,13 @@ def prefetch(buffer: Buffer, indices: List[PrimExpr]) -> None: def evaluate(value: PrimExpr) -> None: + """Evaluate the input expression. + + Parameters + ---------- + value: PrimExpr + The input expression to be evaluated. + """ if isinstance(value, str): value = StringImm(value) return _ffi_api.Evaluate(value) # pylint: disable=no-member # type: ignore From 23b5edc94cef8420b342724ab5b1f149b00795d1 Mon Sep 17 00:00:00 2001 From: Yong Wu Date: Sat, 20 Aug 2022 09:28:55 +0800 Subject: [PATCH 2/6] Update include/tvm/script/ir_builder/tir/frame.h Co-authored-by: Junru Shao --- include/tvm/script/ir_builder/tir/frame.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/tvm/script/ir_builder/tir/frame.h b/include/tvm/script/ir_builder/tir/frame.h index c35ff6229a..bddcd0c9ac 100644 --- a/include/tvm/script/ir_builder/tir/frame.h +++ b/include/tvm/script/ir_builder/tir/frame.h @@ -202,7 +202,7 @@ class AssertFrameNode : public TIRFrameNode { }; /*! - * \brief Reference type of AssertFrameNode. + * \brief Managed reference to AssertFrameNode. * * \sa AssertFrameNode */ From 42db3ccd619ae453f19a5c966fc5ca576307ac33 Mon Sep 17 00:00:00 2001 From: Yong Wu Date: Sat, 20 Aug 2022 01:32:53 +0000 Subject: [PATCH 3/6] add tests --- include/tvm/script/ir_builder/tir/frame.h | 10 +- include/tvm/script/ir_builder/tir/ir.h | 4 +- python/tvm/script/ir_builder/tir/ir.py | 6 +- .../unittest/test_tvmscript_ir_builder_tir.py | 96 +++++++++++++++++++ 4 files changed, 106 insertions(+), 10 deletions(-) create mode 100644 tests/python/unittest/test_tvmscript_ir_builder_tir.py diff --git a/include/tvm/script/ir_builder/tir/frame.h b/include/tvm/script/ir_builder/tir/frame.h index bddcd0c9ac..c1994c7d0b 100644 --- a/include/tvm/script/ir_builder/tir/frame.h +++ b/include/tvm/script/ir_builder/tir/frame.h @@ -212,7 +212,7 @@ class AssertFrame : public TIRFrame { }; /*! - * \brief A frame that represents the let binding expression, which binds a var. + * \brief A frame represents the let binding expression, which binds a var. * * \sa LetFrameNode */ @@ -241,7 +241,7 @@ class LetFrameNode : public TIRFrameNode { }; /*! - * \brief Reference type of LetFrameNode. + * \brief Managed reference to LetFrameNode. * * \sa LetFrameNode */ @@ -313,7 +313,7 @@ class AllocateConstFrame : public TIRFrame { /*! * \brief The LaunchThreadFrameNode. - * \note It can only be used inside a PrimFunc. + * \note It is used only inside a PrimFunc. */ class LaunchThreadFrameNode : public TIRFrameNode { public: @@ -343,7 +343,7 @@ class LaunchThreadFrameNode : public TIRFrameNode { }; /*! - * \brief Reference type of LaunchThreadFrameNode. + * \brief Managed reference to LaunchThreadFrameNode. * * \sa LaunchThreadFrameNode */ @@ -382,7 +382,7 @@ class RealizeFrameNode : public TIRFrameNode { }; /*! - * \brief Reference type of RealizeFrameNode. + * \brief Managed reference to RealizeFrameNode. * * \sa RealizeFrameNode */ diff --git a/include/tvm/script/ir_builder/tir/ir.h b/include/tvm/script/ir_builder/tir/ir.h index 991347fe1e..c88ccbfd1b 100644 --- a/include/tvm/script/ir_builder/tir/ir.h +++ b/include/tvm/script/ir_builder/tir/ir.h @@ -96,7 +96,7 @@ void PreflattenedBuffer(Buffer postflattened_buffer, Array shape, AssertFrame Assert(PrimExpr condition, String message); /*! * \brief The let binding. - * \param var The variable or name of variable. + * \param var The variable to bind. * \param value The value to be bound. * \return The created LetFrame. */ @@ -137,7 +137,7 @@ void BufferStore(Buffer buffer, PrimExpr value, Array indices); void Prefetch(Buffer buffer, Array bounds); /*! * \brief Evaluate the input expression. - * \param value The input expression to be evaluated. + * \param value The input expression to evaluate. */ void Evaluate(PrimExpr value); diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index 055439ef82..2e33c05895 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -362,7 +362,7 @@ def let( Parameters ---------- v : Var - The variable or name of variable. + The variable to bind. value : PrimExpr The value to be bound. @@ -532,7 +532,7 @@ def env_thread(thread_tag: str) -> IterVar: Returns ------- res : IterVar - The result iteration variable which gets bound to the thread env. + The result iteration variable gets bound to the thread env. """ return _ffi_api.EnvThread(thread_tag) # pylint: disable=no-member # type: ignore @@ -569,7 +569,7 @@ def evaluate(value: PrimExpr) -> None: Parameters ---------- value: PrimExpr - The input expression to be evaluated. + The input expression to evaluate. """ if isinstance(value, str): value = StringImm(value) diff --git a/tests/python/unittest/test_tvmscript_ir_builder_tir.py b/tests/python/unittest/test_tvmscript_ir_builder_tir.py new file mode 100644 index 0000000000..b08bcb7d18 --- /dev/null +++ b/tests/python/unittest/test_tvmscript_ir_builder_tir.py @@ -0,0 +1,96 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Unittests for tvm.script.ir_builder.tir""" +import pytest +import tvm +from tvm import tir +from tvm.script.ir_builder import tir as T +from tvm.script.ir_builder import IRBuilder +from tvm.ir.base import assert_structural_equal + + +def test_ir_builder_tir_assert(): + with IRBuilder() as ib: # pylint: disable=invalid-name + with T.Assert(T.var("int32", name="a") == 0, message="a is 0"): + T.evaluate(0) + # the assert generated by IRBuilder + assert_actual = ib.get() + + # the expected assert statement + assert_expected = tir.AssertStmt(T.var("int32", name="a") == 0, + tir.StringImm("a is 0"), + tir.Evaluate(0)) + # Check if the generated ir is expected + assert_structural_equal(assert_actual, assert_expected, map_free_vars=True) + + +def test_ir_builder_tir_evaluate(): + with IRBuilder() as ib: # pylint: disable=invalid-name + T.evaluate(0) + # the evaluate generated by IRBuilder + eval_actual = ib.get() + + # the expected evaluate + eval_expected = tir.Evaluate(0) + # Check if the generated ir is expected + assert_structural_equal(eval_actual, eval_expected, map_free_vars=True) + + +def test_ir_builder_tir_let(): + with IRBuilder() as ib: # pylint: disable=invalid-name + with T.let(T.var("int32", name="a"), tir.IntImm("int32", 2)): + T.evaluate(0) + # the let binding generated by IRBuilder + let_actual = ib.get() + + # the expected Let statement + let_expected = tir.LetStmt(T.var("int32", name="a"), tir.IntImm("int32", 2), tir.Evaluate(0)) + assert_structural_equal(let_actual, let_expected, map_free_vars=True) + + +def test_ir_builder_tir_realize(): + buffer_a = T.buffer_decl((128, 128), "float32") + with IRBuilder() as ib: # pylint: disable=invalid-name + with T.realize(buffer_a[0:128, 0:128], "test_storage_scope", True): + T.evaluate(0) + realize_actual = ib.get() + + # the expected buffer realization + buffer_realize = tir.BufferRealize(buffer_a, + [tvm.ir.Range(0, 128),tvm.ir.Range(0, 128)], + True, tir.Evaluate(0)) + expected_realize = tir.AttrStmt(buffer_a, "realize_scope", + tir.StringImm("test_storage_scope"), + buffer_realize) + assert_structural_equal(realize_actual, expected_realize, map_free_vars=True) + + +def test_ir_builder_tir_thread(): + with IRBuilder() as ib: # pylint: disable=invalid-name + with T.prim_func(): + brow = T.env_thread("blockIdx.y") + with T.launch_thread(brow, 1): + T.evaluate(0) + ir_actual = ib.get() + iter_var = tir.IterVar((0, 1), "v", iter_type=1, thread_tag="blockIdx.y") + attr_stmt = tir.AttrStmt(iter_var, "thread_extent", 1, tir.Evaluate(0)) + func = tir.PrimFunc([], attr_stmt) + assert_structural_equal(ir_actual, func, map_free_vars=True) + + +if __name__ == "__main__": + pytest.main([__file__]) From fd487b7bea478c2bd9efa6d78f027ad6962b599d Mon Sep 17 00:00:00 2001 From: Yong Wu Date: Sat, 20 Aug 2022 01:50:22 +0000 Subject: [PATCH 4/6] fix lint --- .../unittest/test_tvmscript_ir_builder_tir.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/tests/python/unittest/test_tvmscript_ir_builder_tir.py b/tests/python/unittest/test_tvmscript_ir_builder_tir.py index b08bcb7d18..3d0aa57b0d 100644 --- a/tests/python/unittest/test_tvmscript_ir_builder_tir.py +++ b/tests/python/unittest/test_tvmscript_ir_builder_tir.py @@ -31,9 +31,9 @@ def test_ir_builder_tir_assert(): assert_actual = ib.get() # the expected assert statement - assert_expected = tir.AssertStmt(T.var("int32", name="a") == 0, - tir.StringImm("a is 0"), - tir.Evaluate(0)) + assert_expected = tir.AssertStmt( + T.var("int32", name="a") == 0, tir.StringImm("a is 0"), tir.Evaluate(0) + ) # Check if the generated ir is expected assert_structural_equal(assert_actual, assert_expected, map_free_vars=True) @@ -70,12 +70,12 @@ def test_ir_builder_tir_realize(): realize_actual = ib.get() # the expected buffer realization - buffer_realize = tir.BufferRealize(buffer_a, - [tvm.ir.Range(0, 128),tvm.ir.Range(0, 128)], - True, tir.Evaluate(0)) - expected_realize = tir.AttrStmt(buffer_a, "realize_scope", - tir.StringImm("test_storage_scope"), - buffer_realize) + buffer_realize = tir.BufferRealize( + buffer_a, [tvm.ir.Range(0, 128), tvm.ir.Range(0, 128)], True, tir.Evaluate(0) + ) + expected_realize = tir.AttrStmt( + buffer_a, "realize_scope", tir.StringImm("test_storage_scope"), buffer_realize + ) assert_structural_equal(realize_actual, expected_realize, map_free_vars=True) @@ -86,7 +86,7 @@ def test_ir_builder_tir_thread(): with T.launch_thread(brow, 1): T.evaluate(0) ir_actual = ib.get() - iter_var = tir.IterVar((0, 1), "v", iter_type=1, thread_tag="blockIdx.y") + iter_var = tir.IterVar((0, 1), "v", iter_type=1, thread_tag="blockIdx.y") attr_stmt = tir.AttrStmt(iter_var, "thread_extent", 1, tir.Evaluate(0)) func = tir.PrimFunc([], attr_stmt) assert_structural_equal(ir_actual, func, map_free_vars=True) From 4946891b518d1c5ff9f17fc122502e8c8489ecd5 Mon Sep 17 00:00:00 2001 From: Yong Wu Date: Sat, 20 Aug 2022 12:29:19 +0000 Subject: [PATCH 5/6] doc and test for B6 --- include/tvm/script/ir_builder/tir/frame.h | 128 ++++++++++++++++++ include/tvm/script/ir_builder/tir/ir.h | 53 ++++++++ python/tvm/script/ir_builder/tir/ir.py | 114 ++++++++++++++++ .../unittest/test_tvmscript_ir_builder_tir.py | 70 ++++++++++ 4 files changed, 365 insertions(+) diff --git a/include/tvm/script/ir_builder/tir/frame.h b/include/tvm/script/ir_builder/tir/frame.h index c1994c7d0b..413101a480 100644 --- a/include/tvm/script/ir_builder/tir/frame.h +++ b/include/tvm/script/ir_builder/tir/frame.h @@ -250,13 +250,24 @@ class LetFrame : public TIRFrame { TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(LetFrame, TIRFrame, LetFrameNode); }; +/*! + * \brief A frame represents the allocate. + * + * \sa AllocateFrame + */ class AllocateFrameNode : public TIRFrameNode { public: + /*! \brief The extents of the allocate. */ Array extents; + /*! \brief The data type of the buffer. */ DataType dtype; + /*! \brief The storage scope. */ String storage_scope; + /*! \brief The condition. */ PrimExpr condition; + /*! \brief Additional annotation hints. */ Map annotations; + /*! \brief The buffer. */ tvm::tir::Buffer buffer; void VisitAttrs(tvm::AttrVisitor* v) { @@ -273,20 +284,39 @@ class AllocateFrameNode : public TIRFrameNode { TVM_DECLARE_FINAL_OBJECT_INFO(AllocateFrameNode, TIRFrameNode); public: + /*! + * \brief The method called when exiting RAII scope. + * \sa tvm::support::With + */ void ExitWithScope() final; }; +/*! + * \brief Managed reference to AllocateFrameNode. + * + * \sa AllocateFrameNode + */ class AllocateFrame : public TIRFrame { public: TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(AllocateFrame, TIRFrame, AllocateFrameNode); }; +/*! + * \brief A frame represents the allocate constant. + * + * \sa AllocateConstFrame + */ class AllocateConstFrameNode : public TIRFrameNode { public: + /*! \brief The data type of the buffer. */ DataType dtype; + /*! \brief The extents of the allocate. */ Array extents; + /*! \brief The data associated with the constant. */ tvm::runtime::NDArray data; + /*! \brief The buffer */ tvm::tir::Buffer buffer; + /*! \brief Additional annotations about the allocation. */ Map annotations; void VisitAttrs(tvm::AttrVisitor* v) { @@ -302,9 +332,18 @@ class AllocateConstFrameNode : public TIRFrameNode { TVM_DECLARE_FINAL_OBJECT_INFO(AllocateConstFrameNode, TIRFrameNode); public: + /*! + * \brief The method called when exiting RAII scope. + * \sa tvm::support::With + */ void ExitWithScope() final; }; +/*! + * \brief Managed reference to AllocateConstFrameNode. + * + * \sa AllocateConstFrameNode + */ class AllocateConstFrame : public TIRFrame { public: TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(AllocateConstFrame, TIRFrame, @@ -378,6 +417,10 @@ class RealizeFrameNode : public TIRFrameNode { TVM_DECLARE_FINAL_OBJECT_INFO(RealizeFrameNode, TIRFrameNode); public: + /*! + * \brief The method called when exiting RAII scope. + * \sa tvm::support::With + */ void ExitWithScope() final; }; @@ -391,10 +434,18 @@ class RealizeFrame : public TIRFrame { TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(RealizeFrame, TIRFrame, RealizeFrameNode); }; +/*! + * \brief A frame that represents attribute node. + * + * \sa AttrFrame + */ class AttrFrameNode : public TIRFrameNode { public: + /*! \brief The node to annotate the attribute. */ ObjectRef node; + /*! \brief Attribute type key. */ String attr_key; + /*! \brief The value of the attribute. */ PrimExpr value; void VisitAttrs(tvm::AttrVisitor* v) { @@ -408,16 +459,31 @@ class AttrFrameNode : public TIRFrameNode { TVM_DECLARE_FINAL_OBJECT_INFO(AttrFrameNode, TIRFrameNode); public: + /*! + * \brief The method called when exiting RAII scope. + * \sa tvm::support::With + */ void ExitWithScope() final; }; +/*! + * \brief Managed reference to AttrFrameNode. + * + * \sa AttrFrameNode + */ class AttrFrame : public TIRFrame { public: TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(AttrFrame, TIRFrame, AttrFrameNode); }; +/*! + * \brief A frame that represents while loop. + * + * \sa WhileFrame + */ class WhileFrameNode : public TIRFrameNode { public: + /*! \brief The termination condition of while. */ PrimExpr condition; void VisitAttrs(tvm::AttrVisitor* v) { @@ -429,18 +495,35 @@ class WhileFrameNode : public TIRFrameNode { TVM_DECLARE_FINAL_OBJECT_INFO(WhileFrameNode, TIRFrameNode); public: + /*! + * \brief The method called when exiting RAII scope. + * \sa tvm::support::With + */ void ExitWithScope() final; }; +/*! + * \brief Managed reference to WhileFrameNode. + * + * \sa WhileFrameNode + */ class WhileFrame : public TIRFrame { public: TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(WhileFrame, TIRFrame, WhileFrameNode); }; +/*! + * \brief A frame that represents if statement. + * + * \sa IfFrame + */ class IfFrameNode : public TIRFrameNode { public: + /*! \brief The condition of the if statement. */ PrimExpr condition; + /*! \brief The statements in the true branch. */ Optional> then_stmts; + /*! \brief The stetements in the false branch. */ Optional> else_stmts; void VisitAttrs(tvm::AttrVisitor* v) { @@ -454,39 +537,84 @@ class IfFrameNode : public TIRFrameNode { TVM_DECLARE_FINAL_OBJECT_INFO(IfFrameNode, TIRFrameNode); public: + /*! + * \brief The method called when exiting RAII scope. + * \sa tvm::support::With + */ void ExitWithScope() final; }; +/*! + * \brief Managed reference to IfFrameNode. + * + * \sa IfFrameNode + */ class IfFrame : public TIRFrame { public: TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IfFrame, TIRFrame, IfFrameNode); }; +/*! + * \brief A frame that represents then. + * + * \sa ThenFrame + */ class ThenFrameNode : public TIRFrameNode { public: static constexpr const char* _type_key = "script.ir_builder.tir.ThenFrame"; TVM_DECLARE_FINAL_OBJECT_INFO(ThenFrameNode, TIRFrameNode); public: + /*! + * \brief The method called when entering RAII scope. + * \sa tvm::support::With + */ void EnterWithScope() final; + /*! + * \brief The method called when exiting RAII scope. + * \sa tvm::support::With + */ void ExitWithScope() final; }; +/*! + * \brief Managed reference to ThenFrameNode. + * + * \sa ThenFrameNode + */ class ThenFrame : public TIRFrame { public: TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ThenFrame, TIRFrame, ThenFrameNode); }; +/*! + * \brief A frame that represents else. + * + * \sa ElseFrame + */ class ElseFrameNode : public TIRFrameNode { public: static constexpr const char* _type_key = "script.ir_builder.tir.ElseFrame"; TVM_DECLARE_FINAL_OBJECT_INFO(ElseFrameNode, TIRFrameNode); public: + /*! + * \brief The method called when entering RAII scope. + * \sa tvm::support::With + */ void EnterWithScope() final; + /*! + * \brief The method called when exiting RAII scope. + * \sa tvm::support::With + */ void ExitWithScope() final; }; +/*! + * \brief Managed reference to ElseFrameNode. + * + * \sa ElseFrameNode + */ class ElseFrame : public TIRFrame { public: TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ElseFrame, TIRFrame, ElseFrameNode); diff --git a/include/tvm/script/ir_builder/tir/ir.h b/include/tvm/script/ir_builder/tir/ir.h index c88ccbfd1b..3460f9d403 100644 --- a/include/tvm/script/ir_builder/tir/ir.h +++ b/include/tvm/script/ir_builder/tir/ir.h @@ -101,9 +101,26 @@ AssertFrame Assert(PrimExpr condition, String message); * \return The created LetFrame. */ LetFrame Let(Var var, PrimExpr value); +/*! + * \brief The allocate node. + * \param extents The extents of the allocate. + * \param dtype The data type of the buffer. + * \param storage_scope The storage scope. + * \param condition The condition. + * \param annotations Additional annotation hints. + * \return The created AllocateFrame. + */ AllocateFrame Allocate(Array extents, DataType dtype, String storage_scope = "", Optional condition = NullOpt, Optional> annotations = NullOpt); +/*! + * \brief The allocate constant node. + * \param data The data associated with the constant. + * \param dtype The data type of the buffer. + * \param extents The extents of the allocate. + * \param annotations Additional annotation hints. + * \return The created AllocateConstFrame. + */ AllocateConstFrame AllocateConst( NDArray data, DataType dtype, Array extents, Map annotations = NullValue>()); @@ -115,10 +132,35 @@ AllocateConstFrame AllocateConst( * \return The result RealizeFrame. */ RealizeFrame Realize(tvm::tir::BufferRegion buffer_slice, String storage_scope, PrimExpr condition); +/*! + * \brief Create an attribute. + * \param node The node to annotate the attribute. + * \param attr_key Attribute type key. + * \param value The value of the attribute. + * \return The result AttrFrame. + */ AttrFrame Attr(ObjectRef node, String attr_key, PrimExpr value); +/*! + * \brief Create a while loop. + * \param condition The termination condition of the loop. + * \return The result WhileFrame. + */ WhileFrame While(PrimExpr condition); +/*! + * \brief Create an if statement. + * \param condition The condition of if statement. + * \return The result IfFrame. + */ IfFrame If(PrimExpr condition); +/*! + * \brief Create a then. + * \return The result ThenFrame. + */ ThenFrame Then(); +/*! + * \brief Create an else. + * \return The result ElseFrame. + */ ElseFrame Else(); /*! * \brief Launch a thread. @@ -133,7 +175,18 @@ LaunchThreadFrame LaunchThread(Var var, PrimExpr extent); * \return The result variable which gets bound to the thread env. */ Var EnvThread(String thread_tag); +/*! + * \brief Store data in a buffer. + * \param buffer The buffer. + * \param value The value to be stored. + * \param indices The indices location to be stored. + */ void BufferStore(Buffer buffer, PrimExpr value, Array indices); +/*! + * \brief The prefetch hint for a buffer + * \param buffer The buffer to be prefetched. + * \param bounds The bounds to be prefetched. + */ void Prefetch(Buffer buffer, Array bounds); /*! * \brief Evaluate the input expression. diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index 2e33c05895..485c92aa75 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -387,6 +387,25 @@ def allocate( condition: PrimExpr = None, annotations=None, ) -> frame.AllocateFrame: + """Allocate node. + + Parameters + ---------- + extents : List[PrimExpr] + The extents of the allocate. + + dtype : str + The data type of the buffer. + + scope : str + The storage scope. + + condition : PrimExpr + The condition. + + annotations: Optional[Mapping[str, Object]] + Additional annotation hints. + """ if isinstance(condition, bool): condition = IntImm("bool", condition) return _ffi_api.Allocate( # pylint: disable=no-member # type: ignore @@ -400,6 +419,22 @@ def allocate_const( extents: List[PrimExpr], annotations=None, ) -> frame.AllocateConstFrame: + """Allocate constant node. + + Parameters + ---------- + data : List[PrimExpr] + The data associated with the constant. + + dtype : str + The data type of the buffer. + + extents : List[PrimExpr] + The extents of the allocate. + + annotations : Optional[Map] + Additional annotations about the allocation. + """ return _ffi_api.AllocateConst( # pylint: disable=no-member # type: ignore ndarray.array(np.asarray(data, dtype)), dtype, extents, annotations @@ -435,28 +470,85 @@ def realize( def attr(node: Any, attr_key: str, value: Union[PrimExpr, str]) -> frame.AttrFrame: + """Create an attribute node. + + Parameters + ---------- + node : Any + The node to annotate the attribute. + + attr_key : str + Attribute type key. + + value : Union[PrimExpr, str] + The value of the attribute. + + Returns + ------- + res : frame.AttrFrame + The result AttrFrame. + """ node = convert(node) value = convert(value) return _ffi_api.Attr(node, attr_key, value) # pylint: disable=no-member # type: ignore def While(condition: PrimExpr) -> frame.WhileFrame: # pylint: disable=invalid-name + """Create a while node. + + Parameters + ---------- + condition : PrimExpr + The termination condition of the loop. + + Returns + ------- + res : frame.WhileFrame + The result WhileFrame. + """ if isinstance(condition, bool): condition = IntImm("bool", condition) return _ffi_api.While(condition) # pylint: disable=no-member # type: ignore def If(condition: PrimExpr) -> frame.IfFrame: # pylint: disable=invalid-name + """Create an if node. + + Parameters + ---------- + condition : PrimExpr + The condition of if statement, executes the true branch if the condition is true, + otherwise jump into the false branch. + + Returns + ------- + res : frame.IfFrame + The result IfFrame. + """ if isinstance(condition, bool): condition = IntImm("bool", condition) return _ffi_api.If(condition) # pylint: disable=no-member # type: ignore def Then() -> frame.ThenFrame: # pylint: disable=invalid-name + """Create a then. + + Returns + ------- + res : frame.ThenFrame + The result ThenFrame. + """ return _ffi_api.Then() # pylint: disable=no-member # type: ignore def Else() -> frame.ElseFrame: # pylint: disable=invalid-name + """Create an else. + + Returns + ------- + res : frame.ElseFrame + The result ElseFrame. + """ return _ffi_api.Else() # pylint: disable=no-member # type: ignore @@ -539,6 +631,19 @@ def env_thread(thread_tag: str) -> IterVar: def buffer_store(buffer: Buffer, value: PrimExpr, indices: List[Union[PrimExpr, slice]]) -> None: + """Buffer store node. + + Parameters + ---------- + buffer : Buffer + The buffer. + + value : PrimExpr + The value to be stored. + + indices : List[Union[PrimExpr, slice]] + The indices location to be stored. + """ from tvm.arith import Analyzer # pylint: disable=import-outside-toplevel expr_indices = [] @@ -560,6 +665,15 @@ def buffer_store(buffer: Buffer, value: PrimExpr, indices: List[Union[PrimExpr, def prefetch(buffer: Buffer, indices: List[PrimExpr]) -> None: + """The prefetch hint for a buffer. + + Parameters + ---------- + buffer : Buffer + The buffer to be prefetched. + indices : List[PrimExpr] + The indices of the buffer to extract. + """ return _ffi_api.Prefetch(buffer, indices) # pylint: disable=no-member # type: ignore diff --git a/tests/python/unittest/test_tvmscript_ir_builder_tir.py b/tests/python/unittest/test_tvmscript_ir_builder_tir.py index 3d0aa57b0d..0f7f782bcd 100644 --- a/tests/python/unittest/test_tvmscript_ir_builder_tir.py +++ b/tests/python/unittest/test_tvmscript_ir_builder_tir.py @@ -16,8 +16,10 @@ # under the License. """Unittests for tvm.script.ir_builder.tir""" import pytest +import numpy as np import tvm from tvm import tir +from tvm.runtime import ndarray from tvm.script.ir_builder import tir as T from tvm.script.ir_builder import IRBuilder from tvm.ir.base import assert_structural_equal @@ -92,5 +94,73 @@ def test_ir_builder_tir_thread(): assert_structural_equal(ir_actual, func, map_free_vars=True) +def test_ir_builder_tir_allocate(): + with IRBuilder() as ib: # pylint: disable=invalid-name + with T.allocate([10], "float32", scope="local"): + T.evaluate(1) + ir_actual = ib.get() + buffer_var = tir.Var("v", tvm.ir.PointerType(tvm.ir.PrimType("float32"), "local")) + ir_expected = tir.Allocate( + buffer_var, "float32", [10], tvm.tir.const(1, "uint1"), tir.Evaluate(1) + ) + assert_structural_equal(ir_actual, ir_expected, map_free_vars=True) + + +def test_ir_builder_tir_allocate_const(): + data = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1] + with IRBuilder() as ib: # pylint: disable=invalid-name + with T.allocate_const(data, "int32", [10]): + T.evaluate(1) + ir_actual = ib.get() + buffer_var = tir.Var("v", tvm.ir.PointerType(tvm.ir.PrimType("int32"))) + ir_expected = tir.AllocateConst( + buffer_var, "int32", [10], ndarray.array(np.asarray(data, "int32")), tir.Evaluate(1) + ) + assert_structural_equal(ir_actual, ir_expected, map_free_vars=True) + + +def test_ir_builder_tir_prefetch(): + with IRBuilder() as ib: # pylint: disable=invalid-name + buffer_a = T.buffer_decl((128, 128), "float32") + T.prefetch(buffer_a, []) + ir_actual = ib.get() + ir_expected = tir.Prefetch(buffer_a, []) + assert_structural_equal(ir_actual, ir_expected, map_free_vars=True) + + +def test_ir_builder_tir_buffer_store(): + buffer_a = T.buffer_decl((10, 10), "float32") + i = T.var("int32", "x") + with IRBuilder() as ib: # pylint: disable=invalid-name + T.buffer_store(buffer_a, 0.1, [0, i]) + ir_actual = ib.get() + ir_expected = tir.BufferStore(buffer_a, 0.1, [0, i]) + assert_structural_equal(ir_actual, ir_expected, map_free_vars=True) + + +def test_ir_builder_tir_while(): + with IRBuilder() as ib: # pylint: disable=invalid-name + with T.While(T.var("int32", "x") > 0): + T.evaluate(0) + ir_actual = ib.get() + ir_expected = tir.While(tir.Var("x", "int32") > 0, tir.Evaluate(0)) + assert_structural_equal(ir_actual, ir_expected, map_free_vars=True) + + +def test_ir_builder_tir_if(): + with IRBuilder() as ib: # pylint: disable=invalid-name + with T.If(T.var("int32", "c") < 12): + with T.Then(): + T.int32(0) + with T.Else(): + T.int32(1) + ir_actual = ib.get() + ir_expected = tir.if_then_else( + tir.Var("c", "int32") < 12, tir.IntImm("int32", 0), tir.IntImm("int32", 1) + ) + # comment this assertion because tir does not have if/then/else + # assert_structural_equal(ir_actual, ir_expected, map_free_vars=True) + + if __name__ == "__main__": pytest.main([__file__]) From 9bdb7818c222845c9190eb75026a7fa01ce633e7 Mon Sep 17 00:00:00 2001 From: Yong Wu Date: Sun, 21 Aug 2022 23:19:25 +0000 Subject: [PATCH 6/6] fix comment, use IfThenElse instead of if_then_else --- .../unittest/test_tvmscript_ir_builder_tir.py | 36 ++++++++++--------- 1 file changed, 19 insertions(+), 17 deletions(-) diff --git a/tests/python/unittest/test_tvmscript_ir_builder_tir.py b/tests/python/unittest/test_tvmscript_ir_builder_tir.py index 0f7f782bcd..7d60905ae4 100644 --- a/tests/python/unittest/test_tvmscript_ir_builder_tir.py +++ b/tests/python/unittest/test_tvmscript_ir_builder_tir.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# pylint: disable=invalid-name """Unittests for tvm.script.ir_builder.tir""" import pytest import numpy as np @@ -26,7 +27,7 @@ def test_ir_builder_tir_assert(): - with IRBuilder() as ib: # pylint: disable=invalid-name + with IRBuilder() as ib: with T.Assert(T.var("int32", name="a") == 0, message="a is 0"): T.evaluate(0) # the assert generated by IRBuilder @@ -41,7 +42,7 @@ def test_ir_builder_tir_assert(): def test_ir_builder_tir_evaluate(): - with IRBuilder() as ib: # pylint: disable=invalid-name + with IRBuilder() as ib: T.evaluate(0) # the evaluate generated by IRBuilder eval_actual = ib.get() @@ -53,7 +54,7 @@ def test_ir_builder_tir_evaluate(): def test_ir_builder_tir_let(): - with IRBuilder() as ib: # pylint: disable=invalid-name + with IRBuilder() as ib: with T.let(T.var("int32", name="a"), tir.IntImm("int32", 2)): T.evaluate(0) # the let binding generated by IRBuilder @@ -66,7 +67,7 @@ def test_ir_builder_tir_let(): def test_ir_builder_tir_realize(): buffer_a = T.buffer_decl((128, 128), "float32") - with IRBuilder() as ib: # pylint: disable=invalid-name + with IRBuilder() as ib: with T.realize(buffer_a[0:128, 0:128], "test_storage_scope", True): T.evaluate(0) realize_actual = ib.get() @@ -82,7 +83,7 @@ def test_ir_builder_tir_realize(): def test_ir_builder_tir_thread(): - with IRBuilder() as ib: # pylint: disable=invalid-name + with IRBuilder() as ib: with T.prim_func(): brow = T.env_thread("blockIdx.y") with T.launch_thread(brow, 1): @@ -95,7 +96,7 @@ def test_ir_builder_tir_thread(): def test_ir_builder_tir_allocate(): - with IRBuilder() as ib: # pylint: disable=invalid-name + with IRBuilder() as ib: with T.allocate([10], "float32", scope="local"): T.evaluate(1) ir_actual = ib.get() @@ -108,7 +109,7 @@ def test_ir_builder_tir_allocate(): def test_ir_builder_tir_allocate_const(): data = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1] - with IRBuilder() as ib: # pylint: disable=invalid-name + with IRBuilder() as ib: with T.allocate_const(data, "int32", [10]): T.evaluate(1) ir_actual = ib.get() @@ -120,7 +121,7 @@ def test_ir_builder_tir_allocate_const(): def test_ir_builder_tir_prefetch(): - with IRBuilder() as ib: # pylint: disable=invalid-name + with IRBuilder() as ib: buffer_a = T.buffer_decl((128, 128), "float32") T.prefetch(buffer_a, []) ir_actual = ib.get() @@ -131,7 +132,7 @@ def test_ir_builder_tir_prefetch(): def test_ir_builder_tir_buffer_store(): buffer_a = T.buffer_decl((10, 10), "float32") i = T.var("int32", "x") - with IRBuilder() as ib: # pylint: disable=invalid-name + with IRBuilder() as ib: T.buffer_store(buffer_a, 0.1, [0, i]) ir_actual = ib.get() ir_expected = tir.BufferStore(buffer_a, 0.1, [0, i]) @@ -147,19 +148,20 @@ def test_ir_builder_tir_while(): assert_structural_equal(ir_actual, ir_expected, map_free_vars=True) -def test_ir_builder_tir_if(): - with IRBuilder() as ib: # pylint: disable=invalid-name +def test_ir_builder_tir_if_then_else(): + with IRBuilder() as ib: with T.If(T.var("int32", "c") < 12): with T.Then(): - T.int32(0) + T.evaluate(T.int32(0)) with T.Else(): - T.int32(1) + T.evaluate(T.int32(1)) ir_actual = ib.get() - ir_expected = tir.if_then_else( - tir.Var("c", "int32") < 12, tir.IntImm("int32", 0), tir.IntImm("int32", 1) + ir_expected = tir.IfThenElse( + tir.Var("c", "int32") < 12, + tir.Evaluate(tir.IntImm("int32", 0)), + tir.Evaluate(tir.IntImm("int32", 1)), ) - # comment this assertion because tir does not have if/then/else - # assert_structural_equal(ir_actual, ir_expected, map_free_vars=True) + assert_structural_equal(ir_actual, ir_expected, map_free_vars=True) if __name__ == "__main__":