-
Notifications
You must be signed in to change notification settings - Fork 65
Add sparse sum, mean and dot operator support #162
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the contribution, few comments in line, and one additional comment:
- can we test this on an end to end model, for example the one in sparse benchmark scripts?
tests/keras/backend/backend_test.py
Outdated
@@ -1702,6 +1702,107 @@ def test_sparse_concat(self): | |||
assert k_s_d.shape == k_d.shape | |||
assert_allclose(k_s_d, k_d, atol=1e-05) | |||
|
|||
@pytest.mark.skipif((K.backend() != 'mxnet'), | |||
reason='Testing only for MXNet backend') | |||
def test_sparse_sum(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why keep a duplicate test here? We can just use the separated file for testing as @sandeep-krishnamurthy suggested.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My bad - looks like I didn't commit the changed backend_test file. Will update. Have moved the sparse tests to mxnet_sparse_test file
keras/backend/mxnet_backend.py
Outdated
@@ -576,6 +594,15 @@ def eval(x): | |||
return x | |||
|
|||
|
|||
def _forward_pass(x): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggest to move this function together with other internal helper functions if plan to reuse it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
raise NotImplementedError('MXNet Backend: Sparse operations are not supported yet.') | ||
if hasattr(tensor, 'tocoo'): | ||
return tensor.toarray() | ||
elif isinstance(tensor, mx.sym.Symbol): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need to check for mxnet symbol? mxnet symbol only exists in mxnet_backend file. From a keras user perspective, he/she will not use a mxnet symbol during model construction/training
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I checked for all possible tensor data structures that Keras-MXNet supports in theis_tensor
method
Addressed comments. Testing on benchmark model not yet done. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, looking forward to see a sparse example.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work!.
Few suggested changes. Please fix before merging.
keras/backend/mxnet_backend.py
Outdated
if isinstance(_forward_pass(tensor)[0], mx.ndarray.sparse.CSRNDArray) or \ | ||
isinstance(_forward_pass(tensor)[0], mx.ndarray.sparse.RowSparseNDArray): | ||
return True | ||
elif hasattr(tensor, 'tocoo'): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: but a very minor perf improvement if you swap these if elif . hasattr(..) faster than isinstance, forward_pass etc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
keras/backend/mxnet_backend.py
Outdated
sym._keras_shape = tuple([d if d != 0 else None for d in shape]) | ||
sym._mxnet_placeholder = True | ||
sym._uses_learning_phase = False | ||
print(sym) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you missed to remove this print.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed
|
||
class TestMXNetSparse(object): | ||
|
||
@pytest.mark.skipif((K.backend() != 'mxnet'), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can skip all tests in a file this way - https://github.com/awslabs/keras-apache-mxnet/blob/master/tests/keras/layers/wrappers_test.py#L15
and avoid skipif for each test.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
x_r = np.array([0, 2, 2, 3], dtype=np.int64) | ||
x_c = np.array([4, 3, 2, 3], dtype=np.int64) | ||
|
||
x_sparse_matrix = sparse.csr_matrix((x_d, (x_r, x_c)), shape=(4, 5)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lines 25 to 29 repeat across this file 8 times or so.
Doesn't it make sense to extract a common generate_test_sparse_tensor
function instead or repeating the same code block?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
refactored
Addressed refactoring comments |
Thanks. LGTM. Going ahead with merging these operators. Please follow up with an end to end example using sparse tensor (may be think of having this end to end example work for your benchmarking as well) |
Yes will do - we will need to update the existing benchmark script to use these operators as well. |
Summary
Add sparse support for
sum
,mean
anddot
operatorsRelated Issues
Continuing - #159
Missing sparse operators
PR Overview