Skip to content
This repository has been archived by the owner on May 22, 2023. It is now read-only.

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
yongwww committed Aug 20, 2022
1 parent 23b5edc commit 42db3cc
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 10 deletions.
10 changes: 5 additions & 5 deletions include/tvm/script/ir_builder/tir/frame.h
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ class AssertFrame : public TIRFrame {
};

/*!
* \brief A frame that represents the let binding expression, which binds a var.
* \brief A frame represents the let binding expression, which binds a var.
*
* \sa LetFrameNode
*/
Expand Down Expand Up @@ -241,7 +241,7 @@ class LetFrameNode : public TIRFrameNode {
};

/*!
* \brief Reference type of LetFrameNode.
* \brief Managed reference to LetFrameNode.
*
* \sa LetFrameNode
*/
Expand Down Expand Up @@ -313,7 +313,7 @@ class AllocateConstFrame : public TIRFrame {

/*!
* \brief The LaunchThreadFrameNode.
* \note It can only be used inside a PrimFunc.
* \note It is used only inside a PrimFunc.
*/
class LaunchThreadFrameNode : public TIRFrameNode {
public:
Expand Down Expand Up @@ -343,7 +343,7 @@ class LaunchThreadFrameNode : public TIRFrameNode {
};

/*!
* \brief Reference type of LaunchThreadFrameNode.
* \brief Managed reference to LaunchThreadFrameNode.
*
* \sa LaunchThreadFrameNode
*/
Expand Down Expand Up @@ -382,7 +382,7 @@ class RealizeFrameNode : public TIRFrameNode {
};

/*!
* \brief Reference type of RealizeFrameNode.
* \brief Managed reference to RealizeFrameNode.
*
* \sa RealizeFrameNode
*/
Expand Down
4 changes: 2 additions & 2 deletions include/tvm/script/ir_builder/tir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ void PreflattenedBuffer(Buffer postflattened_buffer, Array<PrimExpr> shape,
AssertFrame Assert(PrimExpr condition, String message);
/*!
* \brief The let binding.
* \param var The variable or name of variable.
* \param var The variable to bind.
* \param value The value to be bound.
* \return The created LetFrame.
*/
Expand Down Expand Up @@ -137,7 +137,7 @@ void BufferStore(Buffer buffer, PrimExpr value, Array<PrimExpr> indices);
void Prefetch(Buffer buffer, Array<Range> bounds);
/*!
* \brief Evaluate the input expression.
* \param value The input expression to be evaluated.
* \param value The input expression to evaluate.
*/
void Evaluate(PrimExpr value);

Expand Down
6 changes: 3 additions & 3 deletions python/tvm/script/ir_builder/tir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ def let(
Parameters
----------
v : Var
The variable or name of variable.
The variable to bind.
value : PrimExpr
The value to be bound.
Expand Down Expand Up @@ -532,7 +532,7 @@ def env_thread(thread_tag: str) -> IterVar:
Returns
-------
res : IterVar
The result iteration variable which gets bound to the thread env.
The result iteration variable gets bound to the thread env.
"""
return _ffi_api.EnvThread(thread_tag) # pylint: disable=no-member # type: ignore
Expand Down Expand Up @@ -569,7 +569,7 @@ def evaluate(value: PrimExpr) -> None:
Parameters
----------
value: PrimExpr
The input expression to be evaluated.
The input expression to evaluate.
"""
if isinstance(value, str):
value = StringImm(value)
Expand Down
96 changes: 96 additions & 0 deletions tests/python/unittest/test_tvmscript_ir_builder_tir.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Unittests for tvm.script.ir_builder.tir"""
import pytest
import tvm
from tvm import tir
from tvm.script.ir_builder import tir as T
from tvm.script.ir_builder import IRBuilder
from tvm.ir.base import assert_structural_equal


def test_ir_builder_tir_assert():
with IRBuilder() as ib: # pylint: disable=invalid-name
with T.Assert(T.var("int32", name="a") == 0, message="a is 0"):
T.evaluate(0)
# the assert generated by IRBuilder
assert_actual = ib.get()

# the expected assert statement
assert_expected = tir.AssertStmt(T.var("int32", name="a") == 0,
tir.StringImm("a is 0"),
tir.Evaluate(0))
# Check if the generated ir is expected
assert_structural_equal(assert_actual, assert_expected, map_free_vars=True)


def test_ir_builder_tir_evaluate():
with IRBuilder() as ib: # pylint: disable=invalid-name
T.evaluate(0)
# the evaluate generated by IRBuilder
eval_actual = ib.get()

# the expected evaluate
eval_expected = tir.Evaluate(0)
# Check if the generated ir is expected
assert_structural_equal(eval_actual, eval_expected, map_free_vars=True)


def test_ir_builder_tir_let():
with IRBuilder() as ib: # pylint: disable=invalid-name
with T.let(T.var("int32", name="a"), tir.IntImm("int32", 2)):
T.evaluate(0)
# the let binding generated by IRBuilder
let_actual = ib.get()

# the expected Let statement
let_expected = tir.LetStmt(T.var("int32", name="a"), tir.IntImm("int32", 2), tir.Evaluate(0))
assert_structural_equal(let_actual, let_expected, map_free_vars=True)


def test_ir_builder_tir_realize():
buffer_a = T.buffer_decl((128, 128), "float32")
with IRBuilder() as ib: # pylint: disable=invalid-name
with T.realize(buffer_a[0:128, 0:128], "test_storage_scope", True):
T.evaluate(0)
realize_actual = ib.get()

# the expected buffer realization
buffer_realize = tir.BufferRealize(buffer_a,
[tvm.ir.Range(0, 128),tvm.ir.Range(0, 128)],
True, tir.Evaluate(0))
expected_realize = tir.AttrStmt(buffer_a, "realize_scope",
tir.StringImm("test_storage_scope"),
buffer_realize)
assert_structural_equal(realize_actual, expected_realize, map_free_vars=True)


def test_ir_builder_tir_thread():
with IRBuilder() as ib: # pylint: disable=invalid-name
with T.prim_func():
brow = T.env_thread("blockIdx.y")
with T.launch_thread(brow, 1):
T.evaluate(0)
ir_actual = ib.get()
iter_var = tir.IterVar((0, 1), "v", iter_type=1, thread_tag="blockIdx.y")
attr_stmt = tir.AttrStmt(iter_var, "thread_extent", 1, tir.Evaluate(0))
func = tir.PrimFunc([], attr_stmt)
assert_structural_equal(ir_actual, func, map_free_vars=True)


if __name__ == "__main__":
pytest.main([__file__])

0 comments on commit 42db3cc

Please sign in to comment.