Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
skip if vulkan is not enabled
Browse files Browse the repository at this point in the history
masahi committed Mar 9, 2021
1 parent dc3b3eb commit 98d618b
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions tests/python/unittest/test_target_codegen_spirv.py
Original file line number Diff line number Diff line change
@@ -73,10 +73,13 @@ def do_copy(A, B, n):


def test_pushconstants():
if not tvm.testing.device_enabled("vulkan"):
return

def check_mod(mod, x_np, res_np):
tgt = "vulkan"
ctx = tvm.context("vulkan", 0)
ex = relay.create_executor("vm", mod=mod, ctx=ctx, target=tgt)
target = "vulkan"
ctx = tvm.context(target, 0)
ex = relay.create_executor("vm", mod=mod, ctx=ctx, target=target)
res = ex.evaluate()(x_np).asnumpy()
tvm.testing.assert_allclose(res, res_np, atol=1e-5)

@@ -92,7 +95,7 @@ def check_mod(mod, x_np, res_np):

# One 64 bit and one 32 bit constants
dtype = "int32"
x = relay.var("x", shape=(5,), dtype=dtype)
x = relay.var("x", shape=(relay.Any(),), dtype=dtype)
mod = tvm.IRModule()
mod["main"] = relay.Function([x], relay.cumsum(x))
x_np = np.random.randint(0, high=10, size=(10,)).astype(dtype)

0 comments on commit 98d618b

Please sign in to comment.