-
Notifications
You must be signed in to change notification settings - Fork 65
Conversation
k_S = K.embedding(test_sparse_data, test_sparse_weight, 4, 5) | ||
k_D = K.embedding(test_dense_data, test_dense_weight, 4, 5) | ||
|
||
assert k_S.shape == k_D.shape |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
value verification?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added, also see related issue about using contrib API - apache/mxnet#12465
keras/backend/mxnet_backend.py
Outdated
# Use mxnet.sym.contrib.SparseEmbedding API - https://mxnet.apache.org/api/python/symbol/contrib.html | ||
sym = mx.sym.contrib.SparseEmbedding(data, weight=weight, input_dim=input_dim, output_dim=output_dim, | ||
deterministic=True) | ||
sym = mx.sym.Embedding(data, weight=weight, input_dim=input_dim, output_dim=output_dim) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if sparse_grad, you overwrite the sym
here.
@@ -104,6 +105,43 @@ def test_sparse_dot(self): | |||
assert k_s.shape == k_d.shape | |||
assert_allclose(k_s, k_d, atol=1e-05) | |||
|
|||
def _forward_pass(self, x): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: probably rename to get_value() or get_data() something like that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for your contribution! nice to see this progress!
One concern here:
K.embedding() in mxnet_backend.py
is only used in keras layers Embedding class. It's not used anywhere else. So to test its functionality, you need to test on Embedding class instead of this K.embedding single operator, and use the param sparse_grad
in Embedding layer. The example usage from keras users can be as following:
from keras.layers import Input, Embedding, LSTM, Dense
from keras.models import Model
input = Input(..., sparse=True, ...)
embedding = Embedding(..., ... ) (input)
x = Dense(...)(embedding)
predictions = Dense(...)(x)
model = Model(inputs=inputs, outputs=predictions)
model.compile(...)
model.fit(...)
see reference: https://keras.io/getting-started/functional-api-guide/
Also there are many example using Embedding layer class(imdb examples). Sparse embedding should produce same result as normal embedding.
@@ -1202,13 +1202,18 @@ def gather(reference, indices): | |||
|
|||
|
|||
@keras_mxnet_symbol | |||
def embedding(data, weight, input_dim, output_dim): | |||
def embedding(data, weight, input_dim, output_dim, sparse_grad=False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
K.embedding() is called in keras.layers.Embedding class for mxnet backend. Please also update the function signature there, and test in a end to end example that use Embedding layer. (e.g. examples using imdb dataset)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are passing default value as False
for sparse_grad
in the API signature so making a change there is not necessary
outputs = executor.forward(is_train=K.learning_phase()) | ||
return outputs | ||
|
||
def test_sparse_embedding(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add a test similar to tests/keras/layers/embedding_test.py? use layer_test to test Embedding layer class
As per this issue - we will need to wait for MXNet v1.3 to release to be able to use the new API signature of |
@kalyc - Can we move ahead with this, as we discussed to use mxnet --preview package? |
32423b9
to
ca183dc
Compare
Updated PR and tested with end-to-end Model -
Result -
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the contribution, additional few comments
keras/layers/embeddings.py
Outdated
@@ -140,7 +141,10 @@ def call(self, inputs): | |||
# K.gather is not working with Embedding layer using MXNet backend | |||
# Refer to this issue: https://github.com/awslabs/keras-apache-mxnet/issues/63 | |||
if K.backend() == "mxnet": | |||
out = K.embedding(inputs, self.embeddings, self.input_dim, self.output_dim) | |||
if self.sparse_grad: | |||
out = K.embedding(inputs, self.embeddings, self.input_dim, self.output_dim, sparse_grad=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is it alway True? sparse_grad=self. sparse_grad ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the condition will pass only when self.sparse_grad=True, updated the function call for more clarity
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we simplify to not use if else:
out = K.embedding(inputs, self.embeddings, self.input_dim, self.output_dim, sparse_grad=self.sparse_grad)
@@ -78,14 +78,14 @@ def __init__(self, input_dim, output_dim, | |||
embeddings_constraint=None, | |||
mask_zero=False, | |||
input_length=None, | |||
sparse_grad=False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add this in doc string explaining the usage and note it's only for mxnet backend
@@ -160,6 +160,5 @@ def test_sparse_concat_axis_non_zero(self): | |||
assert k_s_d.shape == k_d.shape | |||
assert_allclose(k_s_d, k_d, atol=1e-05) | |||
|
|||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add sparse test here, add a case where sparse_grad is true: https://github.com/awslabs/keras-apache-mxnet/blob/master/tests/keras/layers/embeddings_test.py
47887bf
to
f1b5d23
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks!
Summary
Add minimal test for testing sparse embedding operator support
Related Issues
Missing Sparse operator support
PR Overview