Skip to content
This repository has been archived by the owner on Jul 1, 2024. It is now read-only.

Add sparse sum, mean and dot operator support #162

Merged
merged 8 commits into from
Sep 5, 2018

Conversation

kalyc
Copy link

@kalyc kalyc commented Aug 28, 2018

Summary

Add sparse support for sum, mean and dot operators

Related Issues

Continuing - #159
Missing sparse operators

PR Overview

  • [y] This PR requires new unit tests [y/n] (make sure tests are included)
  • [n] This PR requires to update the documentation [y/n] (make sure the docs are up-to-date)
  • [y] This PR is backwards compatible [y/n]
  • [n] This PR changes the current API [y/n]

Copy link

@roywei roywei left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the contribution, few comments in line, and one additional comment:

  • can we test this on an end to end model, for example the one in sparse benchmark scripts?

@@ -1702,6 +1702,107 @@ def test_sparse_concat(self):
assert k_s_d.shape == k_d.shape
assert_allclose(k_s_d, k_d, atol=1e-05)

@pytest.mark.skipif((K.backend() != 'mxnet'),
reason='Testing only for MXNet backend')
def test_sparse_sum(self):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why keep a duplicate test here? We can just use the separated file for testing as @sandeep-krishnamurthy suggested.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My bad - looks like I didn't commit the changed backend_test file. Will update. Have moved the sparse tests to mxnet_sparse_test file

@@ -576,6 +594,15 @@ def eval(x):
return x


def _forward_pass(x):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggest to move this function together with other internal helper functions if plan to reuse it.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

raise NotImplementedError('MXNet Backend: Sparse operations are not supported yet.')
if hasattr(tensor, 'tocoo'):
return tensor.toarray()
elif isinstance(tensor, mx.sym.Symbol):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to check for mxnet symbol? mxnet symbol only exists in mxnet_backend file. From a keras user perspective, he/she will not use a mxnet symbol during model construction/training

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked for all possible tensor data structures that Keras-MXNet supports in theis_tensor method

@kalyc
Copy link
Author

kalyc commented Aug 29, 2018

Addressed comments. Testing on benchmark model not yet done.

Copy link

@roywei roywei left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, looking forward to see a sparse example.

Copy link

@sandeep-krishnamurthy sandeep-krishnamurthy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work!.
Few suggested changes. Please fix before merging.

if isinstance(_forward_pass(tensor)[0], mx.ndarray.sparse.CSRNDArray) or \
isinstance(_forward_pass(tensor)[0], mx.ndarray.sparse.RowSparseNDArray):
return True
elif hasattr(tensor, 'tocoo'):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: but a very minor perf improvement if you swap these if elif . hasattr(..) faster than isinstance, forward_pass etc.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

sym._keras_shape = tuple([d if d != 0 else None for d in shape])
sym._mxnet_placeholder = True
sym._uses_learning_phase = False
print(sym)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you missed to remove this print.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed


class TestMXNetSparse(object):

@pytest.mark.skipif((K.backend() != 'mxnet'),

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can skip all tests in a file this way - https://github.com/awslabs/keras-apache-mxnet/blob/master/tests/keras/layers/wrappers_test.py#L15
and avoid skipif for each test.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

x_r = np.array([0, 2, 2, 3], dtype=np.int64)
x_c = np.array([4, 3, 2, 3], dtype=np.int64)

x_sparse_matrix = sparse.csr_matrix((x_d, (x_r, x_c)), shape=(4, 5))
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lines 25 to 29 repeat across this file 8 times or so.
Doesn't it make sense to extract a common generate_test_sparse_tensor function instead or repeating the same code block?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

refactored

@kalyc
Copy link
Author

kalyc commented Sep 5, 2018

Addressed refactoring comments

@sandeep-krishnamurthy
Copy link

Thanks. LGTM. Going ahead with merging these operators. Please follow up with an end to end example using sparse tensor (may be think of having this end to end example work for your benchmarking as well)

@sandeep-krishnamurthy sandeep-krishnamurthy merged commit d961bf3 into awslabs:dev Sep 5, 2018
@kalyc
Copy link
Author

kalyc commented Sep 5, 2018

Yes will do - we will need to update the existing benchmark script to use these operators as well.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants