diff --git a/tensorflow_addons/optimizers/BUILD b/tensorflow_addons/optimizers/BUILD index c24cdadcd8..6563c35122 100644 --- a/tensorflow_addons/optimizers/BUILD +++ b/tensorflow_addons/optimizers/BUILD @@ -7,6 +7,7 @@ py_library( srcs = [ "__init__.py", "conditional_gradient.py", + "cyclical_learning_rate.py", "lamb.py", "lazy_adam.py", "lookahead.py", @@ -110,3 +111,16 @@ py_test( ":optimizers", ], ) + +py_test( + name = "cyclical_learning_rate_test", + size = "small", + srcs = [ + "cyclical_learning_rate_test.py", + ], + main = "cyclical_learning_rate_test.py", + srcs_version = "PY2AND3", + deps = [ + ":optimizers", + ], +) diff --git a/tensorflow_addons/optimizers/README.md b/tensorflow_addons/optimizers/README.md index 4898e20351..d16a53e35d 100644 --- a/tensorflow_addons/optimizers/README.md +++ b/tensorflow_addons/optimizers/README.md @@ -4,6 +4,7 @@ | Submodule | Maintainers | Contact Info | |:---------- |:------------- |:--------------| | conditional_gradient | Pengyu Kan, Vishnu Lokhande | pkan2@wisc.edu, lokhande@cs.wisc.edu | +| cyclical_learning_rate | Raphael Meudec | raphael.meudec@gmail.com | | lamb | Jing Li, Junjie Ke | jingli@google.com, junjiek@google.com | | lazy_adam | Saishruthi Swaminathan | saishruthi.tn@gmail.com | | lookahead | Zhao Hanguang | cyberzhg@gmail.com | @@ -16,6 +17,7 @@ | Submodule | Optimizer | Reference | |:--------- |:---------- |:---------| | conditional_gradient | ConditionalGradient | https://arxiv.org/pdf/1803.06453.pdf | +| cyclical_learning_rate | Cyclical Learning Rate | https://arxiv.org/abs/1506.01186 | | lamb | LAMB | https://arxiv.org/abs/1904.00962 | | lazy_adam | LazyAdam | https://arxiv.org/abs/1412.6980 | | lookahead | Lookahead | https://arxiv.org/abs/1907.08610v1 | diff --git a/tensorflow_addons/optimizers/__init__.py b/tensorflow_addons/optimizers/__init__.py index 76cbc0d172..ace4d1dce7 100644 --- a/tensorflow_addons/optimizers/__init__.py +++ b/tensorflow_addons/optimizers/__init__.py @@ -19,6 +19,14 @@ from __future__ import print_function from tensorflow_addons.optimizers.conditional_gradient import ConditionalGradient +from tensorflow_addons.optimizers.cyclical_learning_rate import ( + CyclicalLearningRate) +from tensorflow_addons.optimizers.cyclical_learning_rate import ( + TriangularCyclicalLearningRate) +from tensorflow_addons.optimizers.cyclical_learning_rate import ( + Triangular2CyclicalLearningRate) +from tensorflow_addons.optimizers.cyclical_learning_rate import ( + ExponentialCyclicalLearningRate) from tensorflow_addons.optimizers.lamb import LAMB from tensorflow_addons.optimizers.lazy_adam import LazyAdam from tensorflow_addons.optimizers.lookahead import Lookahead diff --git a/tensorflow_addons/optimizers/cyclical_learning_rate.py b/tensorflow_addons/optimizers/cyclical_learning_rate.py new file mode 100644 index 0000000000..852362c0ee --- /dev/null +++ b/tensorflow_addons/optimizers/cyclical_learning_rate.py @@ -0,0 +1,293 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Cyclical Learning Rate Schedule policies for TensorFlow.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + + +@tf.keras.utils.register_keras_serializable(package='Addons') +class CyclicalLearningRate(tf.keras.optimizers.schedules.LearningRateSchedule): + """A LearningRateSchedule that uses cyclical schedule.""" + + def __init__( + self, + initial_learning_rate, + maximal_learning_rate, + step_size, + scale_fn, + scale_mode="cycle", + name=None, + ): + """Applies cyclical schedule to the learning rate. + + See Cyclical Learning Rates for Training Neural Networks. https://arxiv.org/abs/1506.01186 + + + ```python + lr_schedule = tf.keras.optimizers.schedules.CyclicalLearningRate( + initial_learning_rate=1e-4, + maximal_learning_rate=1e-2, + step_size=2000, + scale_fn=lambda x: 1., + scale_mode="cycle", + name="MyCyclicScheduler") + + model.compile(optimizer=tf.keras.optimizers.SGD( + learning_rate=lr_schedule), + loss='sparse_categorical_crossentropy', + metrics=['accuracy']) + + model.fit(data, labels, epochs=5) + ``` + + You can pass this schedule directly into a + `tf.keras.optimizers.Optimizer` as the learning rate. + + Args: + initial_learning_rate: A scalar `float32` or `float64` `Tensor` or + a Python number. The initial learning rate. + maximal_learning_rate: A scalar `float32` or `float64` `Tensor` or + a Python number. The maximum learning rate. + step_size: A scalar `float32` or `float64` `Tensor` or a + Python number. Step size. + scale_fn: A function. Scheduling function applied in cycle + scale_mode: ['cycle', 'iterations']. Mode to apply during cyclic + schedule + name: (Optional) Name for the operation. + + Returns: + Updated learning rate value. + """ + super(CyclicalLearningRate, self).__init__() + self.initial_learning_rate = initial_learning_rate + self.maximal_learning_rate = maximal_learning_rate + self.step_size = step_size + self.scale_fn = scale_fn + self.scale_mode = scale_mode + self.name = name + + def __call__(self, step): + with tf.name_scope(self.name or "CyclicalLearningRate"): + initial_learning_rate = tf.convert_to_tensor( + self.initial_learning_rate, name="initial_learning_rate") + dtype = initial_learning_rate.dtype + maximal_learning_rate = tf.cast(self.maximal_learning_rate, dtype) + step_size = tf.cast(self.step_size, dtype) + cycle = tf.floor(1 + step / (2 * step_size)) + x = tf.abs(step / step_size - 2 * cycle + 1) + + mode_step = cycle if self.scale_mode == "cycle" else step + + return initial_learning_rate + ( + maximal_learning_rate - initial_learning_rate) * tf.maximum( + tf.cast(0, dtype), (1 - x)) * self.scale_fn(mode_step) + + def get_config(self): + return { + "initial_learning_rate": self.initial_learning_rate, + "maximal_learning_rate": self.maximal_learning_rate, + "step_size": self.step_size, + "scale_mode": self.scale_mode, + } + + +@tf.keras.utils.register_keras_serializable(package='Addons') +class TriangularCyclicalLearningRate(CyclicalLearningRate): + def __init__( + self, + initial_learning_rate, + maximal_learning_rate, + step_size, + scale_mode="cycle", + name="TriangularCyclicalLearningRate", + ): + """Applies triangular cyclical schedule to the learning rate. + + See Cyclical Learning Rates for Training Neural Networks. https://arxiv.org/abs/1506.01186 + + + ```python + from tf.keras.optimizers import schedules + + lr_schedule = schedules.TriangularCyclicalLearningRate( + initial_learning_rate=1e-4, + maximal_learning_rate=1e-2, + step_size=2000, + scale_mode="cycle", + name="MyCyclicScheduler") + + model.compile(optimizer=tf.keras.optimizers.SGD( + learning_rate=lr_schedule), + loss='sparse_categorical_crossentropy', + metrics=['accuracy']) + + model.fit(data, labels, epochs=5) + ``` + + You can pass this schedule directly into a + `tf.keras.optimizers.Optimizer` as the learning rate. + + Args: + initial_learning_rate: A scalar `float32` or `float64` `Tensor` or + a Python number. The initial learning rate. + maximal_learning_rate: A scalar `float32` or `float64` `Tensor` or + a Python number. The maximum learning rate. + step_size: A scalar `float32` or `float64` `Tensor` or a + Python number. Step size. + scale_fn: A function. Scheduling function applied in cycle + scale_mode: ['cycle', 'iterations']. Mode to apply during cyclic + schedule + name: (Optional) Name for the operation. + + Returns: + Updated learning rate value. + """ + super(TriangularCyclicalLearningRate, self).__init__( + initial_learning_rate=initial_learning_rate, + maximal_learning_rate=maximal_learning_rate, + step_size=step_size, + scale_fn=lambda x: 1., + scale_mode=scale_mode, + name=name, + ) + + +@tf.keras.utils.register_keras_serializable(package='Addons') +class Triangular2CyclicalLearningRate(CyclicalLearningRate): + def __init__( + self, + initial_learning_rate, + maximal_learning_rate, + step_size, + scale_mode="cycle", + name="Triangular2CyclicalLearningRate", + ): + """Applies triangular2 cyclical schedule to the learning rate. + + See Cyclical Learning Rates for Training Neural Networks. https://arxiv.org/abs/1506.01186 + + + ```python + from tf.keras.optimizers import schedules + + lr_schedule = schedules.Triangular2CyclicalLearningRate( + initial_learning_rate=1e-4, + maximal_learning_rate=1e-2, + step_size=2000, + scale_mode="cycle", + name="MyCyclicScheduler") + + model.compile(optimizer=tf.keras.optimizers.SGD( + learning_rate=lr_schedule), + loss='sparse_categorical_crossentropy', + metrics=['accuracy']) + + model.fit(data, labels, epochs=5) + ``` + + You can pass this schedule directly into a + `tf.keras.optimizers.Optimizer` as the learning rate. + + Args: + initial_learning_rate: A scalar `float32` or `float64` `Tensor` or + a Python number. The initial learning rate. + maximal_learning_rate: A scalar `float32` or `float64` `Tensor` or + a Python number. The maximum learning rate. + step_size: A scalar `float32` or `float64` `Tensor` or a + Python number. Step size. + scale_fn: A function. Scheduling function applied in cycle + scale_mode: ['cycle', 'iterations']. Mode to apply during cyclic + schedule + name: (Optional) Name for the operation. + + Returns: + Updated learning rate value. + """ + super(Triangular2CyclicalLearningRate, self).__init__( + initial_learning_rate=initial_learning_rate, + maximal_learning_rate=maximal_learning_rate, + step_size=step_size, + scale_fn=lambda x: 1 / (2.**(x - 1)), + scale_mode=scale_mode, + name=name, + ) + + +@tf.keras.utils.register_keras_serializable(package='Addons') +class ExponentialCyclicalLearningRate(CyclicalLearningRate): + def __init__( + self, + initial_learning_rate, + maximal_learning_rate, + step_size, + scale_mode="iterations", + gamma=1., + name="ExponentialCyclicalLearningRate", + ): + """Applies exponential cyclical schedule to the learning rate. + + See Cyclical Learning Rates for Training Neural Networks. https://arxiv.org/abs/1506.01186 + + + ```python + from tf.keras.optimizers import schedules + + lr_schedule = ExponentialCyclicalLearningRate( + initial_learning_rate=1e-4, + maximal_learning_rate=1e-2, + step_size=2000, + scale_mode="cycle", + gamma=0.96, + name="MyCyclicScheduler") + + model.compile(optimizer=tf.keras.optimizers.SGD( + learning_rate=lr_schedule), + loss='sparse_categorical_crossentropy', + metrics=['accuracy']) + + model.fit(data, labels, epochs=5) + ``` + + You can pass this schedule directly into a + `tf.keras.optimizers.Optimizer` as the learning rate. + + Args: + initial_learning_rate: A scalar `float32` or `float64` `Tensor` or + a Python number. The initial learning rate. + maximal_learning_rate: A scalar `float32` or `float64` `Tensor` or + a Python number. The maximum learning rate. + step_size: A scalar `float32` or `float64` `Tensor` or a + Python number. Step size. + scale_fn: A function. Scheduling function applied in cycle + scale_mode: ['cycle', 'iterations']. Mode to apply during cyclic + schedule + gamma: A scalar `float32` or `float64` `Tensor` or a + Python number. Gamma value. + name: (Optional) Name for the operation. + + Returns: + Updated learning rate value. + """ + super(ExponentialCyclicalLearningRate, self).__init__( + initial_learning_rate=initial_learning_rate, + maximal_learning_rate=maximal_learning_rate, + step_size=step_size, + scale_fn=lambda x: gamma**x, + scale_mode=scale_mode, + name=name, + ) diff --git a/tensorflow_addons/optimizers/cyclical_learning_rate_test.py b/tensorflow_addons/optimizers/cyclical_learning_rate_test.py new file mode 100644 index 0000000000..fbb451765d --- /dev/null +++ b/tensorflow_addons/optimizers/cyclical_learning_rate_test.py @@ -0,0 +1,162 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Cyclical Learning Rate.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized + +import tensorflow as tf +from tensorflow_addons.utils import test_utils +import numpy as np + +import cyclical_learning_rate + + +def _maybe_serialized(lr_decay, serialize_and_deserialize): + if serialize_and_deserialize: + serialized = tf.keras.optimizers.learning_rate_schedule.serialize( + lr_decay) + return tf.keras.optimizers.learning_rate_schedule.deserialize( + serialized) + else: + return lr_decay + + +@parameterized.named_parameters(("NotSerialized", False), ("Serialized", True)) +class CyclicalLearningRateTest(tf.test.TestCase, parameterized.TestCase): + @test_utils.run_in_graph_and_eager_modes(reset_test=True) + def testTriangularCyclicalLearningRate(self, serialize): + initial_learning_rate = 0.1 + maximal_learning_rate = 1 + step_size = 4000 + step = tf.resource_variable_ops.ResourceVariable(0) + triangular_cyclical_lr = ( + cyclical_learning_rate.TriangularCyclicalLearningRate( + initial_learning_rate=initial_learning_rate, + maximal_learning_rate=maximal_learning_rate, + step_size=step_size, + )) + triangular_cyclical_lr = _maybe_serialized(triangular_cyclical_lr, + serialize) + + self.evaluate(tf.compat.v1.global_variables_initializer()) + expected = np.concatenate([ + np.linspace( + initial_learning_rate, maximal_learning_rate, num=2001)[1:], + np.linspace( + maximal_learning_rate, initial_learning_rate, num=2001)[1:] + ]) + + for expected_value in expected: + self.assertAllClose( + self.evaluate(triangular_cyclical_lr(step)), expected_value, + 1e-6) + self.evaluate(step.assign_add(1)) + + @test_utils.run_in_graph_and_eager_modes(reset_test=True) + def testTriangular2CyclicalLearningRate(self, serialize): + initial_learning_rate = 0.1 + maximal_learning_rate = 1 + step_size = 4000 + step = tf.resource_variable_ops.ResourceVariable(0) + triangular2_cyclical_lr = ( + cyclical_learning_rate.Triangular2CyclicalLearningRate( + initial_learning_rate=initial_learning_rate, + maximal_learning_rate=maximal_learning_rate, + step_size=step_size, + )) + triangular2_cyclical_lr = _maybe_serialized(triangular2_cyclical_lr, + serialize) + + self.evaluate(tf.compat.v1.global_variables_initializer()) + middle_learning_rate = ( + maximal_learning_rate + initial_learning_rate) / 2 + expected = np.concatenate([ + np.linspace( + initial_learning_rate, maximal_learning_rate, num=2001)[1:], + np.linspace( + maximal_learning_rate, initial_learning_rate, num=2001)[1:], + np.linspace(initial_learning_rate, middle_learning_rate, + num=2001)[1:], + np.linspace(middle_learning_rate, initial_learning_rate, + num=2001)[1:], + ]) + + for expected_value in expected: + self.assertAllClose( + self.evaluate(triangular2_cyclical_lr(step)), expected_value, + 1e-6) + self.evaluate(step.assign_add(1)) + + @test_utils.run_in_graph_and_eager_modes(reset_test=True) + def testExponentialCyclicalLearningRate(self, serialize): + initial_learning_rate = 0.1 + maximal_learning_rate = 1 + step_size = 4000 + gamma = 0.996 + + step = tf.resource_variable_ops.ResourceVariable(0) + exponential_cyclical_lr = ( + cyclical_learning_rate.ExponentialCyclicalLearningRate( + initial_learning_rate=initial_learning_rate, + maximal_learning_rate=maximal_learning_rate, + step_size=step_size, + gamma=gamma, + )) + exponential_cyclical_lr = _maybe_serialized(exponential_cyclical_lr, + serialize) + + self.evaluate(tf.compat.v1.global_variables_initializer()) + + for i in range(1, 8001): + non_bounded_value = np.abs(i / 2000. - + 2 * np.floor(1 + i / (2 * 2000)) + 1) + expected = initial_learning_rate + ( + maximal_learning_rate - initial_learning_rate) * np.maximum( + 0, (1 - non_bounded_value)) * (gamma**i) + self.assertAllClose( + self.evaluate(exponential_cyclical_lr(step)), expected, 1e-6) + self.evaluate(step.assign_add(1)) + + @test_utils.run_in_graph_and_eager_modes(reset_test=True) + def testCustomCyclicalLearningRate(self, serialize): + initial_learning_rate = 0.1 + maximal_learning_rate = 1 + step_size = 4000 + scale_fn = lambda x: 1 / (5**(x * 0.0001)) + + step = tf.resource_variable_ops.ResourceVariable(0) + custom_cyclical_lr = cyclical_learning_rate.CyclicalLearningRate( + initial_learning_rate=initial_learning_rate, + maximal_learning_rate=maximal_learning_rate, + step_size=step_size, + scale_fn=scale_fn, + ) + custom_cyclical_lr = _maybe_serialized(custom_cyclical_lr, serialize) + + self.evaluate(tf.compat.v1.global_variables_initializer()) + + for i in range(1, 8001): + non_bounded_value = np.abs(i / 2000. - + 2 * np.floor(1 + i / (2 * 2000)) + 1) + expected = initial_learning_rate + ( + maximal_learning_rate - initial_learning_rate) * np.maximum( + 0, 1 - non_bounded_value) * scale_fn(i) + self.assertAllClose( + self.evaluate(custom_cyclical_lr(step)), expected, 1e-6) + self.evaluate(step.assign_add(1))