From 67c1e8ea19e82c3f2a5706674dd81f15ab5002a2 Mon Sep 17 00:00:00 2001 From: fsx950223 Date: Wed, 21 Jul 2021 08:10:30 +0000 Subject: [PATCH] fix iterations --- .../optimizers/gradient_accumulator.py | 100 ++++++++---------- .../tests/gradient_accumulator_test.py | 50 +++------ 2 files changed, 60 insertions(+), 90 deletions(-) diff --git a/tensorflow_addons/optimizers/gradient_accumulator.py b/tensorflow_addons/optimizers/gradient_accumulator.py index 268e019454..57051f8e9e 100644 --- a/tensorflow_addons/optimizers/gradient_accumulator.py +++ b/tensorflow_addons/optimizers/gradient_accumulator.py @@ -49,20 +49,13 @@ def __init__( super().__init__(name, **kwargs) self._optimizer = tf.keras.optimizers.get(inner_optimizer) self._step = None - self._gradients = {} self._accum_steps = accum_steps self._reduction = reduction def _accum_grad(grads_and_vars): - with tf.init_scope(): - if not self._gradients: - for grad, var in grads_and_vars: - self._gradients[var.ref()] = tf.Variable( - tf.zeros_like(var), trainable=False - ) new_grads_and_vars = [] for grad, var in grads_and_vars: - handle = self._gradients[var.ref()] + handle = self.get_slot(var, "ga") if isinstance(grad, tf.IndexedSlices): handle.scatter_add(grad) @@ -84,9 +77,11 @@ def _get_grad(): values = tf.gather(new_grad, indices) dense_shape = tf.constant(new_grad.shape.as_list()) handle.assign( - tf.zeros_like(handle), use_locking=self._use_locking + tf.zeros_like(handle), + use_locking=self._use_locking, + read_value=False, ) - return values, tf.cast(indices, tf.int32), dense_shape + return values, tf.cast(indices, grad.indices.dtype), dense_shape values, indices, dense_shape = tf.cond( self.step % self._accum_steps == 0, @@ -100,14 +95,18 @@ def _get_grad(): new_grad = tf.IndexedSlices(values, indices, dense_shape) new_grads_and_vars.append((new_grad, var)) else: - handle.assign_add(grad) + handle.assign_add( + grad, use_locking=self._use_locking, read_value=False + ) def _get_grad(): new_grad = handle.read_value() if self._reduction == "MEAN": new_grad /= tf.cast(self._accum_steps, new_grad.dtype) handle.assign( - tf.zeros_like(handle), use_locking=self._use_locking + tf.zeros_like(handle), + use_locking=self._use_locking, + read_value=False, ) return new_grad @@ -119,11 +118,39 @@ def _get_grad(): new_grads_and_vars.append((new_grad, var)) return new_grads_and_vars - self._optimizer.gradient_transformers.append(_accum_grad) + self.gradient_transformers.append(_accum_grad) self._iterations = self._optimizer.iterations def _create_slots(self, var_list): self._optimizer._create_slots(var_list=var_list) + for var in var_list: + self.add_slot(var, "ga") + + def _resource_apply_dense(self, grad, handle, apply_state): + if "apply_state" in self._optimizer._dense_apply_args: + return self.inner_optimizer._resource_apply_dense(grad, handle, apply_state) + else: + return self.inner_optimizer._resource_apply_dense(grad, handle) + + def _resource_apply_sparse(self, grad, handle, indices, apply_state): + if "apply_state" in self._optimizer._sparse_apply_args: + return self.inner_optimizer._resource_apply_sparse( + grad, handle, indices, apply_state=apply_state + ) + else: + return self.inner_optimizer._resource_apply_sparse(grad, handle, indices) + + def _resource_apply_sparse_duplicate_indices( + self, grad, handle, indices, apply_state=None + ): + if "apply_state" in self._optimizer._sparse_apply_args: + return self.inner_optimizer._resource_apply_sparse_duplicate_indices( + grad, handle, indices, apply_state=apply_state + ) + else: + return self.inner_optimizer._resource_apply_sparse_duplicate_indices( + grad, handle, indices + ) @property def step(self): @@ -133,7 +160,6 @@ def step(self): self._step = self.add_weight( "iter", shape=[], - initializer="ones", dtype=tf.int64, trainable=False, aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA, @@ -151,49 +177,15 @@ def step(self, variable): self._step = variable self._weights.append(self._step) - @property - def gradients(self): - """The accumulated gradients on the current replica.""" - if not self._gradients: - raise ValueError( - "The accumulator should be called first to initialize the gradients" - ) - return list( - gradient.read_value() if gradient is not None else gradient - for _, gradient in self._gradients - ) - def apply_gradients(self, grads_and_vars, name=None, **kwargs): - train_op = self._optimizer.apply_gradients(grads_and_vars, name, **kwargs) - with tf.control_dependencies([train_op]): - with tf.control_dependencies( - [ - self._optimizer.iterations.assign_add( - tf.cast(self.step % self._accum_steps == 0, tf.int64), - read_value=False, - ) - ] - ): - return self.step.assign_add(1, read_value=False) - - def reset(self): - """Resets the accumulated gradients on the current replica.""" - assign_ops = [] - if not self._gradients: - return assign_ops - - for _, gradient in self._gradients: - if gradient is not None: - assign_ops.append( - gradient.assign( - tf.zeros_like(gradient), - use_locking=self._use_locking, - read_value=False, - ) + with tf.control_dependencies([self.step.assign_add(1, read_value=False)]): + train_op = super().apply_gradients(grads_and_vars, name, **kwargs) + with tf.control_dependencies([train_op]): + return self.iterations.assign_sub( + tf.cast(self.step % self._accum_steps != 0, tf.int64), + read_value=False, ) - return tf.group(assign_ops) - @property def inner_optimizer(self): """The optimizer that this LossScaleOptimizer is wrapping.""" diff --git a/tensorflow_addons/optimizers/tests/gradient_accumulator_test.py b/tensorflow_addons/optimizers/tests/gradient_accumulator_test.py index 7fe4171edd..18d8d890f1 100644 --- a/tensorflow_addons/optimizers/tests/gradient_accumulator_test.py +++ b/tensorflow_addons/optimizers/tests/gradient_accumulator_test.py @@ -17,12 +17,12 @@ import numpy as np import pytest import tensorflow as tf -from tensorflow_addons.utils import test_utils from tensorflow_addons.optimizers import GradientAccumulator @pytest.mark.usefixtures("maybe_run_functions_eagerly") +@pytest.mark.with_device(["cpu", "gpu", tf.distribute.MirroredStrategy]) def test_run(): var0 = tf.Variable([1.0, 2.0]) var1 = tf.Variable([3.0, 4.0]) @@ -35,14 +35,18 @@ def test_run(): opt = GradientAccumulator(tf.keras.optimizers.SGD(lr=1.0), accum_steps) + strategy = tf.distribute.get_strategy() for _ in range(accum_steps + 1): - opt.apply_gradients(grads_and_vars) + strategy.run(opt.apply_gradients, [grads_and_vars]) np.testing.assert_allclose(var0.read_value(), [0.6, 1.6]) np.testing.assert_allclose(var1.read_value(), [2.96, 3.96]) + np.testing.assert_allclose(opt.iterations.read_value(), 1) + np.testing.assert_allclose(opt.step.read_value(), accum_steps + 1) @pytest.mark.usefixtures("maybe_run_functions_eagerly") +@pytest.mark.with_device(["cpu", "gpu", tf.distribute.MirroredStrategy]) def test_sparse(): var0 = tf.Variable([[1.0, 2.0, 0.0], [1.0, 2.0, 0.0]]) var1 = tf.Variable([[3.0, 4.0, 0.0]]) @@ -60,38 +64,13 @@ def test_sparse(): grads_and_vars = list(zip([grads0, grads1], [var0, var1])) opt = GradientAccumulator(tf.keras.optimizers.SGD(lr=1.0)) + strategy = tf.distribute.get_strategy() for _ in range(8): - opt.apply_gradients(grads_and_vars) + strategy.run(opt.apply_gradients, [grads_and_vars]) np.testing.assert_allclose(var0.read_value(), [[1.0, 2.0, 0.0], [0.2, 1.2, 0.0]]) np.testing.assert_allclose(var1.read_value(), [[2.92, 3.92, 0.0]]) -@pytest.mark.usefixtures("maybe_run_functions_eagerly") -@pytest.mark.needs_gpu -def test_sparse_multi_gpus(): - strategy = tf.distribute.MirroredStrategy(test_utils.gpus_for_testing()) - with strategy.scope(): - var0 = tf.Variable([[1.0, 2.0, 0.0]]) - var1 = tf.Variable([[3.0, 4.0, 0.0]]) - - grads0 = tf.IndexedSlices( - tf.constant([[0.1, 0.1, 0.0]]), - tf.constant([0]), - tf.constant([1, 3]), - ) - grads1 = tf.IndexedSlices( - tf.constant([[0.01, 0.01, 0.0]]), - tf.constant([0]), - tf.constant([1, 3]), - ) - - grads_and_vars = list(zip([grads0, grads1], [var0, var1])) - opt = GradientAccumulator(tf.keras.optimizers.SGD(lr=1.0)) - strategy.run(opt.apply_gradients, [grads_and_vars]) - np.testing.assert_allclose(var0.read_value(), [[1.0, 2.0, 0.0]]) - np.testing.assert_allclose(var1.read_value(), [[3.0, 4.0, 0.0]]) - - @pytest.mark.usefixtures("maybe_run_functions_eagerly") def test_dense(): grad = tf.Variable([[0.1]]) @@ -133,7 +112,7 @@ def test_config(): @pytest.mark.usefixtures("maybe_run_functions_eagerly") -@pytest.mark.needs_gpu +@pytest.mark.with_device([tf.distribute.MirroredStrategy]) def test_fit_simple_linear_model(): seed = 0x2019 np.random.seed(seed) @@ -142,13 +121,12 @@ def test_fit_simple_linear_model(): x = np.random.standard_normal((num_examples, 3)) w = np.random.standard_normal((3, 1)) y = np.dot(x, w) + np.random.standard_normal((num_examples, 1)) * 1e-4 - strategy = tf.distribute.MirroredStrategy(test_utils.gpus_for_testing()) - with strategy.scope(): - model = tf.keras.models.Sequential() - model.add(tf.keras.layers.Dense(input_shape=(3,), units=1)) - opt = GradientAccumulator("sgd") - model.compile(opt, loss="mse") + model = tf.keras.models.Sequential() + model.add(tf.keras.layers.Dense(input_shape=(3,), units=1)) + + opt = GradientAccumulator("sgd") + model.compile(opt, loss="mse") model.fit(x, y, epochs=5)