Skip to content

Commit

Permalink
[BugFix][TVMScript] Use operator is when recognizing TIR Module (ap…
Browse files Browse the repository at this point in the history
…ache#10175)

* [BugFix][TVMScript] Use operator `is` when recognizing TIR module

* Test
  • Loading branch information
MasterJH5574 authored and ylc committed Feb 16, 2022
1 parent 6355818 commit de8a0b7
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 1 deletion.
2 changes: 1 addition & 1 deletion python/tvm/script/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -1226,7 +1226,7 @@ def from_source(
elif inspect.isfunction(input_func):
_, start_line = inspect.getsourcelines(input_func)
env: Dict[str, Any] = input_func.__globals__
namespace = [key for key in env.keys() if env[key] == tir]
namespace = [key for key in env.keys() if env[key] is tir]
parser = TVMScriptParser(start_line, namespace)
result = to_ast(input_func, TVMDiagnosticCtx(), parser)
return result
Expand Down
49 changes: 49 additions & 0 deletions tests/python/unittest/test_tvmscript_regression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import numpy

import tvm
from tvm.script import tir as T


# This numpy array is used to test the comparison between the global objects and the
# `tvm.script.tir` submodule.
np_array = numpy.array([0, 1, 2, 3])


@T.prim_func
def matmul(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, [128, 128])
B = T.match_buffer(b, [128, 128])
C = T.match_buffer(c, [128, 128])

for i, j, k in T.grid(128, 128, 128):
with T.block("update"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
C[vi, vj] = T.float32(0)
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]


def test_multi_element_array_in_outmost_namespace():
func = matmul
rt_func = tvm.script.from_source(func.script(show_meta=True))
tvm.ir.assert_structural_equal(func, rt_func)


if __name__ == "__main__":
test_multi_element_array_in_outmost_namespace()

0 comments on commit de8a0b7

Please sign in to comment.