Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
break down interoperability test
Browse files Browse the repository at this point in the history
  • Loading branch information
szha committed Jul 2, 2020
1 parent ce0d1ba commit 6b9636d
Showing 1 changed file with 27 additions and 43 deletions.
70 changes: 27 additions & 43 deletions tests/python/unittest/test_numpy_interoperability.py
Original file line number Diff line number Diff line change
Expand Up @@ -3221,25 +3221,33 @@ def _check_interoperability_helper(op_name, *args, **kwargs):
_np.testing.assert_equal(out, expected_out)


def check_interoperability(op_list):
for name in op_list:
if name in _TVM_OPS and not is_op_runnable():
continue
if name in ['shares_memory', 'may_share_memory', 'empty_like',
'__version__', 'dtype', '_NoValue']: # skip list
continue
if name in ['delete']: # https://github.com/apache/incubator-mxnet/issues/18600
continue
if name in ['full_like', 'zeros_like', 'ones_like'] and \
StrictVersion(platform.python_version()) < StrictVersion('3.0.0'):
continue
print('Dispatch test:', name)
workloads = OpArgMngr.get_workloads(name)
assert workloads is not None, 'Workloads for operator `{}` has not been ' \
'added for checking interoperability with ' \
'the official NumPy.'.format(name)
for workload in workloads:
_check_interoperability_helper(name, *workload['args'], **workload['kwargs'])
@with_seed()
@use_np
@with_array_function_protocol
@with_array_ufunc_protocol
@pytest.mark.parametrize('name',
_NUMPY_ARRAY_FUNCTION_LIST \
+_NUMPY_ARRAY_UFUNC_LIST \
+np.fallback.__all__ \
+['linalg.{}'.format(op_name) for op_name in np.fallback_linalg.__all__])
def test_interoperability(name):
if name in _TVM_OPS and not is_op_runnable():
return
if name in ['shares_memory', 'may_share_memory', 'empty_like',
'__version__', 'dtype', '_NoValue']: # skip list
return
if name in ['delete']: # https://github.com/apache/incubator-mxnet/issues/18600
continue
if name in ['full_like', 'zeros_like', 'ones_like'] and \
StrictVersion(platform.python_version()) < StrictVersion('3.0.0'):
return
print('Dispatch test:', name)
workloads = OpArgMngr.get_workloads(name)
assert workloads is not None, 'Workloads for operator `{}` has not been ' \
'added for checking interoperability with ' \
'the official NumPy.'.format(name)
for workload in workloads:
_check_interoperability_helper(name, *workload['args'], **workload['kwargs'])


@with_seed()
Expand All @@ -3254,27 +3262,3 @@ def test_np_memory_array_function():
assert op(data_mx[0,:,:,:], data_mx[1,:,:,:]) == op(data_np[0,:,:,:], data_np[1,:,:,:])
assert op(data_mx[0,0,0,2:5], data_mx[0,0,0,4:7]) == op(data_np[0,0,0,2:5], data_np[0,0,0,4:7])
assert op(data_mx, np.ones((5, 0))) == op(data_np, _np.ones((5, 0)))


@with_seed()
@use_np
@with_array_function_protocol
@pytest.mark.serial
def test_np_array_function_protocol():
check_interoperability(_NUMPY_ARRAY_FUNCTION_LIST)


@with_seed()
@use_np
@with_array_ufunc_protocol
@pytest.mark.serial
def test_np_array_ufunc_protocol():
check_interoperability(_NUMPY_ARRAY_UFUNC_LIST)


@with_seed()
@use_np
@pytest.mark.serial
def test_np_fallback_ops():
op_list = np.fallback.__all__ + ['linalg.{}'.format(op_name) for op_name in np.fallback_linalg.__all__]
check_interoperability(op_list)

0 comments on commit 6b9636d

Please sign in to comment.