Skip to content

Commit

Permalink
mul(sparse_csr, sparse_csr) using mul(sparse, sparse)
Browse files Browse the repository at this point in the history
Basic fallback implementation. Let's make this faster once used.

NOTE: This is stacked on top of pytorch#74294
Pull Request resolved: pytorch#74266
Approved by: https://github.com/pearu, https://github.com/malfet
  • Loading branch information
cpuhrsch authored and pytorchmergebot committed Mar 25, 2022
1 parent cd929f4 commit 7fe0b6a
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 1 deletion.
5 changes: 4 additions & 1 deletion aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3173,6 +3173,7 @@
variants: function, method
dispatch:
SparseCPU, SparseCUDA: mul_sparse
SparseCsrCPU, SparseCsrCUDA: mul_sparse_csr
MkldnnCPU: mkldnn_mul
ZeroTensor: mul_zerotensor

Expand All @@ -3182,6 +3183,7 @@
variants: method
dispatch:
SparseCPU, SparseCUDA: mul_sparse_
SparseCsrCPU, SparseCsrCUDA: mul_sparse_csr_
MkldnnCPU: mkldnn_mul_

- func: mul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
Expand All @@ -3192,6 +3194,7 @@
CPU, CUDA: mul_out
SparseCPU: mul_out_sparse_cpu
SparseCUDA: mul_out_sparse_cuda
SparseCsrCPU, SparseCsrCUDA: mul_out_sparse_csr
MkldnnCPU: mkldnn_mul_out

# For C++ only, until we have conversion from C++ numbers to Tensor
Expand Down Expand Up @@ -5035,7 +5038,7 @@

- func: resize_as_sparse_(Tensor(a!) self, Tensor the_template) -> Tensor(a!)
use_const_ref_for_mutable_tensors: True
variants: function
variants: function, method
dispatch:
SparseCPU, SparseCUDA: resize_as_sparse_
SparseCsrCPU, SparseCsrCUDA: resize_as_sparse_csr_
Expand Down
28 changes: 28 additions & 0 deletions aten/src/ATen/native/sparse/SparseTensorMath.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -707,6 +707,34 @@ Tensor& mul_sparse_(Tensor& self, const Tensor& other) {
return at::mul_out(self, self, other); // redispatch!
}

Tensor& mul_out_sparse_csr(const Tensor& t_, const Tensor& src_, Tensor& r) {
// // TODO: Use a specialized CSR kernel for performance if needed
TORCH_CHECK(t_.is_sparse_csr(), "mul(dense, sparse_csr) is not supported");
TORCH_CHECK(src_.is_sparse_csr(), "mul(sparse_csr, dense) is not supported");
TORCH_CHECK(r.is_sparse_csr(), "Expected result Tensor to be of format CSR");
Tensor t = t_.to_sparse();
Tensor src = src_.to_sparse();
Tensor tmp_result = t.mul(src);
auto r_sparse_csr = tmp_result.to_sparse_csr();
r.resize_as_sparse_(r_sparse_csr);
r.copy_(r_sparse_csr);
return r;
}

Tensor mul_sparse_csr(const Tensor& self, const Tensor& other) {
auto commonDtype = at::result_type(self, other);
TORCH_CHECK(self.is_sparse_csr(), "mul(dense, sparse_csr) is not supported");
TORCH_CHECK(other.is_sparse_csr(), "mul(sparse_csr, dense) is not supported");
auto result_options = self.options().dtype(commonDtype);
// CSR is 2d!
Tensor result = at::empty({0, 0}, result_options);
return at::mul_out(result, self, other); // redispatch!
}

Tensor& mul_sparse_csr_(Tensor& self, const Tensor& other) {
return at::mul_out(self, self, other); // redispatch!
}

SparseTensor& mul_out_sparse_cpu(const Tensor& t_, const Tensor& src_, SparseTensor& r) {
if (src_.dim() == 0) {
return mul_out_sparse_zerodim(r, t_, src_);
Expand Down
30 changes: 30 additions & 0 deletions test/test_sparse_csr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1072,6 +1072,36 @@ def _test_spadd_shape(nnz, shape):
_test_spadd_shape(10, [100, 1])
_test_spadd_shape(10, [1, 100])

@dtypes(torch.float, torch.double)
def test_mul(self, device, dtype):
def _test_spadd_shape(fn, nnz, shape):
x = self.genSparseCSRTensor(shape, nnz, dtype=dtype, device=device, index_dtype=torch.int32)
y = self.genSparseCSRTensor(shape, nnz, dtype=dtype, device=device, index_dtype=torch.int32)

res = fn(y, x)
expected = fn(y.to_dense(), x.to_dense()).to_sparse_csr()
self.assertEqual(res, expected)

_test_spadd_shape(torch.mul, 100, [100, 100])
_test_spadd_shape(torch.mul, 0, [100, 100])
_test_spadd_shape(torch.mul, 100, [100, 1])
_test_spadd_shape(torch.mul, 100, [1, 100])

s = torch.sparse_coo_tensor([[0], [1]], [5.0], (2, 3), device=device)
s = s.to_sparse_csr()
t23 = s.to_dense()

if device == 'cpu':
with self.assertRaisesRegex(RuntimeError, r"mul\(sparse_csr, dense\) is not supported"):
s * t23
with self.assertRaisesRegex(RuntimeError, r"mul\(dense, sparse_csr\) is not supported"):
t23 * s
elif device == 'cuda':
with self.assertRaisesRegex(NotImplementedError, "CUDA"):
s * t23
with self.assertRaisesRegex(NotImplementedError, "CUDA"):
t23 * s

@skipCPUIfNoMklSparse
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
def test_sparse_add(self, device, dtype):
Expand Down
1 change: 1 addition & 0 deletions torch/overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -1173,6 +1173,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
Tensor.resize: lambda self, *size: -1,
Tensor.resize_: lambda self, size: -1,
Tensor.resize_as: lambda self, other: -1,
Tensor.resize_as_sparse_: lambda self, other: -1,
Tensor.retain_grad: lambda self: -1,
Tensor.set_: lambda self, source=None, storage_offset=0, size=None, stride=None: -1,
Tensor.select_scatter: lambda self, src, dim, index: -1,
Expand Down

0 comments on commit 7fe0b6a

Please sign in to comment.