-
Notifications
You must be signed in to change notification settings - Fork 31
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
810 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,156 @@ | ||
# Copyright 2021 Alibaba Group Holding Limited. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================= | ||
|
||
r'''Layers for ranking model. | ||
''' | ||
|
||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import math | ||
import tensorflow as tf | ||
|
||
|
||
class DotInteract(tf.layers.Layer): | ||
r'''DLRM: Deep Learning Recommendation Model for Personalization and | ||
Recommendation Systems. | ||
See https://github.com/facebookresearch/dlrm for more information. | ||
''' | ||
def call(self, x): | ||
r'''Call the DLRM dot interact layer. | ||
''' | ||
x2 = tf.matmul(x, x, transpose_b=True) | ||
x2_dim = x2.shape[-1] | ||
x2_ones = tf.ones_like(x2) | ||
x2_mask = tf.linalg.band_part(x2_ones, 0, -1) | ||
y = tf.boolean_mask(x2, x2_ones - x2_mask) | ||
y = tf.reshape(y, [-1, x2_dim * (x2_dim - 1) // 2]) | ||
return y | ||
|
||
|
||
class Cross(tf.layers.Layer): | ||
r'''DCN V2: Improved Deep & Cross Network and Practical Lessons for Web-scale | ||
Learning to Rank Systems. | ||
See https://arxiv.org/abs/2008.13535 for more information. | ||
''' | ||
def call(self, x): | ||
r'''Call the DCN cross layer. | ||
''' | ||
x2 = tf.layers.dense( | ||
x, x.shape[-1], | ||
activation=tf.nn.relu, | ||
kernel_initializer=tf.truncated_normal_initializer(), | ||
bias_initializer=tf.zeros_initializer()) | ||
y = x * x2 + x | ||
y = tf.reshape(y, [-1, x.shape[1] * x.shape[2]]) | ||
return y | ||
|
||
|
||
class Ranking(tf.layers.Layer): | ||
r'''A simple ranking model. | ||
''' | ||
def __init__( | ||
self, | ||
embedding_columns, | ||
bottom_mlp=None, | ||
top_mlp=None, | ||
feature_interaction=None, | ||
**kwargs): | ||
r'''Constructor. | ||
Args: | ||
embedding_columns: List of embedding columns. | ||
bottom_mlp: List of bottom MLP dimensions. | ||
top_mlp: List of top MLP dimensions. | ||
feature_interaction: Feature interaction layer class. | ||
**kwargs: keyword named properties. | ||
''' | ||
super().__init__(**kwargs) | ||
|
||
if bottom_mlp is None: | ||
bottom_mlp = [512, 256, 64] | ||
self.bottom_mlp = bottom_mlp | ||
if top_mlp is None: | ||
top_mlp = [1024, 1024, 512, 256, 1] | ||
self.top_mlp = top_mlp | ||
if feature_interaction is None: | ||
feature_interaction = DotInteract | ||
self.feature_interaction = feature_interaction | ||
self.embedding_columns = embedding_columns | ||
dimensions = {c.dimension for c in embedding_columns} | ||
if len(dimensions) > 1: | ||
raise ValueError('Only one dimension supported') | ||
self.dimension = list(dimensions)[0] | ||
|
||
def call(self, values, embeddings): | ||
r'''Call the dlrm model | ||
''' | ||
with tf.name_scope('bottom_mlp'): | ||
bot_mlp_input = tf.math.log(values + 1.) | ||
for i, d in enumerate(self.bottom_mlp): | ||
bot_mlp_input = tf.layers.dense( | ||
bot_mlp_input, d, | ||
activation=tf.nn.relu, | ||
kernel_initializer=tf.glorot_normal_initializer(), | ||
bias_initializer=tf.random_normal_initializer( | ||
mean=0.0, | ||
stddev=math.sqrt(1.0 / d)), | ||
name=f'bottom_mlp_{i}') | ||
bot_mlp_output = tf.layers.dense( | ||
bot_mlp_input, self.dimension, | ||
activation=tf.nn.relu, | ||
kernel_initializer=tf.glorot_normal_initializer(), | ||
bias_initializer=tf.random_normal_initializer( | ||
mean=0.0, | ||
stddev=math.sqrt(1.0 / self.dimension)), | ||
name='bottom_mlp_output') | ||
|
||
with tf.name_scope('feature_interaction'): | ||
feat_interact_input = tf.concat([bot_mlp_output] + embeddings, axis=-1) | ||
feat_interact_input = tf.reshape( | ||
feat_interact_input, | ||
[-1, 1 + len(embeddings), self.dimension]) | ||
feat_interact_output = self.feature_interaction()(feat_interact_input) | ||
|
||
with tf.name_scope('top_mlp'): | ||
top_mlp_input = tf.concat([bot_mlp_output, feat_interact_output], axis=1) | ||
num_fields = len(self.embedding_columns) | ||
prev_d = (num_fields * (num_fields + 1)) / 2 + self.dimension | ||
for i, d in enumerate(self.top_mlp[:-1]): | ||
top_mlp_input = tf.layers.dense( | ||
top_mlp_input, d, | ||
activation=tf.nn.relu, | ||
kernel_initializer=tf.random_normal_initializer( | ||
mean=0.0, | ||
stddev=math.sqrt(2.0 / (prev_d + d))), | ||
bias_initializer=tf.random_normal_initializer( | ||
mean=0.0, | ||
stddev=math.sqrt(1.0 / d)), | ||
name=f'top_mlp_{i}') | ||
prev_d = d | ||
top_mlp_output = tf.layers.dense( | ||
top_mlp_input, self.top_mlp[-1], | ||
activation=tf.nn.sigmoid, | ||
kernel_initializer=tf.random_normal_initializer( | ||
mean=0.0, | ||
stddev=math.sqrt(2.0 / (prev_d + self.top_mlp[-1]))), | ||
bias_initializer=tf.random_normal_initializer( | ||
mean=0.0, | ||
stddev=math.sqrt(1.0 / self.top_mlp[-1])), | ||
name=f'top_mlp_{len(self.top_mlp) - 1}') | ||
return top_mlp_output |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
# Copyright 2021 Alibaba Group Holding Limited. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================= | ||
|
||
r'''Functions for optimization | ||
''' | ||
|
||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import tensorflow as tf | ||
|
||
|
||
def lr_with_linear_warmup_and_polynomial_decay( | ||
global_step, | ||
initial_value=24., | ||
scaling_factor=1., | ||
warmup_steps=None, | ||
decay_steps=None, | ||
decay_start_step=None, | ||
decay_exp=2, | ||
epsilon=1.e-7): | ||
r'''Calculates learning rate with linear warmup and polynomial decay. | ||
Args: | ||
global_step: Variable representing the current step. | ||
initial_value: Initial value of learning rates. | ||
warmup_steps: Steps of warmup. | ||
decay_steps: Steps of decay. | ||
decay_start_step: Start step of decay. | ||
decay_exp: Exponent part of decay. | ||
scaling_factor: Factor for scaling. | ||
Returns: | ||
New learning rate tensor. | ||
''' | ||
initial_lr = tf.constant(initial_value * scaling_factor, tf.float32) | ||
|
||
if warmup_steps is None: | ||
return initial_lr | ||
|
||
global_step = tf.cast(global_step, tf.float32) | ||
warmup_steps = tf.constant(warmup_steps, tf.float32) | ||
warmup_rate = initial_lr / warmup_steps | ||
warmup_lr = initial_lr - (warmup_steps - global_step) * warmup_rate | ||
|
||
if decay_steps is None or decay_start_step is None: | ||
return warmup_lr | ||
|
||
decay_start_step = tf.constant(decay_start_step, tf.float32) | ||
steps_since_decay_start = global_step - decay_start_step | ||
decay_steps = tf.constant(decay_steps, tf.float32) | ||
decayed_steps = tf.minimum(steps_since_decay_start, decay_steps) | ||
to_decay_rate = (decay_steps - decayed_steps) / decay_steps | ||
decay_lr = initial_lr * to_decay_rate**decay_exp | ||
decay_lr = tf.maximum(decay_lr, tf.constant(epsilon)) | ||
|
||
warmup_lambda = tf.cast(global_step < warmup_steps, tf.float32) | ||
decay_lambda = tf.cast(global_step > decay_start_step, tf.float32) | ||
initial_lambda = tf.cast( | ||
tf.math.abs(warmup_lambda + decay_lambda) < epsilon, tf.float32) | ||
|
||
lr = warmup_lambda * warmup_lr | ||
lr += decay_lambda * decay_lr | ||
lr += initial_lambda * initial_lr | ||
return lr | ||
|
||
|
||
def sgd_decay_optimize( | ||
loss, | ||
lr_initial_value, | ||
lr_warmup_steps, | ||
lr_decay_start_step, | ||
lr_decay_steps): | ||
r'''Optimize using SGD and learning rate decay. | ||
''' | ||
step = tf.train.get_or_create_global_step() | ||
lr = lr_with_linear_warmup_and_polynomial_decay( | ||
step, | ||
initial_value=lr_initial_value, | ||
warmup_steps=lr_warmup_steps, | ||
decay_start_step=lr_decay_start_step, | ||
decay_steps=lr_decay_steps) | ||
opt = tf.train.GradientDescentOptimizer(learning_rate=lr) | ||
return opt.minimize(loss, global_step=step) |
Oops, something went wrong.