diff --git a/cpp-package/include/mxnet-cpp/optimizer.h b/cpp-package/include/mxnet-cpp/optimizer.h index f3763bbd6e67..320b13eebf2d 100644 --- a/cpp-package/include/mxnet-cpp/optimizer.h +++ b/cpp-package/include/mxnet-cpp/optimizer.h @@ -146,6 +146,20 @@ class SGDOptimizer : public Optimizer { AtomicSymbolCreator mom_update_handle_; }; +class SignumOptimizer : public Optimizer { + public: + explicit SignumOptimizer(unsigned begin_num_update = 0); + std::string GetType() const override; + void Update(int index, NDArray weight, NDArray grad) override; + private: + virtual ~SignumOptimizer(); + void CreateState_(int index, NDArray weight) override; + std::map states_; + AtomicSymbolCreator update_handle_; + AtomicSymbolCreator mom_update_handle_; +}; + + class RMSPropOptimizer : public Optimizer { public: explicit RMSPropOptimizer(unsigned begin_num_update = 0); diff --git a/cpp-package/include/mxnet-cpp/optimizer.hpp b/cpp-package/include/mxnet-cpp/optimizer.hpp index cb8442dc9ceb..e3d47d1161c6 100644 --- a/cpp-package/include/mxnet-cpp/optimizer.hpp +++ b/cpp-package/include/mxnet-cpp/optimizer.hpp @@ -131,6 +131,7 @@ inline Optimizer* OptimizerRegistry::Find(const std::string& name) { MXNETCPP_REGISTER_OPTIMIZER(adam, AdamOptimizer); MXNETCPP_REGISTER_OPTIMIZER(adagrad, AdaGradOptimizer); MXNETCPP_REGISTER_OPTIMIZER(adadelta, AdaDeltaOptimizer); + MXNETCPP_REGISTER_OPTIMIZER(signum, SignumOptimizer); auto it = cmap().find(name); if (it == cmap().end()) return nullptr; @@ -200,6 +201,69 @@ inline void SGDOptimizer::CreateState_(int index, NDArray weight) { } } +// inplementing Signum optimizer + +inline SignumOptimizer::SignumOptimizer(unsigned begin_num_update) + : Optimizer(begin_num_update) { + update_handle_ = op_map()->GetSymbolCreator("signsgd_update"); + mom_update_handle_ = op_map()->GetSymbolCreator("signum_update"); +} + +inline std::string SignumOptimizer::GetType() const { + return "signum"; +} + +inline SignumOptimizer::~SignumOptimizer() { + for (auto &it : states_) { + delete it.second; + } +} + +inline void SignumOptimizer::Update(int index, NDArray weight, NDArray grad) { + if (states_.count(index) == 0) { + CreateState_(index, weight); + } + + params_["lr"] = std::to_string(GetLR_(index)); + params_["wd"] = std::to_string(GetWD_(index)); + UpdateCount_(index); + auto keys = GetParamKeys_(); + auto values = GetParamValues_(); + CHECK_EQ(keys.size(), values.size()); + + NDArrayHandle inputs[3]; + inputs[0] = weight.GetHandle(); + inputs[1] = grad.GetHandle(); + + int num_outputs = 1; + NDArrayHandle output = weight.GetHandle(); + NDArrayHandle *outputs = &output; + + if (states_[index] == nullptr) { + MXImperativeInvoke(update_handle_, 2, inputs, + &num_outputs, &outputs, + keys.size(), keys.data(), values.data()); + } else { + inputs[2] = states_[index]->GetHandle(); + MXImperativeInvoke(mom_update_handle_, 3, inputs, + &num_outputs, &outputs, + keys.size(), keys.data(), values.data()); + } +} + +inline void SignumOptimizer::CreateState_(int index, NDArray weight) { + if (params_.count("momentum") == 0) { + states_[index] = nullptr; + } else { + states_[index] = new NDArray(weight.GetShape(), weight.GetContext()); + *states_[index] = 0; + } +} + +// finish implementing Signum + + + inline RMSPropOptimizer::RMSPropOptimizer(unsigned begin_num_update) : Optimizer(begin_num_update) { update_handle_ = op_map()->GetSymbolCreator("rmsprop_update"); diff --git a/python/mxnet/optimizer.py b/python/mxnet/optimizer.py index feff87e0baab..4285aecef1f5 100644 --- a/python/mxnet/optimizer.py +++ b/python/mxnet/optimizer.py @@ -25,7 +25,8 @@ from .base import py_str from .ndarray import (NDArray, zeros, clip, sqrt, cast, maximum, abs as NDabs) from .ndarray import (sgd_update, sgd_mom_update, adam_update, rmsprop_update, rmspropalex_update, - mp_sgd_update, mp_sgd_mom_update, square, ftrl_update, ftml_update) + mp_sgd_update, mp_sgd_mom_update, square, ftrl_update, ftml_update, + signsgd_update, signum_update) from .ndarray import _internal from .ndarray import op from .ndarray import sparse @@ -534,6 +535,67 @@ def update_multi_precision(self, index, weight, grad, state): self._update_impl(index, weight, grad, state, multi_precision=use_multi_precision) +@register +class Signum(Optimizer): + """The Signum optimizer that takes the sign of gradient or momentum. + + The optimizer updates the weight by: + + rescaled_grad = rescale_grad * clip(grad, clip_gradient) + wd * weight + state = momentum * state + (1-momentum)*rescaled_grad + weight = (1 - lr * wd_lh) * weight - lr * sign(state) + + See the original paper at: https://jeremybernste.in/projects/amazon/signum.pdf + + For details of the update algorithm see + :class:`~mxnet.ndarray.signsgd_update` and :class:`~mxnet.ndarray.signum_update`. + + This optimizer accepts the following parameters in addition to those accepted + by :class:`.Optimizer`. + + Parameters + ---------- + momentum : float, optional + The momentum value. + wd_lh : float, optional + The amount of decoupled weight decay regularization, see details in the original paper at:\ + https://arxiv.org/abs/1711.05101 + """ + def __init__(self, learning_rate=0.01, momentum=0.9, wd_lh=0.0, **kwargs): + super(Signum, self).__init__(learning_rate=learning_rate, **kwargs) + self.momentum = momentum + self.wd_lh = wd_lh + + def create_state(self, index, weight): + momentum = None + if self.momentum != 0.0: + momentum = zeros(weight.shape, weight.context, dtype=weight.dtype, stype=weight.stype) + return momentum + + def _update_impl(self, index, weight, grad, state): + assert(isinstance(weight, NDArray)) + assert(isinstance(grad, NDArray)) + self._update_count(index) + lr = self._get_lr(index) + wd = self._get_wd(index) + + kwargs = {'rescale_grad': self.rescale_grad} + if self.momentum > 0: + kwargs['momentum'] = self.momentum + if self.clip_gradient: + kwargs['clip_gradient'] = self.clip_gradient + if self.wd_lh: + kwargs['wd_lh'] = self.wd_lh + + if state is not None: + signum_update(weight, grad, state, out=weight, + lr=lr, wd=wd, **kwargs) + else: + signsgd_update(weight, grad, out=weight, + lr=lr, wd=wd, **kwargs) + + def update(self, index, weight, grad, state): + self._update_impl(index, weight, grad, state) @register class FTML(Optimizer): @@ -702,8 +764,7 @@ def update(self, index, weight, grad, state): if self.clip_gradient is not None: grad = clip(grad, -self.clip_gradient, self.clip_gradient) weight[:] += - lr/2 * (grad + wd * weight) + normal(0, math.sqrt(lr), - shape=weight.shape, - ctx=weight.context) + weight.shape, weight.context) @register # pylint: disable=invalid-name diff --git a/src/operator/optimizer_op-inl.h b/src/operator/optimizer_op-inl.h index 33b7dd5fe5a8..c2564db0f079 100644 --- a/src/operator/optimizer_op-inl.h +++ b/src/operator/optimizer_op-inl.h @@ -66,6 +66,7 @@ struct SGDParam : public dmlc::Parameter { } }; + struct SGDKernel { template MSHADOW_XINLINE static void Map(int i, DType* out_data, const DType* weight_data, @@ -228,6 +229,7 @@ struct SGDMomParam : public dmlc::Parameter { } }; + struct SGDMomKernel { template MSHADOW_XINLINE static void Map(int i, DType* out_data, DType* mom_data, const DType* weight_data, @@ -1281,6 +1283,146 @@ inline void FtrlUpdateEx(const nnvm::NodeAttrs& attrs, } } + +// Implementation for signSGD and Signum + +struct SignSGDParam : public dmlc::Parameter { + float lr; + float wd; + float rescale_grad; + float clip_gradient; + DMLC_DECLARE_PARAMETER(SignSGDParam) { + DMLC_DECLARE_FIELD(lr) + .describe("Learning rate"); + DMLC_DECLARE_FIELD(wd) + .set_default(0.0f) + .describe("Weight decay augments the objective function with a " + "regularization term that penalizes large weights. " + "The penalty scales with the square of the magnitude of each weight."); + DMLC_DECLARE_FIELD(rescale_grad) + .set_default(1.0f) + .describe("Rescale gradient to grad = rescale_grad*grad."); + DMLC_DECLARE_FIELD(clip_gradient) + .set_default(-1.0f) + .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] " + "If clip_gradient <= 0, gradient clipping is turned off. " + "grad = max(min(grad, clip_gradient), -clip_gradient)."); + } +}; + + +struct SignSGDKernel { + template + MSHADOW_XINLINE static void Map(int i, DType* out_data, const DType* weight_data, + const DType* grad_data, const DType param_clip_gradient, + const DType param_lr, const DType param_wd, const DType param_rescale_grad, + const OpReqType req) { + + // param_clip_gradient has no effect for SignSGD + KERNEL_ASSIGN(out_data[i], req, + (1.f-param_lr*param_wd)*weight_data[i] + - (param_lr)*((grad_data[i] > 0) - (grad_data[i] < 0))); + } +}; + +template +inline void SignSGDUpdate(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + using namespace mxnet_op; + const SignSGDParam& param = nnvm::get(attrs.parsed); + Stream* s = ctx.get_stream(); + MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { + Tensor weight = inputs[0].FlatTo2D(s); + Tensor grad = inputs[1].FlatTo2D(s); + Tensor out = outputs[0].FlatTo2D(s); + Kernel::Launch(s, weight.shape_.Size(), out.dptr_, weight.dptr_, + grad.dptr_, static_cast(param.clip_gradient), + static_cast(param.lr), static_cast(param.wd), + static_cast(param.rescale_grad), req[0]); + }); +} + + +struct SignumParam : public dmlc::Parameter { + float lr; + float momentum; + float wd; + float rescale_grad; + float clip_gradient; + float wd_lh; // the amount of algorithmic weight decay by Loshchilov and Frank Hutter + DMLC_DECLARE_PARAMETER(SignumParam) { + DMLC_DECLARE_FIELD(lr) + .describe("Learning rate"); + DMLC_DECLARE_FIELD(momentum) + .set_default(0.0f) + .describe("The decay rate of momentum estimates at each epoch."); + DMLC_DECLARE_FIELD(wd) + .set_default(0.0f) + .describe("Weight decay augments the objective function with a " + "regularization term that penalizes large weights. " + "The penalty scales with the square of the magnitude of each weight."); + DMLC_DECLARE_FIELD(rescale_grad) + .set_default(1.0f) + .describe("Rescale gradient to grad = rescale_grad*grad."); + DMLC_DECLARE_FIELD(clip_gradient) + .set_default(-1.0f) + .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] " + "If clip_gradient <= 0, gradient clipping is turned off. " + "grad = max(min(grad, clip_gradient), -clip_gradient)."); + DMLC_DECLARE_FIELD(wd_lh) + .set_default(0.0f) + .describe("The amount of weight decay that does not go into gradient/momentum calculations" + "otherwise do weight decay algorithmically only."); + } +}; + +struct SignumKernel { + template + MSHADOW_XINLINE static void Map(int i, DType* out_data, DType* mom_data, const DType* weight_data, + const DType* grad_data, const DType param_clip_gradient, const DType param_momentum, + const DType param_lr, const DType param_wd, const DType param_rescale_grad, + const DType param_wd_lh, const OpReqType req) { + if (param_clip_gradient >= 0.0f) { + mom_data[i] = param_momentum*mom_data[i] + - (1-param_momentum)*param_wd*weight_data[i] + - (1-param_momentum) + *mshadow_op::clip::Map(param_rescale_grad*grad_data[i], param_clip_gradient); + } else { + mom_data[i] = param_momentum*mom_data[i] + - (1-param_momentum)*param_wd*weight_data[i] + - (1-param_momentum)*param_rescale_grad*grad_data[i]; + } + KERNEL_ASSIGN(out_data[i], req, (1.f-param_lr*param_wd_lh)*weight_data[i] + + (param_lr)*((mom_data[i] > 0) - (mom_data[i] < 0))); + } +}; + +template +inline void SignumUpdate(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + using namespace mxnet_op; + SignumParam param = nnvm::get(attrs.parsed); + Stream* s = ctx.get_stream(); + MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { + Tensor weight = inputs[0].FlatTo2D(s); + Tensor grad = inputs[1].FlatTo2D(s); + Tensor mom = inputs[2].FlatTo2D(s); + Tensor out = outputs[0].FlatTo2D(s); + Kernel::Launch(s, weight.shape_.Size(), out.dptr_, mom.dptr_, weight.dptr_, + grad.dptr_, static_cast(param.clip_gradient), static_cast(param.momentum), + static_cast(param.lr), static_cast(param.wd), + static_cast(param.rescale_grad), static_cast(param.wd_lh), req[0]); + }); +} + + + } // namespace op } // namespace mxnet diff --git a/src/operator/optimizer_op.cc b/src/operator/optimizer_op.cc index dda809255dce..8760fe94a526 100644 --- a/src/operator/optimizer_op.cc +++ b/src/operator/optimizer_op.cc @@ -36,6 +36,67 @@ DMLC_REGISTER_PARAMETER(AdamParam); DMLC_REGISTER_PARAMETER(RMSPropParam); DMLC_REGISTER_PARAMETER(RMSPropAlexParam); DMLC_REGISTER_PARAMETER(FtrlParam); +DMLC_REGISTER_PARAMETER(SignSGDParam); +DMLC_REGISTER_PARAMETER(SignumParam); + +NNVM_REGISTER_OP(signsgd_update) +.describe(R"code(Update function for SignSGD optimizer. +.. math:: + + g_t = \nabla J(W_{t-1})\\ + W_t = W_{t-1} - \eta_t \text{sign}(g_t)} + +It updates the weights using:: + + weight = weight - learning_rate * sign(gradient) + +.. note:: + - sparse ndarray not supported for this optimizer yet. +)code" ADD_FILELINE) +.set_num_inputs(2) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FInferShape", ElemwiseShape<2, 1>) +.set_attr("FInferType", ElemwiseType<2, 1>) +.set_attr("FCompute", SignSGDUpdate) +.add_argument("weight", "NDArray-or-Symbol", "Weight") +.add_argument("grad", "NDArray-or-Symbol", "Gradient") +.add_arguments(SignSGDParam::__FIELDS__()); + + +NNVM_REGISTER_OP(signum_update) +.describe(R"code(SIGN momentUM (Signum) optimizer. + +.. math:: + + g_t = \nabla J(W_{t-1})\\ + m_t = \beta m_{t-1} + (1 - \beta) g_t\\ + W_t = W_{t-1} - \eta_t \text{sign}(m_t)} + +It updates the weights using:: + state = momentum * state + (1-momentum) * gradient + weight = weight - learning_rate * sign(state) + +Where the parameter ``momentum`` is the decay rate of momentum estimates at each epoch. + +.. note:: + - sparse ndarray not supported for this optimizer yet. +)code" ADD_FILELINE) +.set_num_inputs(3) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FInferShape", ElemwiseShape<3, 1>) +.set_attr("FInferType", ElemwiseType<3, 1>) +.set_attr("FMutateInputs", + [](const nnvm::NodeAttrs& attrs) { + return std::vector{2}; + }) +.set_attr("FCompute", SignumUpdate) +.add_argument("weight", "NDArray-or-Symbol", "Weight") +.add_argument("grad", "NDArray-or-Symbol", "Gradient") +.add_argument("mom", "NDArray-or-Symbol", "Momentum") +.add_arguments(SignumParam::__FIELDS__()); + template<> void SGDMomStdUpdateDnsRspDnsImpl(const SGDMomParam& param, diff --git a/src/operator/optimizer_op.cu b/src/operator/optimizer_op.cu index 9512e92a80ec..891f24fe7935 100644 --- a/src/operator/optimizer_op.cu +++ b/src/operator/optimizer_op.cu @@ -94,6 +94,13 @@ void SGDMomStdUpdateDnsRspDnsImpl(const SGDMomParam& param, }); } + +NNVM_REGISTER_OP(signsgd_update) +.set_attr("FCompute", SignSGDUpdate); + +NNVM_REGISTER_OP(signum_update) +.set_attr("FCompute", SignumUpdate); + NNVM_REGISTER_OP(sgd_update) .set_attr("FCompute", SGDUpdate) .set_attr("FComputeEx", SGDUpdateEx); diff --git a/tests/python/unittest/test_optimizer.py b/tests/python/unittest/test_optimizer.py index ae248b0d0bc7..2d22391879ce 100644 --- a/tests/python/unittest/test_optimizer.py +++ b/tests/python/unittest/test_optimizer.py @@ -524,6 +524,87 @@ def test_adam(): compare_optimizer(opt1(sparse_update=True, **kwarg), opt2(**kwarg), shape, dtype, w_stype='row_sparse', g_stype='row_sparse') + +# Signum +class PySignum(mx.optimizer.Optimizer): + """The python reference of Signum optimizer. + + The optimizer updates the weight by: + + rescaled_grad = rescale_grad * clip(grad, clip_gradient) + wd * weight + state = momentum * state + (1-momentum)*rescaled_grad + weight = (1 - lr * wd_lh) * weight - lr * sign(state) + + See the original paper at: https://jeremybernste.in/projects/amazon/signum.pdf + + For details of the update algorithm see + :class:`~mxnet.ndarray.signsgd_update` and :class:`~mxnet.ndarray.signum_update`. + + This optimizer accepts the following parameters in addition to those accepted + by :class:`.Optimizer`. + + Parameters + ---------- + momentum : float, optional + The momentum value. + wd_lh : float, optitional + The amount of decoupled weight decay regularization. + """ + def __init__(self, learning_rate=0.01, momentum=0.9, wd_lh = 0.0, **kwargs): + super(PySignum, self).__init__(learning_rate = learning_rate, **kwargs) + self.momentum = momentum + self.wd_lh = wd_lh + + def create_state(self, index, weight): + momentum = None + if self.momentum != 0.0: + momentum = mx.nd.zeros(weight.shape, weight.context, dtype=weight.dtype, stype=weight.stype) + return momentum + + def update(self, index, weight, grad, state): + self._update_count(index) + lr = self._get_lr(index) + wd = self._get_wd(index) + + if state is not None: + mom = state + if self.clip_gradient is not None: + mom[:] = (self.momentum*mom - (1-self.momentum)*(wd*weight + + mx.nd.clip(grad*self.rescale_grad, -self.clip_gradient, self.clip_gradient))) + else: + mom[:] = self.momentum*mom - (1-self.momentum)*wd*weight - (1-self.momentum)*self.rescale_grad*grad + weight[:] = (1 - lr*self.wd_lh)*weight + lr*mx.nd.sign(mom) + else: + weight[:] = (1 - lr*(wd+self.wd_lh))*weight - lr*mx.nd.sign(grad) + +def test_signum(): + mx.random.seed(0) + opt1 = PySignum + opt2 = mx.optimizer.Signum + shape = (3, 4, 5) + cg_options = [{}, {'clip_gradient': 0.4}, {'clip_gradient': 0.5}] + rg_options = [{}, {'rescale_grad': 0.14}, {'rescale_grad': 0.8}] + wd_options = [{}, {'wd': 0.03}, {'wd': 0.05}, {'wd': 0.07}] + wd_lh_options = [{}, {'wd_lh': 0.015}, {'wd_lh': 0.0}] + mom_options = [{}, {'momentum': 0.9}] + lr_options = [{'learning_rate': 0.05},{'learning_rate': 0.01}] + for dtype in [np.float32, np.float64]: + for cg_option in cg_options: + for rg_option in rg_options: + for wd_option in wd_options: + for mp_option in wd_lh_options: + for lr_option in lr_options: + for mom_option in mom_options: + kwarg = {} + kwarg.update(cg_option) + kwarg.update(rg_option) + kwarg.update(wd_option) + kwarg.update(mp_option) + kwarg.update(lr_option) + kwarg.update(mom_option) + compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, dtype) + + # RMSProp class PyRMSProp(mx.optimizer.Optimizer): """RMSProp optimizer of Tieleman & Hinton, 2012,