This repository has been archived by the owner on Jul 1, 2024. It is now read-only.
forked from keras-team/keras
-
Notifications
You must be signed in to change notification settings - Fork 65
Add Sparse embedding support #164
Merged
Merged
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
0752908
Update tests
kalyc e3cb730
Add check for values in sparse embedding test
kalyc f76b9f8
Update embedding API support in the layers/embeddings class
kalyc c452537
Add support in Embedding layer class for sparse data
kalyc 8742eec
Remove unused import and variable
kalyc f1b5d23
Add layer_test for sparse embedding
kalyc 120acd7
Add file encoding
kalyc File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
# -*- coding: utf-8 -*- | ||
"""Embedding layer. | ||
""" | ||
from __future__ import absolute_import | ||
|
@@ -59,6 +60,10 @@ class Embedding(Layer): | |
This argument is required if you are going to connect | ||
`Flatten` then `Dense` layers upstream | ||
(without it, the shape of the dense outputs cannot be computed). | ||
sparse_grad: Used only for MXNet backend | ||
When set to True, the gradients’s storage type is row_sparse. | ||
Compute row sparse gradient in the backward calculation. | ||
Refer to: https://mxnet.incubator.apache.org/api/python/symbol/sparse.html#mxnet.symbol.sparse.Embedding | ||
|
||
# Input shape | ||
2D tensor with shape: `(batch_size, sequence_length)`. | ||
|
@@ -78,14 +83,14 @@ def __init__(self, input_dim, output_dim, | |
embeddings_constraint=None, | ||
mask_zero=False, | ||
input_length=None, | ||
sparse_grad=False, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add this in doc string explaining the usage and note it's only for mxnet backend |
||
**kwargs): | ||
if 'input_shape' not in kwargs: | ||
if input_length: | ||
kwargs['input_shape'] = (input_length,) | ||
else: | ||
kwargs['input_shape'] = (None,) | ||
super(Embedding, self).__init__(**kwargs) | ||
|
||
self.input_dim = input_dim | ||
self.output_dim = output_dim | ||
self.embeddings_initializer = initializers.get(embeddings_initializer) | ||
|
@@ -95,6 +100,7 @@ def __init__(self, input_dim, output_dim, | |
self.mask_zero = mask_zero | ||
self.supports_masking = mask_zero | ||
self.input_length = input_length | ||
self.sparse_grad = sparse_grad | ||
|
||
def build(self, input_shape): | ||
self.embeddings = self.add_weight( | ||
|
@@ -140,7 +146,10 @@ def call(self, inputs): | |
# K.gather is not working with Embedding layer using MXNet backend | ||
# Refer to this issue: https://github.com/awslabs/keras-apache-mxnet/issues/63 | ||
if K.backend() == "mxnet": | ||
out = K.embedding(inputs, self.embeddings, self.input_dim, self.output_dim) | ||
if self.sparse_grad: | ||
out = K.embedding(inputs, self.embeddings, self.input_dim, self.output_dim, sparse_grad=self.sparse_grad) | ||
else: | ||
out = K.embedding(inputs, self.embeddings, self.input_dim, self.output_dim) | ||
else: | ||
out = K.gather(self.embeddings, inputs) | ||
return out | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -160,6 +160,5 @@ def test_sparse_concat_axis_non_zero(self): | |
assert k_s_d.shape == k_d.shape | ||
assert_allclose(k_s_d, k_d, atol=1e-05) | ||
|
||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add sparse test here, add a case where sparse_grad is true: https://github.com/awslabs/keras-apache-mxnet/blob/master/tests/keras/layers/embeddings_test.py |
||
if __name__ == '__main__': | ||
pytest.main([__file__]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
K.embedding() is called in keras.layers.Embedding class for mxnet backend. Please also update the function signature there, and test in a end to end example that use Embedding layer. (e.g. examples using imdb dataset)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are passing default value as
False
forsparse_grad
in the API signature so making a change there is not necessary