From 063d6ae3c57af94e9d0bb4457ef949ebc6c74953 Mon Sep 17 00:00:00 2001 From: holldean <644564286@qq.com> Date: Tue, 28 Apr 2020 10:29:48 +0800 Subject: [PATCH] Add optimization_flop.py --- flop/factorize.py | 5 +- flop/layers.py | 66 ++++----------- flop/modeling_flop.py | 126 +++++++++++++++++++-------- flop/nn.py | 3 +- flop/optimization_flop.py | 174 ++++++++++++++++++++++++++++++++++++++ main.ipynb | 12 +-- 6 files changed, 288 insertions(+), 98 deletions(-) create mode 100644 flop/optimization_flop.py diff --git a/flop/factorize.py b/flop/factorize.py index 52433b2..c32f104 100644 --- a/flop/factorize.py +++ b/flop/factorize.py @@ -1,7 +1,6 @@ from tensorflow.python import pywrap_tensorflow import tensorflow as tf import numpy as np -import factorize import copy import os import re @@ -43,13 +42,13 @@ def save_factorized_model(bert_config_file, init_checkpoint, output_dir): tvar_names.append(var.name) for key in var_to_shape_map: if re.match(bias_pattern, key): - q = factorize.bias_map(key) + q = bias_map(key) q_var = [v for v in tvar if v.name == q][0] tf.logging.info("Tensor: %s %s", q, "*INIT_FROM_CKPT*") sess.run(tf.assign(q_var, reader.get_tensor(key))) tvar_names.remove(q) elif re.match(kernel_pattern, key): - p, q = factorize.kernel_map(key) + p, q = kernel_map(key) p_var = [v for v in tvar if v.name == p][0] q_var = [v for v in tvar if v.name == q][0] u, s, v = np.linalg.svd(reader.get_tensor(key)) diff --git a/flop/layers.py b/flop/layers.py index 3eaaca5..8f2e4b7 100644 --- a/flop/layers.py +++ b/flop/layers.py @@ -20,11 +20,12 @@ from __future__ import print_function import tensorflow as tf +import math import common import nn from tensorflow.python.layers import base # pylint: disable=g-direct-tensorflow-import -from tensorflow.contrib.layers.python.layers import utils as layer_utils +# from tensorflow.contrib.layers.python.layers import utils as layer_utils from tensorflow.python.ops import variables as tf_variables # pylint: disable=g-direct-tensorflow-import @@ -35,13 +36,8 @@ class FlopFullyConnected(base.Layer): """Base implementation of a fully connected layer with FLOP. Args: x: Input, float32 tensor. - num_outputs: Int representing size of output tensor. - activation: If None, a linear activation is used. - bias_initializer: Initalizer of the bias vector. - bias_regularizer: Optional regularizer for the bias vector. log_alpha_initializer: Specified initializer of the log_alpha term. is_training: Boolean specifying whether it is training or eval. - use_bias: Boolean specifying whether bias vector should be used. eps: Small epsilon value to prevent math op saturation. beta: The beta parameter, which controls the "temperature" of the distribution. Defaults to 2/3 from the above paper. @@ -55,19 +51,15 @@ class FlopFullyConnected(base.Layer): """ def __init__(self, - num_outputs, - activation, - bias_initializer, - bias_regularizer, - log_alpha_initializer, activity_regularizer=None, is_training=True, trainable=True, - use_bias=True, - eps=common.EPSILON, - beta=common.BETA, - limit_l=common.LIMIT_L, - limit_r=common.LIMIT_R, + init_mean=0.5, + init_std=0.01, + eps=1e-6, + beta=1.0, + limit_l=-0.1, + limit_r=1.1, name="flop_mask", **kwargs): super(FlopFullyConnected, self).__init__( @@ -75,13 +67,9 @@ def __init__(self, name=name, activity_regularizer=activity_regularizer, **kwargs) - self.num_outputs = num_outputs - self.activation = activation - self.bias_initializer = bias_initializer - self.bias_regularizer = bias_regularizer - self.log_alpha_initializer = log_alpha_initializer self.is_training = is_training - self.use_bias = use_bias + self.init_mean = init_mean + self.init_std = init_std self.eps = eps self.beta = beta self.limit_l = limit_l @@ -90,37 +78,23 @@ def __init__(self, def build(self, input_shape): input_shape = input_shape.as_list() - assert input_shape[0] == input_shape[1] - input_hidden_size = input_shape[1] - diag_size = input_shape[0] - if not self.log_alpha_initializer: - # default log alpha set s.t. \alpha / (\alpha + 1) = .1 - self.log_alpha_initializer = tf.random_normal_initializer( - mean=2.197, stddev=0.01, dtype=self.dtype) + mean = math.log(1 - self.init_mean) - math.log(self.init_mean) + self.log_alpha_initializer = tf.random_normal_initializer( + mean=mean, stddev=self.init_std, dtype=self.dtype) self.log_alpha = tf.get_variable( "log_alpha", - shape=diag_size, + shape=input_hidden_size, initializer=self.log_alpha_initializer, dtype=self.dtype, trainable=True) - layer_utils.add_variable_to_collection( - self.log_alpha, - [THETA_LOGALPHA_COLLECTION], None) - - if self.use_bias: - self.bias = self.add_variable( - name="bias", - shape=(self.num_outputs,), - initializer=self.bias_initializer, - regularizer=self.bias_regularizer, - trainable=True, - dtype=self.dtype) - else: - self.bias = None + # layer_utils.add_variable_to_collection( + # self.log_alpha, + # [THETA_LOGALPHA_COLLECTION], None) + self.built = True def call(self, inputs): @@ -139,10 +113,6 @@ def call(self, inputs): limit_l=self.limit_l, limit_r=self.limit_r) - if self.use_bias: - x = tf.nn.bias_add(x, self.bias) - if self.activation is not None: - return self.activation(x) return x diff --git a/flop/modeling_flop.py b/flop/modeling_flop.py index d6dd790..14262c0 100644 --- a/flop/modeling_flop.py +++ b/flop/modeling_flop.py @@ -1,4 +1,5 @@ from modeling import * +import layers class BertModelHardConcrete(BertModel): @@ -78,7 +79,7 @@ def __init__(self, # Run the stacked transformer. # `sequence_output` shape = [batch_size, seq_length, hidden_size]. - self.all_encoder_layers = transformer_model_train( + self.all_encoder_layers = transformer_model_flop( input_tensor=self.embedding_output, attention_mask=attention_mask, hidden_size=config.hidden_size, @@ -89,7 +90,8 @@ def __init__(self, hidden_dropout_prob=config.hidden_dropout_prob, attention_probs_dropout_prob=config.attention_probs_dropout_prob, initializer_range=config.initializer_range, - do_return_all_layers=True) + do_return_all_layers=True, + is_training=is_training) self.sequence_output = self.all_encoder_layers[-1] # The "pooler" converts the encoded sequence tensor of shape @@ -109,20 +111,21 @@ def __init__(self, kernel_initializer=create_initializer(config.initializer_range)) -def attention_layer_train(from_tensor, - to_tensor, - attention_mask=None, - num_attention_heads=1, - size_per_head=512, - query_act=None, - key_act=None, - value_act=None, - attention_probs_dropout_prob=0.0, - initializer_range=0.02, - do_return_2d_tensor=False, - batch_size=None, - from_seq_length=None, - to_seq_length=None): +def attention_layer_flop(from_tensor, + to_tensor, + attention_mask=None, + num_attention_heads=1, + size_per_head=512, + query_act=None, + key_act=None, + value_act=None, + attention_probs_dropout_prob=0.0, + initializer_range=0.02, + do_return_2d_tensor=False, + batch_size=None, + from_seq_length=None, + to_seq_length=None, + is_training=True): def transpose_for_scores(input_tensor, batch_size, num_attention_heads, seq_length, width): @@ -169,9 +172,16 @@ def transpose_for_scores(input_tensor, batch_size, num_attention_heads, name="query_p", kernel_initializer=create_initializer(initializer_range)) + # Attention: log_alpha_initializer, eps, beta, limit_l, limit_r! + query_layer_mask = layers.FlopFullyConnected( + name="query_g", + log_alpha_initializer=None, + is_training=is_training) + + query_layer_mask_output = query_layer_mask(query_layer_p) query_layer = tf.layers.dense( - query_layer_p, + query_layer_mask_output, num_attention_heads * size_per_head, activation=query_act, name="query_q", @@ -194,8 +204,16 @@ def transpose_for_scores(input_tensor, batch_size, num_attention_heads, name="key_p", kernel_initializer=create_initializer(initializer_range)) + # Attention: log_alpha_initializer, eps, beta, limit_l, limit_r! + key_layer_mask = layers.FlopFullyConnected( + name="key_g", + log_alpha_initializer=None, + is_training=is_training) + + key_layer_mask_output = key_layer_mask(key_layer_p) + key_layer = tf.layers.dense( - key_layer_p, + key_layer_mask_output, num_attention_heads * size_per_head, activation=key_act, name="key_q", @@ -218,8 +236,16 @@ def transpose_for_scores(input_tensor, batch_size, num_attention_heads, name="value_p", kernel_initializer=create_initializer(initializer_range)) + # Attention: log_alpha_initializer, eps, beta, limit_l, limit_r! + value_layer_mask = layers.FlopFullyConnected( + name="value_g", + log_alpha_initializer=None, + is_training=is_training) + + value_layer_mask_output = value_layer_mask(value_layer_p) + value_layer = tf.layers.dense( - value_layer_p, + value_layer_mask_output, num_attention_heads * size_per_head, activation=value_act, name="value_q", @@ -298,17 +324,18 @@ def transpose_for_scores(input_tensor, batch_size, num_attention_heads, return context_layer -def transformer_model_train(input_tensor, - attention_mask=None, - hidden_size=768, - num_hidden_layers=12, - num_attention_heads=12, - intermediate_size=3072, - intermediate_act_fn=gelu, - hidden_dropout_prob=0.1, - attention_probs_dropout_prob=0.1, - initializer_range=0.02, - do_return_all_layers=False): +def transformer_model_flop(input_tensor, + attention_mask=None, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + intermediate_act_fn=gelu, + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + initializer_range=0.02, + do_return_all_layers=False, + is_training=True): if hidden_size % num_attention_heads != 0: raise ValueError( "The hidden size (%d) is not a multiple of the number of attention " @@ -340,7 +367,7 @@ def transformer_model_train(input_tensor, with tf.variable_scope("attention"): attention_heads = [] with tf.variable_scope("self"): - attention_head = attention_layer_train( + attention_head = attention_layer_flop( from_tensor=layer_input, to_tensor=layer_input, attention_mask=attention_mask, @@ -351,7 +378,8 @@ def transformer_model_train(input_tensor, do_return_2d_tensor=True, batch_size=batch_size, from_seq_length=seq_length, - to_seq_length=seq_length) + to_seq_length=seq_length, + is_training=is_training) attention_heads.append(attention_head) attention_output = None @@ -373,8 +401,17 @@ def transformer_model_train(input_tensor, name="dense_p", kernel_initializer=create_initializer(initializer_range)) + # Attention: log_alpha_initializer, eps, beta, limit_l, limit_r! + attention_output_mask = layers.FlopFullyConnected( + name="dense_g", + log_alpha_initializer=None, + is_training=is_training) + + attention_output_mask_output = attention_output_mask( + attention_output_p) + attention_output = tf.layers.dense( - attention_output_p, + attention_output_mask_output, hidden_size, name="dense_q", kernel_initializer=create_initializer(initializer_range)) @@ -400,8 +437,17 @@ def transformer_model_train(input_tensor, name='dense_p', kernel_initializer=create_initializer(initializer_range)) + # Attention: log_alpha_initializer, eps, beta, limit_l, limit_r! + intermediate_output_mask = layers.FlopFullyConnected( + name="dense_g", + log_alpha_initializer=None, + is_training=is_training) + + intermediate_output_mask_output = intermediate_output_mask( + intermediate_output_p) + intermediate_output = tf.layers.dense( - intermediate_output_p, + intermediate_output_mask_output, intermediate_size, activation=intermediate_act_fn, name='dense_q', @@ -423,8 +469,16 @@ def transformer_model_train(input_tensor, name="dense_p", kernel_initializer=create_initializer(initializer_range)) + # Attention: log_alpha_initializer, eps, beta, limit_l, limit_r! + layer_output_mask = layers.FlopFullyConnected( + name="dense_g", + log_alpha_initializer=None, + is_training=is_training) + + layer_output_mask_output = layer_output_mask(layer_output_p) + layer_output = tf.layers.dense( - layer_output_p, + layer_output_mask_output, hidden_size, name="dense_q", kernel_initializer=create_initializer(initializer_range)) @@ -530,7 +584,7 @@ def model_fn(features, labels, mode, params): # pylint: disable=unused-argument is_training = (mode == tf.estimator.ModeKeys.TRAIN) - (total_loss, per_example_loss, logits, probabilities) = create_model( + (total_loss, per_example_loss, logits, probabilities) = create_model_train( bert_config, is_training, input_ids, input_mask, segment_ids, label_ids, num_labels, use_one_hot_embeddings) diff --git a/flop/nn.py b/flop/nn.py index 91cfe58..2b8e5c6 100644 --- a/flop/nn.py +++ b/flop/nn.py @@ -20,8 +20,7 @@ import tensorflow as tf -from state_of_sparsity.layers.l0_regularization import common -from state_of_sparsity.layers.utils import layer_utils +import common def matmul_train( diff --git a/flop/optimization_flop.py b/flop/optimization_flop.py new file mode 100644 index 0000000..0c88ea7 --- /dev/null +++ b/flop/optimization_flop.py @@ -0,0 +1,174 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Functions and classes related to optimization (weight updates).""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import re +import tensorflow as tf + + +def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu): + """Creates an optimizer training op.""" + global_step = tf.train.get_or_create_global_step() + + learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32) + + # Implements linear decay of the learning rate. + learning_rate = tf.train.polynomial_decay( + learning_rate, + global_step, + num_train_steps, + end_learning_rate=0.0, + power=1.0, + cycle=False) + + # Implements linear warmup. I.e., if global_step < num_warmup_steps, the + # learning rate will be `global_step/num_warmup_steps * init_lr`. + if num_warmup_steps: + global_steps_int = tf.cast(global_step, tf.int32) + warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32) + + global_steps_float = tf.cast(global_steps_int, tf.float32) + warmup_steps_float = tf.cast(warmup_steps_int, tf.float32) + + warmup_percent_done = global_steps_float / warmup_steps_float + warmup_learning_rate = init_lr * warmup_percent_done + + is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32) + learning_rate = ( + (1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate) + + # It is recommended that you use this optimizer for fine tuning, since this + # is how the model was trained (note that the Adam m/v variables are NOT + # loaded from init_checkpoint.) + optimizer = AdamWeightDecayOptimizer( + learning_rate=learning_rate, + weight_decay_rate=0.01, + beta_1=0.9, + beta_2=0.999, + epsilon=1e-6, + exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"]) + + if use_tpu: + optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer) + + tvars = tf.trainable_variables() + grads = tf.gradients(loss, tvars) + + # This is how the model was pre-trained. + (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0) + + train_op = optimizer.apply_gradients( + zip(grads, tvars), global_step=global_step) + + # Normally the global step update is done inside of `apply_gradients`. + # However, `AdamWeightDecayOptimizer` doesn't do this. But if you use + # a different optimizer, you should probably take this line out. + new_global_step = global_step + 1 + train_op = tf.group(train_op, [global_step.assign(new_global_step)]) + return train_op + + +class AdamWeightDecayOptimizer(tf.train.Optimizer): + """A basic Adam optimizer that includes "correct" L2 weight decay.""" + + def __init__(self, + learning_rate, + weight_decay_rate=0.0, + beta_1=0.9, + beta_2=0.999, + epsilon=1e-6, + exclude_from_weight_decay=None, + name="AdamWeightDecayOptimizer"): + """Constructs a AdamWeightDecayOptimizer.""" + super(AdamWeightDecayOptimizer, self).__init__(False, name) + + self.learning_rate = learning_rate + self.weight_decay_rate = weight_decay_rate + self.beta_1 = beta_1 + self.beta_2 = beta_2 + self.epsilon = epsilon + self.exclude_from_weight_decay = exclude_from_weight_decay + + def apply_gradients(self, grads_and_vars, global_step=None, name=None): + """See base class.""" + assignments = [] + for (grad, param) in grads_and_vars: + if grad is None or param is None: + continue + + param_name = self._get_variable_name(param.name) + + m = tf.get_variable( + name=param_name + "/adam_m", + shape=param.shape.as_list(), + dtype=tf.float32, + trainable=False, + initializer=tf.zeros_initializer()) + v = tf.get_variable( + name=param_name + "/adam_v", + shape=param.shape.as_list(), + dtype=tf.float32, + trainable=False, + initializer=tf.zeros_initializer()) + + # Standard Adam update. + next_m = ( + tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad)) + next_v = ( + tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2, + tf.square(grad))) + + update = next_m / (tf.sqrt(next_v) + self.epsilon) + + # Just adding the square of the weights to the loss function is *not* + # the correct way of using L2 regularization/weight decay with Adam, + # since that will interact with the m and v parameters in strange ways. + # + # Instead we want ot decay the weights in a manner that doesn't interact + # with the m/v parameters. This is equivalent to adding the square + # of the weights to the loss with plain (non-momentum) SGD. + if self._do_use_weight_decay(param_name): + update += self.weight_decay_rate * param + + update_with_lr = self.learning_rate * update + + next_param = param - update_with_lr + + assignments.extend( + [param.assign(next_param), + m.assign(next_m), + v.assign(next_v)]) + return tf.group(*assignments, name=name) + + def _do_use_weight_decay(self, param_name): + """Whether to use L2 weight decay for `param_name`.""" + if not self.weight_decay_rate: + return False + if self.exclude_from_weight_decay: + for r in self.exclude_from_weight_decay: + if re.search(r, param_name) is not None: + return False + return True + + def _get_variable_name(self, param_name): + """Get the variable name from the tensor name.""" + m = re.match("^(.*):\\d+$", param_name) + if m is not None: + param_name = m.group(1) + return param_name diff --git a/main.ipynb b/main.ipynb index 42f5fcf..db78f87 100644 --- a/main.ipynb +++ b/main.ipynb @@ -28,13 +28,7 @@ "cell_type": "code", "execution_count": 2, "metadata": {}, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": "WARNING:tensorflow:From /data0/ultraman/BERT-Pruning/bert/modeling.py:93: The name tf.gfile.GFile is deprecated. Please use tf.io.gfile.GFile instead.\n\n" - } - ], + "outputs": [], "source": [ "tf.logging.set_verbosity(tf.logging.DEBUG)\n", "flags = {\n", @@ -81,7 +75,7 @@ { "output_type": "stream", "name": "stdout", - "text": "WARNING:tensorflow:From /data0/ultraman/BERT-Pruning/flop/modeling_hardconcrete.py:44: The name tf.variable_scope is deprecated. Please use tf.compat.v1.variable_scope instead.\n\nWARNING:tensorflow:From /data0/ultraman/BERT-Pruning/bert/modeling.py:409: The name tf.get_variable is deprecated. Please use tf.compat.v1.get_variable instead.\n\nWARNING:tensorflow:From /data0/ultraman/BERT-Pruning/bert/modeling.py:490: The name tf.assert_less_equal is deprecated. Please use tf.compat.v1.assert_less_equal instead.\n\nWARNING:tensorflow:From /data0/ultraman/BERT-Pruning/flop/modeling_hardconcrete.py:166: dense (from tensorflow.python.layers.core) is deprecated and will be removed in a future version.\nInstructions for updating:\nUse keras.layers.Dense instead.\nWARNING:tensorflow:From /home/ultraman/anaconda3/envs/tensorflow-1.15.0/lib/python3.7/site-packages/tensorflow_core/python/layers/core.py:187: Layer.apply (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.\nInstructions for updating:\nPlease use `layer.__call__` method instead.\n" + "text": "WARNING:tensorflow:From /data0/ultraman/BERT-Pruning/bert/modeling.py:93: The name tf.gfile.GFile is deprecated. Please use tf.io.gfile.GFile instead.\n\nWARNING:tensorflow:From /data0/ultraman/BERT-Pruning/flop/modeling_flop.py:48: The name tf.variable_scope is deprecated. Please use tf.compat.v1.variable_scope instead.\n\nWARNING:tensorflow:From /data0/ultraman/BERT-Pruning/bert/modeling.py:409: The name tf.get_variable is deprecated. Please use tf.compat.v1.get_variable instead.\n\nWARNING:tensorflow:From /data0/ultraman/BERT-Pruning/bert/modeling.py:490: The name tf.assert_less_equal is deprecated. Please use tf.compat.v1.assert_less_equal instead.\n\nWARNING:tensorflow:From /data0/ultraman/BERT-Pruning/flop/modeling_flop.py:173: dense (from tensorflow.python.layers.core) is deprecated and will be removed in a future version.\nInstructions for updating:\nUse keras.layers.Dense instead.\nWARNING:tensorflow:From /home/ultraman/anaconda3/envs/tensorflow-1.15.0/lib/python3.7/site-packages/tensorflow_core/python/layers/core.py:187: Layer.apply (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.\nInstructions for updating:\nPlease use `layer.__call__` method instead.\nWARNING:tensorflow:From /data0/ultraman/BERT-Pruning/flop/layers.py:83: calling RandomNormal.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.\nInstructions for updating:\nCall initializer instance with the dtype argument instead of passing it to the constructor\n" } ], "source": [ @@ -124,7 +118,7 @@ "source": [ "sess = tf.Session()\n", " # 保存路径\n", - "tenboard_dir = './tensorboard/fractorized/123'\n", + "tenboard_dir = './tensorboard/fractorized/777'\n", "\n", "# 指定一个文件用来保存图\n", "writer = tf.summary.FileWriter(tenboard_dir)\n",