diff --git a/tests/python/unittest/test_numpy_interoperability.py b/tests/python/unittest/test_numpy_interoperability.py index 6a2845e0fb24..d0d22b5b372c 100644 --- a/tests/python/unittest/test_numpy_interoperability.py +++ b/tests/python/unittest/test_numpy_interoperability.py @@ -3221,25 +3221,33 @@ def _check_interoperability_helper(op_name, *args, **kwargs): _np.testing.assert_equal(out, expected_out) -def check_interoperability(op_list): - for name in op_list: - if name in _TVM_OPS and not is_op_runnable(): - continue - if name in ['shares_memory', 'may_share_memory', 'empty_like', - '__version__', 'dtype', '_NoValue']: # skip list - continue - if name in ['delete']: # https://github.com/apache/incubator-mxnet/issues/18600 - continue - if name in ['full_like', 'zeros_like', 'ones_like'] and \ - StrictVersion(platform.python_version()) < StrictVersion('3.0.0'): - continue - print('Dispatch test:', name) - workloads = OpArgMngr.get_workloads(name) - assert workloads is not None, 'Workloads for operator `{}` has not been ' \ - 'added for checking interoperability with ' \ - 'the official NumPy.'.format(name) - for workload in workloads: - _check_interoperability_helper(name, *workload['args'], **workload['kwargs']) +@with_seed() +@use_np +@with_array_function_protocol +@with_array_ufunc_protocol +@pytest.mark.parametrize('name', + _NUMPY_ARRAY_FUNCTION_LIST \ + +_NUMPY_ARRAY_UFUNC_LIST \ + +np.fallback.__all__ \ + +['linalg.{}'.format(op_name) for op_name in np.fallback_linalg.__all__]) +def test_interoperability(name): + if name in _TVM_OPS and not is_op_runnable(): + return + if name in ['shares_memory', 'may_share_memory', 'empty_like', + '__version__', 'dtype', '_NoValue']: # skip list + return + if name in ['delete']: # https://github.com/apache/incubator-mxnet/issues/18600 + continue + if name in ['full_like', 'zeros_like', 'ones_like'] and \ + StrictVersion(platform.python_version()) < StrictVersion('3.0.0'): + return + print('Dispatch test:', name) + workloads = OpArgMngr.get_workloads(name) + assert workloads is not None, 'Workloads for operator `{}` has not been ' \ + 'added for checking interoperability with ' \ + 'the official NumPy.'.format(name) + for workload in workloads: + _check_interoperability_helper(name, *workload['args'], **workload['kwargs']) @with_seed() @@ -3254,27 +3262,3 @@ def test_np_memory_array_function(): assert op(data_mx[0,:,:,:], data_mx[1,:,:,:]) == op(data_np[0,:,:,:], data_np[1,:,:,:]) assert op(data_mx[0,0,0,2:5], data_mx[0,0,0,4:7]) == op(data_np[0,0,0,2:5], data_np[0,0,0,4:7]) assert op(data_mx, np.ones((5, 0))) == op(data_np, _np.ones((5, 0))) - - -@with_seed() -@use_np -@with_array_function_protocol -@pytest.mark.serial -def test_np_array_function_protocol(): - check_interoperability(_NUMPY_ARRAY_FUNCTION_LIST) - - -@with_seed() -@use_np -@with_array_ufunc_protocol -@pytest.mark.serial -def test_np_array_ufunc_protocol(): - check_interoperability(_NUMPY_ARRAY_UFUNC_LIST) - - -@with_seed() -@use_np -@pytest.mark.serial -def test_np_fallback_ops(): - op_list = np.fallback.__all__ + ['linalg.{}'.format(op_name) for op_name in np.fallback_linalg.__all__] - check_interoperability(op_list)