Skip to content

Commit

Permalink
add more thrust scan test
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Jan 18, 2021
1 parent 1e61ba9 commit 7a0403d
Showing 1 changed file with 46 additions and 7 deletions.
53 changes: 46 additions & 7 deletions tests/python/contrib/test_thrust.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import tvm.testing
from tvm import te
from tvm.topi.cuda import stable_sort_by_key_thrust, is_thrust_available
from tvm.topi.cuda.scan import exclusive_scan, schedule_scan
from tvm.topi.cuda.scan import exclusive_scan, scan_thrust, schedule_scan
import numpy as np


Expand Down Expand Up @@ -54,31 +54,70 @@ def test_stable_sort_by_key():
tvm.testing.assert_allclose(values_out.asnumpy(), ref_values_out, rtol=1e-5)


def test_scan():
def test_exclusive_scan():
if not is_thrust_available():
print("skip because thrust is not enabled...")
return

for ishape in [(1,), (10, 10)]:
values = te.placeholder(ishape, name="values", dtype="int32")

with tvm.target.Target("cuda"):
scan, reduction = exclusive_scan(values, return_reduction=True)
s = schedule_scan([scan, reduction])

ctx = tvm.gpu(0)
f = tvm.build(s, [values, scan, reduction], "cuda")

values_np = np.random.randint(0, 10, size=ishape).astype(np.int32)
values_np_out = np.zeros(values_np.shape, np.int32)

if len(ishape) == 1:
reduction_shape = ()
else:
reduction_shape = (ishape[0],)

reduction_np_out = np.zeros(reduction_shape, np.int32)

values_in = tvm.nd.array(values_np, ctx)
values_out = tvm.nd.array(values_np_out, ctx)
reduction_out = tvm.nd.array(reduction_np_out, ctx)
f(values_in, values_out, reduction_out)

ref_values_out = np.cumsum(values_np, axis=-1, dtype="int32") - values_np
tvm.testing.assert_allclose(values_out.asnumpy(), ref_values_out, rtol=1e-5)
ref_reduction_out = np.sum(values_np, axis=-1)
tvm.testing.assert_allclose(reduction_out.asnumpy(), ref_reduction_out, rtol=1e-5)


def test_inclusive_scan():
if not is_thrust_available():
print("skip because thrust is not enabled...")
return

out_dtype = "int64"

for ishape in [(10,), (10, 10)]:
values = te.placeholder(ishape, name="values", dtype="int32")

with tvm.target.Target("cuda"):
scan = exclusive_scan(values)
s = schedule_scan([scan])
scan = scan_thrust(values, out_dtype, exclusive=False)
s = tvm.te.create_schedule([scan.op])

ctx = tvm.gpu(0)
f = tvm.build(s, [values, scan], "cuda")

values_np = np.random.randint(0, 10, size=ishape).astype(np.int32)
values_np_out = np.zeros(values_np.shape, np.int32)
values_np_out = np.zeros(values_np.shape, out_dtype)
values_in = tvm.nd.array(values_np, ctx)
values_out = tvm.nd.array(values_np_out, ctx)
f(values_in, values_out)

ref_values_out = np.cumsum(values_np, axis=-1, dtype="int32") - values_np
ref_values_out = np.cumsum(values_np, axis=-1, dtype=out_dtype)
tvm.testing.assert_allclose(values_out.asnumpy(), ref_values_out, rtol=1e-5)


if __name__ == "__main__":
test_stable_sort_by_key()
test_scan()
test_exclusive_scan()
test_inclusive_scan()

0 comments on commit 7a0403d

Please sign in to comment.