Skip to content

Commit

Permalink
add boolean support for concatenate (apache#18213)
Browse files Browse the repository at this point in the history
  • Loading branch information
Tommliu committed May 21, 2020
1 parent 67b5d31 commit ef8c4c0
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 42 deletions.
4 changes: 2 additions & 2 deletions src/operator/numpy/np_matrix_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1153,7 +1153,7 @@ void NumpyConcatenateForward(const nnvm::NodeAttrs& attrs,
ConcatParam cparam;
cparam.num_args = param.num_args;
cparam.dim = param.axis.has_value() ? param.axis.value() : 0;
MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {
MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, DType, {
ConcatOp<xpu, DType> op;
op.Init(cparam);
op.Forward(ctx, data, req, outputs);
Expand Down Expand Up @@ -1186,7 +1186,7 @@ void NumpyConcatenateBackward(const nnvm::NodeAttrs& attrs,
ConcatParam cparam;
cparam.num_args = param.num_args;
cparam.dim = param.axis.has_value() ? param.axis.value() : 0;
MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {
MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, DType, {
ConcatOp<xpu, DType> op;
op.Init(cparam);
op.Backward(ctx, inputs[0], req, data);
Expand Down
84 changes: 44 additions & 40 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -3240,52 +3240,56 @@ def get_new_shape(shape, axis):
shape_lst[axis] = random.randint(0, 3)
return tuple(shape_lst)

for shape in [(0, 0), (2, 3), (2, 1, 3)]:
for hybridize in [True, False]:
for axis in [0, 1, None]:
for grad_req in ['write', 'add', 'null']:
# test gluon
test_concat = TestConcat(axis=axis)
if hybridize:
test_concat.hybridize()
shapes = [(0, 0), (2, 3), (2, 1, 3)]
hybridizes = [True, False]
axes = [0, 1, None]
grad_reqs = ['write', 'add', 'null']
dtypes = [np.float32, np.float64, np.bool]
combinations = itertools.product(shapes, hybridizes, axes, grad_reqs, dtypes)

grad_req_c = grad_req
grad_req_d = grad_req
if grad_req == 'null':
ide = random.randint(0, 2)
grad_req_c = 'write' if ide == 0 else 'add'
grad_req_c = 'write' if ide == 1 else 'add'
for shape, hybridize, axis, grad_req, dtype in combinations:
# test gluon
test_concat = TestConcat(axis=axis)
if hybridize:
test_concat.hybridize()

a = mx.nd.random.uniform(-1.0, 1.0, shape=get_new_shape(shape, axis)).as_np_ndarray()
a.attach_grad(grad_req)
b = mx.nd.random.uniform(-1.0, 1.0, shape=get_new_shape(shape, axis)).as_np_ndarray()
b.attach_grad(grad_req)
c = mx.nd.random.uniform(-1.0, 1.0, shape=get_new_shape(shape, axis)).as_np_ndarray()
c.attach_grad(grad_req_c)
d = mx.nd.random.uniform(-1.0, 1.0, shape=get_new_shape(shape, axis)).as_np_ndarray()
d.attach_grad(grad_req_d)
expected_ret = _np.concatenate([a.asnumpy(), b.asnumpy(), c.asnumpy(), d.asnumpy()], axis=axis)
grad_req_c = grad_req
grad_req_d = grad_req
if grad_req == 'null':
ide = random.randint(0, 2)
grad_req_c = 'write' if ide == 0 else 'add'
grad_req_c = 'write' if ide == 1 else 'add'

with mx.autograd.record():
y = test_concat(a, b, c, d)
a = np.random.uniform(-1.0, 1.0, size=get_new_shape(shape, axis)).astype(dtype)
a.attach_grad(grad_req)
b = np.random.uniform(-1.0, 1.0, size=get_new_shape(shape, axis)).astype(dtype)
b.attach_grad(grad_req)
c = np.random.uniform(-1.0, 1.0, size=get_new_shape(shape, axis)).astype(dtype)
c.attach_grad(grad_req_c)
d = np.random.uniform(-1.0, 1.0, size=get_new_shape(shape, axis)).astype(dtype)
d.attach_grad(grad_req_d)
expected_ret = _np.concatenate([a.asnumpy(), b.asnumpy(), c.asnumpy(), d.asnumpy()], axis=axis)

assert y.shape == expected_ret.shape
assert_almost_equal(y.asnumpy(), expected_ret, rtol=1e-3, atol=1e-5)
with mx.autograd.record():
y = test_concat(a, b, c, d)

y.backward()
if grad_req != 'null':
assert_almost_equal(a.grad.asnumpy(), _np.ones(a.shape), rtol=1e-3, atol=1e-5)
if grad_req != 'null':
assert_almost_equal(b.grad.asnumpy(), _np.ones(b.shape), rtol=1e-3, atol=1e-5)
if grad_req_c != 'null':
assert_almost_equal(c.grad.asnumpy(), _np.ones(c.shape), rtol=1e-3, atol=1e-5)
if grad_req_d != 'null':
assert_almost_equal(d.grad.asnumpy(), _np.ones(d.shape), rtol=1e-3, atol=1e-5)
assert y.shape == expected_ret.shape
assert_almost_equal(y.asnumpy(), expected_ret, rtol=1e-3, atol=1e-5)

# test imperative
mx_out = np.concatenate([a, b, c, d], axis=axis)
np_out = _np.concatenate([a.asnumpy(), b.asnumpy(), c.asnumpy(), d.asnumpy()], axis=axis)
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)
y.backward()
if grad_req != 'null':
assert_almost_equal(a.grad.asnumpy(), _np.ones(a.shape), rtol=1e-3, atol=1e-5)
if grad_req != 'null':
assert_almost_equal(b.grad.asnumpy(), _np.ones(b.shape), rtol=1e-3, atol=1e-5)
if grad_req_c != 'null':
assert_almost_equal(c.grad.asnumpy(), _np.ones(c.shape), rtol=1e-3, atol=1e-5)
if grad_req_d != 'null':
assert_almost_equal(d.grad.asnumpy(), _np.ones(d.shape), rtol=1e-3, atol=1e-5)

# test imperative
mx_out = np.concatenate([a, b, c, d], axis=axis)
np_out = _np.concatenate([a.asnumpy(), b.asnumpy(), c.asnumpy(), d.asnumpy()], axis=axis)
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)


@with_seed()
Expand Down

0 comments on commit ef8c4c0

Please sign in to comment.