Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
handle negative values for shape
Browse files Browse the repository at this point in the history
  • Loading branch information
vandanavk committed Jan 2, 2020
1 parent b9c96a7 commit 5486321
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 16 deletions.
16 changes: 7 additions & 9 deletions src/operator/contrib/dynamic_shape_ops-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@
#ifndef MXNET_OPERATOR_CONTRIB_DYNAMIC_SHAPE_OPS_INL_H_
#define MXNET_OPERATOR_CONTRIB_DYNAMIC_SHAPE_OPS_INL_H_

#include <vector>
#include "../mxnet_op.h"
#include "../mshadow_op.h"
#include <vector>
#include "../tensor/matrix_op-inl.h"

namespace mxnet {
namespace op {
Expand All @@ -43,18 +44,20 @@ inline void DynamicReshapeForward(const nnvm::NodeAttrs& attrs,
const NDArray &idx = inputs[1];
size_t idx_size = idx.shape()[0];
mxnet::TShape shapevalue = mxnet::TShape(idx_size, 0);
std::vector<int> shapev(idx_size, 0);
MSHADOW_TYPE_SWITCH(idx.dtype(), DType, {
DType* idx_dptr = idx.data().dptr<DType>();
for (size_t i = 0; i < idx_size; i++) {
shapevalue[i] = idx_dptr[i];
shapev[i] = idx_dptr[i];
}
});
shapevalue = InferReshapeShape(mxnet::Tuple<int>(shapev), inputs[0].shape(), false);
const_cast<NDArray &>(out).Init(shapevalue);
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();

MSHADOW_TYPE_SWITCH(outputs[0].dtype(), DType, {
mxnet_op::Kernel<mxnet_op::op_with_req<mshadow_op::identity, kWriteTo>, xpu>::Launch(
s, inputs[0].data().Size(), outputs[0].data().dptr<DType>(),
s, inputs[0].data().Size(), outputs[0].data().dptr<DType>(),
inputs[0].data().dptr<DType>());
});
}
Expand All @@ -67,16 +70,11 @@ inline void DynamicReshapeBackward(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray> &outputs) {
CHECK_EQ(inputs.size(), 1U);
CHECK_EQ(outputs.size(), 2U);
const NDArray& ograd = inputs[0];
const NDArray& igrad_shape = outputs[1];
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();

mxnet::TShape opshape = ograd.shape();
const_cast<NDArray &>(igrad_shape).Init(opshape);

MSHADOW_TYPE_SWITCH(outputs[0].dtype(), DType, {
mxnet_op::Kernel<mxnet_op::op_with_req<mshadow_op::identity, kWriteTo>, xpu>::Launch(
s, inputs[0].data().Size(), outputs[0].data().dptr<DType>(),
s, inputs[0].data().Size(), outputs[0].data().dptr<DType>(),
inputs[0].data().dptr<DType>());
});
}
Expand Down
22 changes: 15 additions & 7 deletions tests/python/unittest/test_contrib_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ def test_modulated_deformable_convolution():

@with_seed()
def test_dynamic_reshape():
def dynamic_reshape_testcases(src_shape, dst_shape):
def dynamic_reshape_testcases(src_shape, shape_arg, dst_shape):
data = mx.sym.Variable('data')
shape = mx.sym.Variable('shape')
net = mx.sym.contrib.dynamic_reshape(data, shape)
Expand All @@ -457,7 +457,7 @@ def dynamic_reshape_testcases(src_shape, dst_shape):
grad_npy = np.random.rand(*dst_shape)
args = {
'data': mx.nd.array(dat_npy),
'shape': mx.nd.array(dst_shape)
'shape': mx.nd.array(shape_arg)
}
args_grad = {
'data': mx.nd.empty(src_shape),
Expand All @@ -470,11 +470,19 @@ def dynamic_reshape_testcases(src_shape, dst_shape):
assert np.square(exe.grad_dict['data'].asnumpy() - grad_npy.reshape(src_shape)).mean() < 1E-7

test_cases = [
[(2, 3, 5, 5), (2, 75)],
[(2, 3, 5, 5), (2, 3, 25)],
[(64,), (16, 4)],
[(64, 1, 2, 3), (16, 4, 1, 2, 3)],
[(2, 4, 5, 3), (30, 2, 2, 1)]]
[(2, 3, 5, 5), (0, -1), (2, 75)],
[(2, 3, 5, 5), (0, 0, -1), (2, 3, 25)],
[(5, 3, 4, 5), (0, -1, 0), (5, 15, 4)],
[(2, 3, 5, 4), (-1, 0, 0), (8, 3, 5)],
[(2, 3, 5, 5), (0, 0, 0, 0), (2, 3, 5, 5)],
[(2, 4, 5, 3), (-1, 2, 2, 1), (30, 2, 2, 1)],
[(2, 3, 5, 6), (-2,), (2, 3, 5, 6)],
[(2, 3, 5, 6), (6, 1, -2), (6, 1, 5, 6)],
[(2, 3, 5, 6), (-3, -3), (6, 30)],
[(2, 3, 5, 6), (-3, -1), (6, 30)],
[(64,), (-4, 16, 4), (16, 4)],
[(64,), (-4, 16, -1), (16, 4)],
[(64, 1, 2, 3), (-4, 16, -1, -2), (16, 4, 1, 2, 3)]]

for test_case in test_cases:
dynamic_reshape_testcases(*test_case)
Expand Down

0 comments on commit 5486321

Please sign in to comment.