forked from apache/tvm
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[TIR] End-to-end tests for PrimFunc-to-PrimFunc subroutines
The functionality tested in this commit was added across several recent PRs, each of which tested their features in isolation. This PR adds unit tests to validate the end-to-end behavior of TIR subroutine calls. PRs building up to this point: - TVMScript - apache#14889 - apache#14915 - apache#14919 - apache#14941 - Functionality improvements of existing TIR passes - apache#14913 - apache#14914 - apache#14918 - apache#14951 - Changes to the TIR lowering flow - apache#14942 - apache#14985 - Codegen updates - apache#14958 - apache#14901 - Compatibility updates/fixes - apache#14892 - apache#14950 - apache#14943 - apache#14944 - apache#14945 - apache#14952 - apache#14982 - apache#14949
- Loading branch information
1 parent
465f2bb
commit 5c02044
Showing
1 changed file
with
275 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,275 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
# pylint: disable=missing-function-docstring,missing-module-docstring | ||
|
||
import pytest | ||
import numpy as np | ||
|
||
import tvm | ||
import tvm.testing | ||
|
||
from tvm.script import tir as T, ir as I | ||
|
||
|
||
@tvm.testing.parametrize_targets("llvm") | ||
def test_call_noop(target, dev): | ||
"""TIR functions on the CPU may call other functions | ||
The simplest test case, where the subroutine is a no-op. | ||
""" | ||
|
||
@I.ir_module | ||
class module: | ||
@T.prim_func | ||
def subroutine(): | ||
T.evaluate(0) | ||
|
||
@T.prim_func | ||
def main(A: T.Buffer(1, "float32")): | ||
T.func_attr({"global_symbol": "main"}) | ||
module.subroutine() | ||
A[0] = 42.0 | ||
|
||
built = tvm.build(module, target=target) | ||
|
||
arr = tvm.nd.empty([1], dtype="float32", device=dev) | ||
built(arr) | ||
|
||
assert arr.numpy()[0] == 42.0 | ||
|
||
|
||
@tvm.testing.parametrize_targets("llvm") | ||
def test_call_noop_defined_below(target, dev): | ||
"""Calling a subroutine does not depend on the definition order | ||
All GlobalVar instances are in-scope for subroutine calls. | ||
""" | ||
|
||
@I.ir_module | ||
class module: | ||
@T.prim_func | ||
def main(A: T.Buffer(1, "float32")): | ||
T.func_attr({"global_symbol": "main"}) | ||
module.subroutine() | ||
A[0] = 42.0 | ||
|
||
@T.prim_func | ||
def subroutine(): | ||
T.evaluate(0) | ||
|
||
built = tvm.build(module, target=target) | ||
|
||
arr = tvm.nd.empty([1], dtype="float32", device=dev) | ||
built(arr) | ||
|
||
assert arr.numpy()[0] == 42.0 | ||
|
||
|
||
@tvm.testing.parametrize_targets("llvm") | ||
def test_subroutine_call_with_pointer_param(target, dev): | ||
"""TIR functions on the CPU may call other functions | ||
Buffers may be exposed to subroutines through data pointers. | ||
""" | ||
|
||
@I.ir_module | ||
class module: | ||
@T.prim_func | ||
def main(A: T.Buffer(2, "float32")): | ||
T.func_attr({"global_symbol": "main"}) | ||
module.subroutine(A.data) | ||
module.subroutine(T.address_of(A[1])) | ||
|
||
@T.prim_func | ||
def subroutine(A_data: T.handle("float32")): | ||
A = T.decl_buffer(shape=[1], dtype="float32", data=A_data) | ||
A[0] = 42.0 | ||
|
||
built = tvm.build(module, target=target) | ||
|
||
arr = tvm.nd.empty([2], dtype="float32", device=dev) | ||
built(arr) | ||
|
||
assert arr.numpy()[0] == 42.0 | ||
assert arr.numpy()[1] == 42.0 | ||
|
||
|
||
@pytest.mark.xfail(reason="Depends on LLVM version") | ||
@tvm.testing.parametrize_targets("llvm") | ||
def test_failed_subroutine_call_for_incorrect_type(target, dev): | ||
"""Calls into a subroutine must have correct argument types | ||
This currently relies on the `llvm::verifyModule` function during | ||
codegen. In the future, this should be moved to a dedicated check | ||
of TIR validity. | ||
""" | ||
|
||
@I.ir_module | ||
class module: | ||
@T.prim_func | ||
def main(A: T.Buffer(1, "float32")): | ||
T.func_attr({"global_symbol": "main"}) | ||
module.subroutine(A.data) | ||
|
||
@T.prim_func | ||
def subroutine(A_data: T.handle("int32")): | ||
A = T.decl_buffer(shape=[1], dtype="int32", data=A_data) | ||
A[0] = -1 | ||
|
||
lowered = tvm.lower(module) | ||
with pytest.raises(tvm.TVMError): | ||
tvm.build(lowered) | ||
|
||
|
||
@tvm.testing.parametrize_targets("llvm") | ||
def test_subroutine_call_with_scalar_param(target, dev): | ||
"""Subroutines may also accept scalar parameters""" | ||
|
||
@I.ir_module | ||
class module: | ||
@T.prim_func | ||
def main(A: T.Buffer(1, "float32")): | ||
T.func_attr({"global_symbol": "main"}) | ||
module.subroutine(A.data, 42.0) | ||
|
||
@T.prim_func | ||
def subroutine(A_data: T.handle("float32"), val: T.float32): | ||
A = T.decl_buffer([1], "float32", data=A_data) | ||
A[0] = 2 * val | ||
|
||
built = tvm.build(module, target=target) | ||
|
||
arr = tvm.nd.empty([1], dtype="float32", device=dev) | ||
built(arr) | ||
|
||
assert arr.numpy()[0] == 84.0 | ||
|
||
|
||
@tvm.testing.parametrize_targets("llvm") | ||
def test_internal_subroutine_is_not_exposed_externally(target, dev): | ||
"""An internal subroutine may not be called externally | ||
An internal subroutine is any subroutine without a "global_symbol" | ||
attribute. These are not exposed in the runtime::Module and do | ||
not have an externally linkable symbol. | ||
""" | ||
|
||
@I.ir_module | ||
class module: | ||
@T.prim_func | ||
def main(A: T.Buffer(1, "float32")): | ||
T.func_attr({"global_symbol": "main"}) | ||
module.subroutine(A.data, 42.0) | ||
|
||
@T.prim_func | ||
def subroutine(A_data: T.handle("float32"), val: T.float32): | ||
A = T.decl_buffer([1], "float32", data=A_data) | ||
A[0] = 2 * val | ||
|
||
built = tvm.build(module, target=target) | ||
with pytest.raises(AttributeError): | ||
built["subroutine"] | ||
|
||
|
||
@tvm.testing.parametrize_targets("llvm") | ||
def test_call_to_externally_visible_subroutine(target, dev): | ||
"""Subroutines may be exposed externally. | ||
A subroutine may be exposed externally. Externally-exposed | ||
subroutines may be called by an external API, or may be called by | ||
other functions in the same IRModule. | ||
The current implementation lowers internal subroutine calls to | ||
`T.tvm_call_cpacked`. This avoids the overhead of the global | ||
registry lookup used by `T.tvm_call_packed`, but still requires | ||
the overhead of packing/unpacking the `PackedFunc` interface, and | ||
is limited to callers whose target supports the `PackedFunc` | ||
interface. | ||
""" | ||
|
||
@I.ir_module | ||
class module: | ||
@T.prim_func | ||
def main(A: T.Buffer(1, "float32")): | ||
T.func_attr({"global_symbol": "main"}) | ||
module.subroutine(A.data, 42.0) | ||
|
||
@T.prim_func | ||
def subroutine(A_data: T.handle("float32"), val: T.float32): | ||
T.func_attr({"global_symbol": "subroutine"}) | ||
A = T.Buffer([1], "float32", data=A_data) | ||
A[0] = 2 * val | ||
|
||
built = tvm.build(module, target=target) | ||
|
||
arr = tvm.nd.empty([1], dtype="float32", device=dev) | ||
built["main"](arr) | ||
assert arr.numpy()[0] == 84.0 | ||
|
||
arr = np.zeros(shape=[1], dtype="float32") | ||
built["subroutine"](arr.ctypes._data, 100.0) | ||
assert arr[0] == 200.0 | ||
|
||
|
||
is_external_subroutine = tvm.testing.parameter(by_dict={"external": True, "internal": False}) | ||
|
||
|
||
@tvm.testing.parametrize_targets("llvm", "cuda") | ||
def test_call_to_device_subroutine(target, dev, is_external_subroutine): | ||
"""Subroutines may be exposed externally. | ||
This feature is currently limited to host-side subroutine calls of | ||
externally-exposed subroutines. | ||
""" | ||
is_gpu = "gpu" in tvm.target.Target(target).keys | ||
|
||
if is_gpu and not is_external_subroutine: | ||
pytest.xfail(reason="Not yet implemented.") | ||
|
||
if is_external_subroutine: | ||
func_attr = {"global_symbol": "subroutine"} | ||
else: | ||
func_attr = {} | ||
|
||
@I.ir_module | ||
class module: | ||
@T.prim_func | ||
def main(A: T.Buffer(1, "float32")): | ||
T.func_attr({"global_symbol": "main"}) | ||
module.subroutine(A.data, 42.0) | ||
|
||
@T.prim_func | ||
def subroutine(A_data: T.handle("float32"), val: T.float32): | ||
T.func_attr(func_attr) | ||
A = T.Buffer([1], "float32", data=A_data) | ||
iterator = T.meta_var( | ||
T.thread_binding(0, 1, thread="threadIdx.x") if is_gpu else range(1) | ||
) | ||
for i in iterator: | ||
A[0] = 2 * val | ||
|
||
built = tvm.build(module, target=target) | ||
|
||
arr = tvm.nd.empty([1], dtype="float32", device=dev) | ||
built["main"](arr) | ||
assert arr.numpy()[0] == 84.0 | ||
|
||
|
||
if __name__ == "__main__": | ||
tvm.testing.main() |