Skip to content

Commit

Permalink
Modify more tests (apache#10)
Browse files Browse the repository at this point in the history
* Modify tests for bitserial_conv2d, bitserial_dense, bitserial_conv2d_rasp and bnn

* Minor fix

* More minor fix
  • Loading branch information
kevinthesun authored and icemelon committed Feb 12, 2020
1 parent 89247d0 commit 955a411
Show file tree
Hide file tree
Showing 7 changed files with 43 additions and 39 deletions.
2 changes: 0 additions & 2 deletions topi/python/topi/generic/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,6 @@ def schedule_adaptive_pool(outs):
return _default_schedule(outs, False)


@tvm.target.override_native_generic_func("schedule_binarize_pack")
def schedule_binarize_pack(outs):
"""Schedule for binarize_pack
Expand Down Expand Up @@ -566,7 +565,6 @@ def schedule_bitpack(outs):
return _default_schedule(outs, False)


@tvm.target.override_native_generic_func("schedule_binary_dense")
def schedule_binary_dense(outs):
"""Schedule for binary_dense
Expand Down
2 changes: 0 additions & 2 deletions topi/python/topi/x86/binarize_pack.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,8 @@
"""Schedule for binarization and bit-packing."""
from __future__ import absolute_import as _abs
import tvm
from .. import generic


@generic.schedule_binarize_pack.register(["cpu"])
def schedule_binarize_pack(outs):
"""Schedule for binarize_pack.
Expand Down
2 changes: 0 additions & 2 deletions topi/python/topi/x86/binary_dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,8 @@
from __future__ import absolute_import as _abs
import tvm
from .. import tag
from .. import generic


@generic.schedule_binary_dense.register(["cpu"])
def schedule_binary_dense(outs):
"""Schedule for binary_dense.
Expand Down
12 changes: 6 additions & 6 deletions topi/tests/python/test_topi_bitserial_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ def verify_bitserial_conv2d_nchw(batch, in_size, in_channel, num_filter, kernel,
with tvm.target.create('llvm'):
A = tvm.placeholder((batch, in_channel, in_height, in_width), dtype=input_dtype, name='A')
W = tvm.placeholder((num_filter, in_channel, kernel, kernel), dtype=input_dtype, name='W')
B = topi.nn.bitserial_conv2d_nchw(A, W, stride, padding, activation_bits, weight_bits,
out_dtype=out_dtype, unipolar=unipolar)
s = topi.generic.schedule_bitserial_conv2d_nchw([B])
B = topi.x86.bitserial_conv2d_nchw(A, W, stride, padding, activation_bits, weight_bits,
input_dtype, out_dtype, unipolar)
s = topi.x86.schedule_bitserial_conv2d_nchw([B])

a_shape = get_const_tuple(A.shape)
w_shape = get_const_tuple(W.shape)
Expand Down Expand Up @@ -73,9 +73,9 @@ def verify_bitserial_conv2d_nhwc(batch, in_size, in_channel, num_filter, kernel,
with tvm.target.create('llvm'):
A = tvm.placeholder((batch, in_height, in_width, in_channel), dtype=input_dtype, name='A')
W = tvm.placeholder((kernel, kernel, in_channel, num_filter), dtype=input_dtype, name='W')
B = topi.nn.bitserial_conv2d_nhwc(A, W, stride, padding, activation_bits, weight_bits,
out_dtype=out_dtype, unipolar=unipolar)
s = topi.generic.schedule_bitserial_conv2d_nhwc([B])
B = topi.x86.bitserial_conv2d_nhwc(A, W, stride, padding, activation_bits, weight_bits,
input_dtype, out_dtype, unipolar)
s = topi.x86.schedule_bitserial_conv2d_nhwc([B])

a_shape = get_const_tuple(A.shape)
w_shape = get_const_tuple(W.shape)
Expand Down
6 changes: 3 additions & 3 deletions topi/tests/python/test_topi_bitserial_conv2d_rasp.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ def verify_bitserial_conv2d_nhwc(batch, in_size, in_channel, num_filter, kernel,
with tvm.target.create(device):
A = tvm.placeholder((batch, in_height, in_width, in_channel), dtype=input_type, name='A')
W = tvm.placeholder((kernel, kernel, in_channel, num_filter), dtype=input_type, name='W')
B = topi.nn.bitserial_conv2d_nhwc(A, W, stride, padding, activation_bits, weight_bits,
pack_dtype='uint8', out_dtype='int16', unipolar=unipolar)
s = topi.generic.schedule_bitserial_conv2d_nhwc([B])
B = topi.arm_cpu.bitserial_conv2d_nhwc(A, W, stride, padding, activation_bits, weight_bits,
'uint8', out_dtype, unipolar)
s = topi.arm_cpu.schedule_bitserial_conv2d_nhwc([B])

func = tvm.build(s, [A, W, B], device)

Expand Down
52 changes: 31 additions & 21 deletions topi/tests/python/test_topi_bitserial_dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,34 +15,29 @@
# specific language governing permissions and limitations
# under the License.
"""Test code for bitserial_dense operator"""
import os
import numpy as np
import tvm
import topi
import topi.testing
from topi.util import get_const_tuple
from tvm.contrib.pickle_memoize import memoize

_bitserial_dense_implement = {
"generic": (topi.nn.bitserial_dense, topi.generic.schedule_bitserial_dense),
"cpu": (topi.x86.bitserial_dense, topi.x86.schedule_bitserial_dense),
"arm_cpu": (topi.arm_cpu.bitserial_dense, topi.arm_cpu.schedule_bitserial_dense),
}

def generate_quantized_np(shape, bits, out_dtype):
min_val = 0
max_val = 1 << bits
return np.random.randint(min_val, max_val, size=shape).astype(out_dtype)

def verify_bitserial_dense(batch, in_dim, out_dim, activation_bits, weight_bits, unipolar):
input_dtype = 'uint32'
out_dtype = 'int16'

with tvm.target.create('llvm'):
A = tvm.placeholder((batch, in_dim), dtype=input_dtype, name='A')
B = tvm.placeholder((out_dim, in_dim), dtype=input_dtype, name='B')
C = topi.nn.bitserial_dense(A, B, activation_bits, weight_bits, out_dtype=out_dtype,
unipolar=unipolar)
s = topi.generic.schedule_bitserial_dense([C])

a_shape = get_const_tuple(A.shape)
b_shape = get_const_tuple(B.shape)

@memoize("topi.tests.test_topi_bitseral_dense")
def get_ref_data():
def get_ref_data(a_shape, b_shape, input_dtype):
a_np = generate_quantized_np(get_const_tuple(a_shape), activation_bits, input_dtype)
b_np = generate_quantized_np(get_const_tuple(b_shape), weight_bits, input_dtype)
if unipolar:
Expand All @@ -53,15 +48,30 @@ def get_ref_data():
else:
c_np = np.dot(a_np, b_np.T)
return a_np, b_np, c_np
a_np, b_np, c_np = get_ref_data()

ctx = tvm.cpu(0)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(b_np, ctx)
c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx)
func = tvm.build(s, [A, B, C], "llvm")
func(a, b, c)
tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
for target in ["llvm", "llvm -device=arm_cpu"]:
if "arm_cpu" in target and 'arm' not in os.uname()[4]:
print ("Skipped running code, not an arm device")
continue
input_dtype = 'uint8' if "arm_cpu" in target else "uint32"
A = tvm.placeholder((batch, in_dim), dtype=input_dtype, name='A')
B = tvm.placeholder((out_dim, in_dim), dtype=input_dtype, name='B')
fcompute, fschedule = topi.testing.dispatch(target, _bitserial_dense_implement)
C = fcompute(A, B, activation_bits, weight_bits,
input_dtype, out_dtype, unipolar)
s = fschedule([C])

a_shape = get_const_tuple(A.shape)
b_shape = get_const_tuple(B.shape)
a_np, b_np, c_np = get_ref_data(a_shape, b_shape, input_dtype)

ctx = tvm.cpu(0)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(b_np, ctx)
c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx)
func = tvm.build(s, [A, B, C], target)
func(a, b, c)
tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)

def test_bitserial_dense():
verify_bitserial_dense(1, 1024, 1000, 1, 1, True)
Expand Down
6 changes: 3 additions & 3 deletions topi/tests/python/test_topi_bnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ def verify_binary_dense(batch, in_dim, out_dim):
bnn_C = topi.nn.binary_dense(bnn_A1, bnn_B1)
# schedule
with tvm.target.create('llvm'):
s1 = topi.generic.schedule_binarize_pack(bnn_A)
s2 = topi.generic.schedule_binarize_pack(bnn_B)
s3 = topi.generic.schedule_binary_dense(bnn_C)
s1 = topi.x86.schedule_binarize_pack(bnn_A)
s2 = topi.x86.schedule_binarize_pack(bnn_B)
s3 = topi.x86.schedule_binary_dense(bnn_C)

dtype = A.dtype
@memoize("topi.tests.test_topi_binary_dense")
Expand Down

0 comments on commit 955a411

Please sign in to comment.