Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[v1.7.x] Backport PRs of numpy features (#18653)
Browse files Browse the repository at this point in the history
* add zero grad for npi_unique (#18080)

* fix np.clip scalar input case (#17788)

* fix true_divide (#18393)

Co-authored-by: Hao Jin <[email protected]>
Co-authored-by: Xi Wang <[email protected]>
  • Loading branch information
3 people authored Jul 2, 2020
1 parent 9c06894 commit 802e5af
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 9 deletions.
5 changes: 5 additions & 0 deletions python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -6174,6 +6174,11 @@ def clip(a, a_min, a_max, out=None):
>>> np.clip(a, 3, 6, out=a)
array([3., 3., 3., 3., 4., 5., 6., 6., 6., 6.], dtype=float32)
"""
from numbers import Number
if isinstance(a, Number):
# In case input is a scalar, the computation would fall back to native numpy.
# The value returned would be a python scalar.
return _np.clip(a, a_min, a_max, out=None)
return _mx_nd_np.clip(a, a_min, a_max, out=out)


Expand Down
16 changes: 8 additions & 8 deletions src/operator/numpy/np_true_divide-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ void TrueDivideElemwiseCompute(const nnvm::NodeAttrs &attrs,
// Case when types of the 2 input tensors are different
if (common::is_float(lhs.type_flag_) && common::is_float(rhs.type_flag_)) {
// both lhs and rhs are float types, output type is the more precise one
LOG(ERROR) << "not implemented yet...";
LOG(FATAL) << "not implemented yet...";
} else if (common::is_float(lhs.type_flag_) || common::is_float(rhs.type_flag_)) {
// one is float type, the other is integer type, the output type should be the same as float
CHECK_EQ(out.type_flag_,
Expand Down Expand Up @@ -150,14 +150,14 @@ void TrueDivideElemwiseCompute(const nnvm::NodeAttrs &attrs,
}
} else {
// lhs is integer type, rhs is integer type, output type should be float
LOG(ERROR) << "not implemented yet...";
LOG(FATAL) << "not implemented yet...";
}
#else
// Windows case: using temp space for casting the type
// Case when types of the 2 input tensors are different
if (common::is_float(lhs.type_flag_) && common::is_float(rhs.type_flag_)) {
// both lhs and rhs are float types, output type is the more precise one
LOG(ERROR) << "not implemented yet...";
LOG(FATAL) << "not implemented yet...";
} else if (common::is_float(lhs.type_flag_) || common::is_float(rhs.type_flag_)) {
// lhs is float type, rhs is integer type, the output type should be the same as lhs
CHECK_EQ(out.type_flag_,
Expand Down Expand Up @@ -187,7 +187,7 @@ void TrueDivideElemwiseCompute(const nnvm::NodeAttrs &attrs,
}
} else {
// lhs is integer type, rhs is integer type, output type should be float
LOG(ERROR) << "not implemented yet...";
LOG(FATAL) << "not implemented yet...";
}
#endif
}
Expand Down Expand Up @@ -241,7 +241,7 @@ void TrueDivideBroadcastCompute(const nnvm::NodeAttrs& attrs,
} else {
if (common::is_float(lhs.type_flag_) && common::is_float(rhs.type_flag_)) {
// lhs and rhs have different float types, the output is the more precise one
LOG(ERROR) << "not implemented yet...";
LOG(FATAL) << "not implemented yet...";
} else if (common::is_float(lhs.type_flag_) || common::is_float(rhs.type_flag_)) {
// one of lhs and rhs is float, the output is the same type as the float one
if (common::is_float(lhs.type_flag_)) {
Expand Down Expand Up @@ -269,7 +269,7 @@ void TrueDivideBroadcastCompute(const nnvm::NodeAttrs& attrs,
}
} else {
// lhs and rhs have different integer types, the output is float type
LOG(ERROR) << "not implemented yet...";
LOG(FATAL) << "not implemented yet...";
}
}
});
Expand Down Expand Up @@ -302,7 +302,7 @@ void TrueDivideBroadcastCompute(const nnvm::NodeAttrs& attrs,
} else {
if (common::is_float(lhs.type_flag_) && common::is_float(rhs.type_flag_)) {
// lhs and rhs have different float types, the output is the more precise one
LOG(ERROR) << "not implemented yet...";
LOG(FATAL) << "not implemented yet...";
} else if (common::is_float(lhs.type_flag_) || common::is_float(rhs.type_flag_)) {
// one of lhs and rhs is float, the output is the same type as the float one
TBlob temp_tblob;
Expand Down Expand Up @@ -333,7 +333,7 @@ void TrueDivideBroadcastCompute(const nnvm::NodeAttrs& attrs,
}
} else {
// lhs and rhs have different integer types, the output is float type
LOG(ERROR) << "not implemented yet...";
LOG(FATAL) << "not implemented yet...";
}
}
#endif
Expand Down
1 change: 1 addition & 0 deletions src/operator/numpy/np_unique_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,7 @@ NNVM_REGISTER_OP(_npi_unique)
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
.add_argument("data", "NDArray-or-Symbol", "The input array")
.add_arguments(NumpyUniqueParam::__FIELDS__());

Expand Down
12 changes: 11 additions & 1 deletion tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -3644,6 +3644,16 @@ def __init__(self, a_min=None, a_max=None):

def hybrid_forward(self, F, x):
return x.clip(self._a_min, self._a_max)

# Test scalar case
for _, a_min, a_max, throw_exception in workloads:
a = _np.random.uniform() # A scalar
if throw_exception:
# No need to test the exception case here.
continue
mx_ret = np.clip(a, a_min, a_max)
np_ret = _np.clip(a, a_min, a_max)
assert_almost_equal(mx_ret, np_ret, atol=1e-4, rtol=1e-3, use_broadcast=False)

for shape, a_min, a_max, throw_exception in workloads:
for dtype in dtypes:
Expand Down Expand Up @@ -6549,7 +6559,7 @@ def hybrid_forward(self, F, a):
((5, 3, 4), True, True, True, 1),
]
for dtype in ['float32', 'float64', 'int8', 'uint8', 'int32', 'int64']:
for hybridize in [False]:
for hybridize in [False, True]:
for config in configs:
test_unique = TestUnique(*config[1:])
if hybridize:
Expand Down

0 comments on commit 802e5af

Please sign in to comment.