Skip to content
This repository has been archived by the owner on Jul 1, 2024. It is now read-only.

Add Sparse embedding support #164

Merged
merged 7 commits into from
Sep 28, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 22 additions & 6 deletions keras/backend/mxnet_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
from __future__ import print_function

import warnings
from collections import defaultdict
from functools import wraps
from numbers import Number
from subprocess import CalledProcessError

import mxnet as mx
import numpy as np
from subprocess import CalledProcessError
from numbers import Number
from functools import wraps
from collections import defaultdict

from .common import floatx, epsilon, image_data_format

Expand Down Expand Up @@ -1203,12 +1204,16 @@ def gather(reference, indices):


@keras_mxnet_symbol
def embedding(data, weight, input_dim, output_dim):
def embedding(data, weight, input_dim, output_dim, sparse_grad=False):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

K.embedding() is called in keras.layers.Embedding class for mxnet backend. Please also update the function signature there, and test in a end to end example that use Embedding layer. (e.g. examples using imdb dataset)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are passing default value as False for sparse_grad in the API signature so making a change there is not necessary

# check if inputs are KerasSymbol
if isinstance(data, KerasSymbol):
data = data.symbol
if isinstance(weight, KerasSymbol):
weight = weight.symbol
if sparse_grad:
# Refer https://mxnet.incubator.apache.org/api/python/symbol/sparse.html#mxnet.symbol.sparse.Embedding
return KerasSymbol(mx.sym.Embedding(data, weight=weight, input_dim=input_dim, output_dim=output_dim,
sparse_grad=True))
return KerasSymbol(mx.sym.Embedding(data, weight=weight, input_dim=input_dim, output_dim=output_dim))


Expand Down Expand Up @@ -2693,7 +2698,8 @@ def rnn(step_function, inputs, initial_states,
warnings.warn('MXNet Backend: `unroll=False` is not supported yet in RNN. Since the input_shape is known, '
'setting `unroll=True` and continuing the execution.'
'More Details - '
'https://github.com/awslabs/keras-apache-mxnet/tree/master/docs/mxnet_backend/using_rnn_with_mxnet_backend.md', # nopep8
'https://github.com/awslabs/keras-apache-mxnet/tree/master/docs/mxnet_backend/using_rnn_with_mxnet_backend.md',
# nopep8
stacklevel=2) # nopep8

# Split the inputs across time dimension and generate the list of inputs
Expand Down Expand Up @@ -4836,6 +4842,7 @@ class Model(engine.Model):
"""The `Model` class adds training & evaluation routines to a `Network`. This class extends
keras.engine.Model to add MXNet Module to perform training and inference with MXNet backend.
"""

def __init__(self, *args, **kwargs):
if 'name' not in kwargs:
prefix = self.__class__.__name__.lower()
Expand Down Expand Up @@ -5225,6 +5232,7 @@ class Sequential(sequential.Sequential, engine.Model):
"""Linear stack of layers. This class extends keras.engine.Sequential to add MXNet Module to perform training
and inference with MXNet backend.
"""

def __init__(self, layers=None, *args, **kwargs):
if 'name' not in kwargs:
prefix = self.__class__.__name__.lower()
Expand All @@ -5251,6 +5259,7 @@ class MXOptimizer(optimizers.Optimizer, mx.optimizer.Optimizer):
This is required because we cannot use Keras optimizer directly as MXNet backend does not
support symbolic optimizers.
"""

def __init__(self, lr, decay):
super(MXOptimizer, self).__init__()
self.lr = variable(lr)
Expand Down Expand Up @@ -5278,6 +5287,7 @@ class SGD(MXOptimizer, mx.optimizer.SGD):
decay: float >= 0. Learning rate decay over each update.
nesterov: boolean. Whether to apply Nesterov momentum.
"""

def __init__(self, lr=0.01, momentum=0., decay=0.,
nesterov=False, clipnorm=None, **kwargs):
mx.optimizer.SGD.__init__(self, learning_rate=lr, momentum=momentum, clip_gradient=clipnorm, **kwargs)
Expand Down Expand Up @@ -5309,6 +5319,7 @@ class Adagrad(MXOptimizer, mx.optimizer.AdaGrad):
# References
- [Adaptive Subgradient Methods for Online Learning and Stochastic Optimization](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) # nopep8
"""

def __init__(self, lr=0.01, epsilon=1e-8, decay=0., clipnorm=None, **kwargs):
mx.optimizer.AdaGrad.__init__(self, learning_rate=lr, eps=epsilon, clip_gradient=clipnorm, **kwargs)
MXOptimizer.__init__(self, lr, decay)
Expand Down Expand Up @@ -5345,6 +5356,7 @@ class Adadelta(MXOptimizer, mx.optimizer.AdaDelta):
# References
- [Adadelta - an adaptive learning rate method](http://arxiv.org/abs/1212.5701)
"""

def __init__(self, lr=1.0, rho=0.95, epsilon=1e-8, decay=0., clipnorm=None, **kwargs):
mx.optimizer.AdaDelta.__init__(self, rho=rho, epsilon=epsilon, clip_gradient=clipnorm, **kwargs)
MXOptimizer.__init__(self, lr, decay)
Expand Down Expand Up @@ -5376,6 +5388,7 @@ class Adam(MXOptimizer, mx.optimizer.Adam):
- [Adam - A Method for Stochastic Optimization](http://arxiv.org/abs/1412.6980v8)
- [On the Convergence of Adam and Beyond](https://openreview.net/forum?id=ryQu7f-RZ)
"""

def __init__(self, lr=0.001, beta_1=0.9, beta_2=0.999,
epsilon=1e-8, decay=0., clipnorm=None, **kwargs):
mx.optimizer.Adam.__init__(self, learning_rate=lr, beta1=beta_1, beta2=beta_2,
Expand Down Expand Up @@ -5406,6 +5419,7 @@ class Adamax(MXOptimizer, mx.optimizer.Adamax):
# References
- [Adam - A Method for Stochastic Optimization](http://arxiv.org/abs/1412.6980v8)
"""

def __init__(self, lr=0.002, beta_1=0.9, beta_2=0.999, decay=0., clipnorm=None,
epsilon=1e-8, **kwargs):
mx.optimizer.Adamax.__init__(self, learning_rate=lr, beta1=beta_1, beta2=beta_2,
Expand Down Expand Up @@ -5441,6 +5455,7 @@ class Nadam(MXOptimizer, mx.optimizer.Nadam):
- [Nadam report](http://cs229.stanford.edu/proj2015/054_report.pdf)
- [On the importance of initialization and momentum in deep learning](http://www.cs.toronto.edu/~fritz/absps/momentum.pdf) # nopep8
"""

def __init__(self, lr=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-8, decay=0., clipnorm=None,
schedule_decay=0.004, **kwargs):
mx.optimizer.Nadam.__init__(self, learning_rate=lr, beta1=beta_1, beta2=beta_2, epsilon=epsilon,
Expand Down Expand Up @@ -5475,6 +5490,7 @@ class RMSprop(MXOptimizer, mx.optimizer.RMSProp):
# References
- [rmsprop: Divide the gradient by a running average of its recent magnitude](http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf) # nopep8
"""

def __init__(self, lr=0.001, rho=0.9, epsilon=1e-8, decay=0., clipnorm=None, **kwargs):
mx.optimizer.RMSProp.__init__(self, learning_rate=lr, gamma1=rho, epsilon=epsilon,
clip_gradient=clipnorm, **kwargs)
Expand Down
13 changes: 11 additions & 2 deletions keras/layers/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# -*- coding: utf-8 -*-
"""Embedding layer.
"""
from __future__ import absolute_import
Expand Down Expand Up @@ -59,6 +60,10 @@ class Embedding(Layer):
This argument is required if you are going to connect
`Flatten` then `Dense` layers upstream
(without it, the shape of the dense outputs cannot be computed).
sparse_grad: Used only for MXNet backend
When set to True, the gradients’s storage type is row_sparse.
Compute row sparse gradient in the backward calculation.
Refer to: https://mxnet.incubator.apache.org/api/python/symbol/sparse.html#mxnet.symbol.sparse.Embedding

# Input shape
2D tensor with shape: `(batch_size, sequence_length)`.
Expand All @@ -78,14 +83,14 @@ def __init__(self, input_dim, output_dim,
embeddings_constraint=None,
mask_zero=False,
input_length=None,
sparse_grad=False,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add this in doc string explaining the usage and note it's only for mxnet backend

**kwargs):
if 'input_shape' not in kwargs:
if input_length:
kwargs['input_shape'] = (input_length,)
else:
kwargs['input_shape'] = (None,)
super(Embedding, self).__init__(**kwargs)

self.input_dim = input_dim
self.output_dim = output_dim
self.embeddings_initializer = initializers.get(embeddings_initializer)
Expand All @@ -95,6 +100,7 @@ def __init__(self, input_dim, output_dim,
self.mask_zero = mask_zero
self.supports_masking = mask_zero
self.input_length = input_length
self.sparse_grad = sparse_grad

def build(self, input_shape):
self.embeddings = self.add_weight(
Expand Down Expand Up @@ -140,7 +146,10 @@ def call(self, inputs):
# K.gather is not working with Embedding layer using MXNet backend
# Refer to this issue: https://github.com/awslabs/keras-apache-mxnet/issues/63
if K.backend() == "mxnet":
out = K.embedding(inputs, self.embeddings, self.input_dim, self.output_dim)
if self.sparse_grad:
out = K.embedding(inputs, self.embeddings, self.input_dim, self.output_dim, sparse_grad=self.sparse_grad)
else:
out = K.embedding(inputs, self.embeddings, self.input_dim, self.output_dim)
else:
out = K.gather(self.embeddings, inputs)
return out
Expand Down
1 change: 0 additions & 1 deletion tests/keras/backend/mxnet_sparse_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,5 @@ def test_sparse_concat_axis_non_zero(self):
assert k_s_d.shape == k_d.shape
assert_allclose(k_s_d, k_d, atol=1e-05)


Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add sparse test here, add a case where sparse_grad is true: https://github.com/awslabs/keras-apache-mxnet/blob/master/tests/keras/layers/embeddings_test.py

if __name__ == '__main__':
pytest.main([__file__])
7 changes: 7 additions & 0 deletions tests/keras/layers/embeddings_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,13 @@ def test_embedding():
input_shape=(3, 2, 5),
input_dtype='int32',
expected_output_dtype=K.floatx())
layer_test(Embedding,
kwargs={'output_dim': 4, 'input_dim': 10, 'mask_zero': True, 'input_length': (None, 5),
'sparse_grad': True},
input_shape=(3, 2, 5),
input_dtype='int32',
expected_output_dtype=K.floatx()
)


if __name__ == '__main__':
Expand Down