diff --git a/python/mxnet/gluon/nn/activations.py b/python/mxnet/gluon/nn/activations.py index 422301a6a483..fa8eee9d2989 100644 --- a/python/mxnet/gluon/nn/activations.py +++ b/python/mxnet/gluon/nn/activations.py @@ -176,11 +176,9 @@ class SELU(HybridBlock): """ def __init__(self, **kwargs): super(SELU, self).__init__(**kwargs) - self._scale = 1.0507009873554804934193349852946 - self._alpha = 1.6732632423543772848170429916717 def hybrid_forward(self, F, x): - return self._scale * F.where(x > 0, x, self._alpha * (F.exp(x) - 1.0)) + return F.LeakyReLU(x, act_type='selu', name='fwd') class Swish(HybridBlock): diff --git a/src/operator/leaky_relu-inl.h b/src/operator/leaky_relu-inl.h index 20aabc8ae32f..1c4f48b32ed2 100644 --- a/src/operator/leaky_relu-inl.h +++ b/src/operator/leaky_relu-inl.h @@ -47,7 +47,7 @@ namespace op { namespace leakyrelu { enum LeakyReLUOpInputs {kData, kGamma}; enum LeakyReLUOpOutputs {kOut, kMask}; -enum LeakyReLUOpType {kLeakyReLU, kPReLU, kRReLU, kELU}; +enum LeakyReLUOpType {kLeakyReLU, kPReLU, kRReLU, kELU, kSELU}; enum LeakyReLUOpResource {kRandom}; } // namespace leakyrelu @@ -63,6 +63,7 @@ struct LeakyReLUParam : public dmlc::Parameter { .add_enum("leaky", leakyrelu::kLeakyReLU) .add_enum("prelu", leakyrelu::kPReLU) .add_enum("elu", leakyrelu::kELU) + .add_enum("selu", leakyrelu::kSELU) .describe("Activation function to be applied."); DMLC_DECLARE_FIELD(slope).set_default(0.25f) .describe("Init slope for the activation. (For leaky and elu only)"); @@ -182,6 +183,13 @@ class LeakyReLUOp : public Operator { }); break; } + case leakyrelu::kSELU: { + MXNET_ASSIGN_REQ_SWITCH(req[leakyrelu::kOut], Req, { + mxnet_op::Kernel, xpu>::Launch( + s, out.size(0) * out.size(1) * out.size(2), out.dptr_, data.dptr_); + }); + break; + } default: LOG(FATAL) << "Not implmented"; } @@ -270,6 +278,15 @@ class LeakyReLUOp : public Operator { }); break; } + case leakyrelu::kSELU: { + MXNET_ASSIGN_REQ_SWITCH(req[leakyrelu::kData], Req, { + mxnet_op::Kernel, Req>, xpu>::Launch( + s, gdata.size(0) * gdata.size(1) * gdata.size(2), gdata.dptr_, grad.dptr_, + output.dptr_); + }); + break; + } default: LOG(FATAL) << "Not implmented"; } diff --git a/src/operator/leaky_relu.cc b/src/operator/leaky_relu.cc index 99b6ba362f75..4bb24237b8ed 100644 --- a/src/operator/leaky_relu.cc +++ b/src/operator/leaky_relu.cc @@ -54,6 +54,8 @@ when the input is negative and has a slope of one when input is positive. The following modified ReLU Activation functions are supported: - *elu*: Exponential Linear Unit. `y = x > 0 ? x : slope * (exp(x)-1)` +- *selu*: Scaled Exponential Linear Unit. `y = lambda * (x > 0 ? x : alpha * (exp(x) - 1))` where + *lambda = 1.0507009873554804934193349852946* and *alpha = 1.6732632423543772848170429916717*. - *leaky*: Leaky ReLU. `y = x > 0 ? x : slope * x` - *prelu*: Parametric ReLU. This is same as *leaky* except that `slope` is learnt during training. - *rrelu*: Randomized ReLU. same as *leaky* but the `slope` is uniformly and randomly chosen from diff --git a/src/operator/mshadow_op.h b/src/operator/mshadow_op.h index 7a2032df7580..339719375fdd 100644 --- a/src/operator/mshadow_op.h +++ b/src/operator/mshadow_op.h @@ -42,8 +42,12 @@ namespace mshadow_op { #ifdef __CUDA_ARCH__ __constant__ const float PI = 3.14159265358979323846; +__constant__ const float SELU_ALPHA = 1.6732632423543772848170429916717; +__constant__ const float SELU_LAMBDA = 1.0507009873554804934193349852946; #else const float PI = 3.14159265358979323846; +const float SELU_ALPHA = 1.6732632423543772848170429916717; +const float SELU_LAMBDA = 1.0507009873554804934193349852946; using std::isnan; #endif using std::enable_if; @@ -126,6 +130,12 @@ MXNET_UNARY_MATH_OP_NC(relu, a > DType(0) ? a : DType(0)); MXNET_UNARY_MATH_OP_NC(relu_grad, a > DType(0) ? DType(1) : DType(0)); +MXNET_UNARY_MATH_OP_NC(selu, DType(SELU_LAMBDA) * + (a > DType(0) ? a : DType(math::id(SELU_ALPHA) * math::expm1(a)))); + +MXNET_UNARY_MATH_OP_NC(selu_grad, + DType(SELU_LAMBDA) * (a > DType(0) ? DType(1) : DType(SELU_ALPHA + a))); + MXNET_BINARY_MATH_OP_NC(prelu_grad, a > DType(0) ? DType(0) : a); MXNET_BINARY_MATH_OP_NC(xelu, a > DType(0) ? a : diff --git a/src/operator/operator_tune.cc b/src/operator/operator_tune.cc index 0953cbaf519b..cf5412f98246 100644 --- a/src/operator/operator_tune.cc +++ b/src/operator/operator_tune.cc @@ -217,6 +217,8 @@ IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::softsign); // NOLINT() IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::softsign_grad); // NOLINT() IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::relu); // NOLINT() IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::relu_grad); // NOLINT() +IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::selu); // NOLINT() +IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::selu_grad); // NOLINT() IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::tanh); // NOLINT() IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::tanh_grad); // NOLINT() IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::softrelu); // NOLINT() diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 22caaca8a0bc..54eb0fd94a22 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -819,6 +819,37 @@ def fprelu_grad(x, y, gamma): check_symbolic_backward(y, [xa, gam_full], [np.ones(shape), np.ones(gam_full.shape)], [g_xa_full, g_gam_full], rtol=rtol, atol=atol, dtype=dtype) +@with_seed() +def test_selu(): + alpha = 1.6732632423543772848170429916717 + lamb = 1.0507009873554804934193349852946 + def fselu(x): + neg_indices = x < 0 + out = x.copy() + out[neg_indices] = alpha * np.expm1(out[neg_indices]) + return out * lamb + def fselu_grad(grad, x, y): + neg_indices = x < 0 + out = np.ones(x.shape).astype(x.dtype) + out[neg_indices] = y[neg_indices] + alpha + return out * lamb + + shape = (3, 4) + x = mx.sym.Variable("x") + y = mx.sym.LeakyReLU(data=x, act_type="selu") + for dtype in [np.float16, np.float32, np.float64]: + xa = np.random.uniform(low=-0.1,high=0.1,size=shape).astype(dtype) + eps, rtol, atol = (7.5e-4, 1e-1, 1e-2) if dtype is np.float16 else (1e-4, 1e-2, 1e-4) + if dtype is np.float16: + xa /= 10.0 + xa[abs(xa) < eps] = 0.01 + ya = fselu(xa) + ga = fselu_grad(np.ones(shape).astype(dtype), xa, ya) + check_numeric_gradient(y, [xa], numeric_eps=eps, rtol=rtol, atol=atol, dtype=dtype) + check_symbolic_forward(y, [xa], [ya], rtol=rtol, atol=atol, dtype=dtype) + check_symbolic_backward(y, [xa], [np.ones(shape)], [ga], rtol=rtol, atol=atol, dtype=dtype) + + @with_seed() def test_sigmoid(): def fsigmoid(a):