Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix test error
Browse files Browse the repository at this point in the history
Ubuntu committed Jun 16, 2020

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent c251837 commit 14e4b9e
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
@@ -1505,9 +1505,14 @@ def check_index_update_forward(mx_ret, a, ind, val, ind_ndim, ind_num, eps):
else:
expect_tmp = val
mx_tmp = mx_ret[t_ind]
if _np.allclose(expect_tmp, mx_tmp, rtol=eps, atol=eps):
mx_ret[t_ind] = 0
a[t_ind] = 0
close_pos = _np.where(_np.isclose(expect_tmp, mx_tmp, rtol=eps, atol=eps))
if a[t_ind].ndim == 0:
if close_pos[0].size == 1:
mx_ret[t_ind] = 0
a[t_ind] = 0
else:
mx_ret[t_ind][close_pos] = 0
a[t_ind][close_pos] = 0
assert_almost_equal(mx_ret, a, rtol=eps, atol=eps)

def index_update_bwd(out_grad, a_grad, ind, val_grad, ind_ndim, ind_num, grad_req_a, grad_req_val):
@@ -1585,8 +1590,6 @@ def index_update_bwd(out_grad, a_grad, ind, val_grad, ind_ndim, ind_num, grad_re
itertools.product([True, False], grad_req, grad_req, dtypes, ['int32', 'int64']):
for a_shape, ind, val_shape ,ind_ndim, ind_num in configs:
eps = 1e-3
if sys.platform.startswith('linux'):
eps = 1e-2
atype = dtype
valtype = dtype
test_index_update = TestIndexUpdate()

0 comments on commit 14e4b9e

Please sign in to comment.