Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
support multiply between floats and bool
Browse files Browse the repository at this point in the history
  • Loading branch information
haojin2 committed Oct 25, 2019
1 parent a7d18c8 commit a85214c
Show file tree
Hide file tree
Showing 8 changed files with 313 additions and 7 deletions.
36 changes: 36 additions & 0 deletions src/common/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -842,6 +842,42 @@ inline bool is_float(const int dtype) {
return dtype == mshadow::kFloat32 || dtype == mshadow::kFloat64 || dtype == mshadow::kFloat16;
}

inline int more_precise_type(const int type1, const int type2) {
if (type1 == type2) return type1;
if (is_float(type1) && is_float(type2)) {
if (type1 == mshadow::kFloat64 || type2 == mshadow::kFloat64) {
return mshadow::kFloat64;
}
if (type1 == mshadow::kFloat32 || type2 == mshadow::kFloat32) {
return mshadow::kFloat32;
}
return mshadow::kFloat16;
} else if (is_float(type1) || is_float(type2)) {
return is_float(type1) ? type1 : type2;
}
if (type1 == mshadow::kInt64 || type2 == mshadow::kInt64) {
return mshadow::kInt64;
}
if (type1 == mshadow::kInt32 || type2 == mshadow::kInt32) {
return mshadow::kInt32;
}
CHECK(!((type1 == mshadow::kUint8 && type2 == mshadow::kInt8) ||
(type1 == mshadow::kInt8 && type2 == mshadow::kUint8)))
<< "1 is UInt8 and 1 is Int8 should not get here";
if (type1 == mshadow::kUint8 || type2 == mshadow::kUint8) {
return mshadow::kUint8;
}
return mshadow::kInt8;
}

inline int np_binary_out_type(const int type1, const int type2) {
if ((type1 == mshadow::kUint8 && type2 == mshadow::kInt8) ||
(type1 == mshadow::kInt8 && type2 == mshadow::kUint8)) {
return mshadow::kInt32;
}
return more_precise_type(type1, type2);
}

} // namespace common
} // namespace mxnet
#endif // MXNET_COMMON_UTILS_H_
8 changes: 8 additions & 0 deletions src/operator/mshadow_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,14 @@ MXNET_BINARY_MATH_OP_NC(right, b);

MXNET_BINARY_MATH_OP_NC(mul, a * b);

struct mixed_mul {
template<typename DType,
typename std::enable_if<!std::is_pointer<DType>::value, int>::type = 0>
MSHADOW_XINLINE static DType Map(bool a, DType b) {
return static_cast<DType>(a) * b;
}
};

MXNET_BINARY_MATH_OP_NC(div, a / b);

MXNET_BINARY_MATH_OP_NC(plus, a + b);
Expand Down
59 changes: 54 additions & 5 deletions src/operator/numpy/np_elemwise_broadcast_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@
* \brief CPU Implementation of basic functions for elementwise numpy binary broadcast operator.
*/

#include "../tensor/elemwise_binary_broadcast_op.h"
#include "../tensor/elemwise_binary_scalar_op.h"
#include "./np_elemwise_broadcast_op.h"

namespace mxnet {
namespace op {
Expand Down Expand Up @@ -55,6 +54,38 @@ bool NumpyBinaryScalarType(const nnvm::NodeAttrs& attrs,
.add_argument("data", "NDArray-or-Symbol", "source input") \
.add_argument("scalar", "float", "scalar input")

bool NumpyBinaryMixedPrecisionType(const nnvm::NodeAttrs& attrs,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
CHECK_EQ(in_attrs->size(), 2U);
CHECK_EQ(out_attrs->size(), 1U);
const int ltype = in_attrs->at(0);
const int rtype = in_attrs->at(1);
if (ltype != -1 && rtype != -1 && (ltype != rtype)) {
// Only when both input types are known and not the same, we enter the mixed-precision mode
TYPE_ASSIGN_CHECK(*out_attrs, 0, common::np_binary_out_type(ltype, rtype));
} else {
return ElemwiseType<2, 1>(attrs, in_attrs, out_attrs);
}
return true;
}

#define MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(name) \
NNVM_REGISTER_OP(name) \
.set_num_inputs(2) \
.set_num_outputs(1) \
.set_attr<nnvm::FListInputNames>("FListInputNames", \
[](const NodeAttrs& attrs) { \
return std::vector<std::string>{"lhs", "rhs"}; \
}) \
.set_attr<mxnet::FInferShape>("FInferShape", BinaryBroadcastShape) \
.set_attr<nnvm::FInferType>("FInferType", NumpyBinaryMixedPrecisionType) \
.set_attr<nnvm::FInplaceOption>("FInplaceOption", \
[](const NodeAttrs& attrs){ \
return std::vector<std::pair<int, int> >{{0, 0}, {1, 0}}; \
}) \
.add_argument("lhs", "NDArray-or-Symbol", "First input to the function") \
.add_argument("rhs", "NDArray-or-Symbol", "Second input to the function")

MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_add)
.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, op::mshadow_op::plus>)
Expand All @@ -64,9 +95,27 @@ MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_subtract)
.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, op::mshadow_op::minus>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_broadcast_sub"});

MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_multiply)
.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, op::mshadow_op::mul>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_broadcast_mul"});
MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(_npi_multiply)
.set_attr<FCompute>(
"FCompute<cpu>",
MixedBinaryBroadcastCompute<cpu, op::mshadow_op::mul, op::mshadow_op::mixed_mul,
op::mshadow_op::mixed_mul>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_npi_broadcast_mul"});

NNVM_REGISTER_OP(_backward_npi_broadcast_mul)
.set_num_inputs(3)
.set_num_outputs(2)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<nnvm::FInplaceOption>("FInplaceOption",
[](const NodeAttrs& attrs) {
return std::vector<std::pair<int, int> >{{0, 1}};
})
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<FCompute>("FCompute<cpu>", MixedBinaryBackwardUseIn<cpu, mshadow_op::right,
mshadow_op::left>);

MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_mod)
.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, mshadow_op::mod>)
Expand Down
10 changes: 9 additions & 1 deletion src/operator/numpy/np_elemwise_broadcast_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
* \file np_elemwise_broadcast_op.cu
* \brief GPU Implementation of basic functions for elementwise binary broadcast operator.
*/
#include "./np_elemwise_broadcast_op.h"
#include "../tensor/elemwise_binary_broadcast_op.h"
#include "../tensor/elemwise_binary_scalar_op.h"

Expand All @@ -35,7 +36,14 @@ NNVM_REGISTER_OP(_npi_subtract)
.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, op::mshadow_op::minus>);

NNVM_REGISTER_OP(_npi_multiply)
.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, op::mshadow_op::mul>);
.set_attr<FCompute>(
"FCompute<gpu>",
MixedBinaryBroadcastCompute<gpu, op::mshadow_op::mul, op::mshadow_op::mixed_mul,
op::mshadow_op::mixed_mul>);

NNVM_REGISTER_OP(_backward_npi_broadcast_mul)
.set_attr<FCompute>("FCompute<gpu>", MixedBinaryBackwardUseIn<gpu, mshadow_op::right,
mshadow_op::left>);

NNVM_REGISTER_OP(_npi_mod)
.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, mshadow_op::mod>);
Expand Down
148 changes: 148 additions & 0 deletions src/operator/numpy/np_elemwise_broadcast_op.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Copyright (c) 2019 by Contributors
* \file np_elemwise_binary_op.h
* \brief
*/
#ifndef MXNET_OPERATOR_NUMPY_NP_ELEMWISE_BROADCAST_OP_H_
#define MXNET_OPERATOR_NUMPY_NP_ELEMWISE_BROADCAST_OP_H_

#include <vector>

#include "../tensor/elemwise_binary_broadcast_op.h"
#include "../tensor/elemwise_binary_scalar_op.h"

namespace mxnet {
namespace op {

template<typename xpu, typename LOP, typename ROP>
void MixedBinaryElemwiseCompute(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mxnet_op;
CHECK_EQ(inputs.size(), 2U);
CHECK_EQ(outputs.size(), 1U);

const TBlob& lhs = inputs[0];
const TBlob& rhs = inputs[1];
const TBlob& out = outputs[0];

CHECK((lhs.type_flag_ == mshadow::kBool) || (rhs.type_flag_ == mshadow::kBool))
<< "now supports bool with another type only";

Stream<xpu> *s = ctx.get_stream<xpu>();

MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
const size_t size = (ElemwiseBinaryOp::minthree(out.Size(), lhs.Size(), rhs.Size())
+ DataType<DType>::kLanes - 1) / DataType<DType>::kLanes;
if (size != 0) {
if (lhs.type_flag_ == kBool) {
Kernel<mxnet_op::op_with_req<LOP, Req>, xpu>::Launch(
s, size, out.dptr<DType>(), lhs.dptr<bool>(), rhs.dptr<DType>());
} else {
Kernel<mxnet_op::op_with_req<ROP, Req>, xpu>::Launch(
s, size, out.dptr<DType>(), rhs.dptr<bool>(), lhs.dptr<DType>());
}
}
});
});
}

template<typename xpu, typename OP, typename LOP, typename ROP>
void MixedBinaryBroadcastCompute(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mxnet_op;
CHECK_EQ(inputs.size(), 2U);
CHECK_EQ(outputs.size(), 1U);

const TBlob& lhs = inputs[0];
const TBlob& rhs = inputs[1];
const TBlob& out = outputs[0];

if ((out.shape_.Size() == 0U) || (req[0] == kNullOp)) return;

mxnet::TShape new_lshape, new_rshape, new_oshape;
int ndim = BinaryBroadcastShapeCompact(lhs.shape_, rhs.shape_, out.shape_,
&new_lshape, &new_rshape, &new_oshape);


if (lhs.type_flag_ == rhs.type_flag_) {
BinaryBroadcastCompute<xpu, OP>(attrs, ctx, inputs, req, outputs);
return;
}

CHECK((lhs.type_flag_ == mshadow::kBool) || (rhs.type_flag_ == mshadow::kBool))
<< "now supports bool with another type only";


if (!ndim) {
MixedBinaryElemwiseCompute<xpu, LOP, ROP>(attrs, ctx, inputs, req, outputs);
} else {
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
BROADCAST_NDIM_SWITCH(ndim, NDim, {
mshadow::Shape<NDim> oshape = new_oshape.get<NDim>();
mshadow::Shape<NDim> lstride = mxnet_op::calc_stride(new_lshape.get<NDim>());
mshadow::Shape<NDim> rstride = mxnet_op::calc_stride(new_rshape.get<NDim>());
if (lhs.type_flag_ == mshadow::kBool) {
mxnet_op::Kernel<mxnet_op::binary_broadcast_kernel<NDim, LOP>, xpu>::
template LaunchEx(s, new_oshape.Size(), req[0], lstride, rstride, oshape,
lhs.dptr<bool>(), rhs.dptr<DType>(), out.dptr<DType>());
} else {
mxnet_op::Kernel<mxnet_op::binary_broadcast_kernel<NDim, ROP>, xpu>::
template LaunchEx(s, new_oshape.Size(), req[0], rstride, lstride, oshape,
rhs.dptr<bool>(), lhs.dptr<DType>(), out.dptr<DType>());
}
});
});
}
}

template<typename xpu, typename LOP, typename ROP>
void MixedBinaryBackwardUseIn(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
CHECK_EQ(inputs.size(), 3U);
CHECK_EQ(outputs.size(), 2U);

const TBlob& lhs = inputs[1];
const TBlob& rhs = inputs[2];
if (lhs.type_flag_ == rhs.type_flag_) {
BinaryBroadcastBackwardUseIn<xpu, LOP, ROP>(attrs, ctx, inputs, req, outputs);
return;
}

LOG(ERROR) << "Binary operation with mixed input data types does not support backward yet...";
}

} // namespace op
} // namespace mxnet
#endif // MXNET_OPERATOR_NUMPY_NP_ELEMWISE_BROADCAST_OP_H_
2 changes: 1 addition & 1 deletion src/operator/tensor/elemwise_binary_broadcast_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -701,7 +701,7 @@ void BinaryBroadcastBackwardUseIn(const nnvm::NodeAttrs& attrs,
[](const NodeAttrs& attrs) { \
return std::vector<std::string>{"lhs", "rhs"}; \
}) \
.set_attr<mxnet::FInferShape>("FInferShape", BinaryBroadcastShape) \
.set_attr<mxnet::FInferShape>("FInferShape", BinaryBroadcastShape) \
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<2, 1>) \
.set_attr<nnvm::FInplaceOption>("FInplaceOption", \
[](const NodeAttrs& attrs){ \
Expand Down
2 changes: 2 additions & 0 deletions src/operator/tensor/elemwise_binary_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,13 @@ class ElemwiseBinaryOp : public OpBase {
return a1.var() == a2.var();
}

public:
/*! \brief Minimum of three */
static MSHADOW_XINLINE size_t minthree(const size_t a, const size_t b, const size_t c) {
return a < b ? (a < c ? a : c) : (b < c ? b : c);
}

private:
template<typename xpu, typename LOP, typename ROP, typename DType>
static void BackwardUseNone_(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
Expand Down
55 changes: 55 additions & 0 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1650,6 +1650,61 @@ def hybrid_forward(self, F, a, b, *args, **kwargs):
check_binary_func(func, lshape, rshape, low, high, lgrads, rgrads, dtypes)


@with_seed()
@use_np
def test_np_mixed_precision_binary_funcs():
def check_mixed_precision_binary_func(func, low, high, lshape, rshape, ltype, rtype):
class TestMixedBinary(HybridBlock):
def __init__(self, func):
super(TestMixedBinary, self).__init__()
self._func = func

def hybrid_forward(self, F, a, b, *args, **kwargs):
return getattr(F.np, self._func)(a, b)

np_func = getattr(_np, func)
mx_func = TestMixedBinary(func)
np_test_x1 = _np.random.uniform(low, high, lshape).astype(ltype)
np_test_x2 = _np.random.uniform(low, high, rshape).astype(rtype)
mx_test_x1 = mx.numpy.array(np_test_x1, dtype=ltype)
mx_test_x2 = mx.numpy.array(np_test_x2, dtype=rtype)
for hybridize in [True, False]:
if hybridize:
mx_func.hybridize()
np_out = np_func(np_test_x1, np_test_x2)
with mx.autograd.record():
y = mx_func(mx_test_x1, mx_test_x2)
assert y.shape == np_out.shape
assert_almost_equal(y.asnumpy(), np_out.astype(y.dtype), rtol=1e-3, atol=1e-5,
use_broadcast=False, equal_nan=True)

np_out = getattr(_np, func)(np_test_x1, np_test_x2)
mx_out = getattr(mx.np, func)(mx_test_x1, mx_test_x2)
assert mx_out.shape == np_out.shape
assert_almost_equal(mx_out.asnumpy(), np_out.astype(mx_out.dtype), rtol=1e-3, atol=1e-5,
use_broadcast=False, equal_nan=True)

funcs = {
'multiply': (-1.0, 1.0),
}
shape_pairs = [((3, 2), (3, 2)),
((3, 2), (3, 1)),
((3, 1), (3, 0)),
((0, 2), (1, 2)),
((2, 3, 4), (3, 1)),
((2, 3), ()),
((), (2, 3))]
type_pairs = [(np.bool, np.float16),
(np.bool, np.float32),
(np.bool, np.float64)]
for func, func_data in funcs.items():
low, high = func_data
for lshape, rshape in shape_pairs:
for type1, type2 in type_pairs:
check_mixed_precision_binary_func(func, low, high, lshape, rshape, type1, type2)
check_mixed_precision_binary_func(func, low, high, lshape, rshape, type2, type1)


@with_seed()
@use_np
def test_npx_relu():
Expand Down

0 comments on commit a85214c

Please sign in to comment.