diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index cf4fb6c9df1a9..c8262af0617d5 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -3173,6 +3173,7 @@ variants: function, method dispatch: SparseCPU, SparseCUDA: mul_sparse + SparseCsrCPU, SparseCsrCUDA: mul_sparse_csr MkldnnCPU: mkldnn_mul ZeroTensor: mul_zerotensor @@ -3182,6 +3183,7 @@ variants: method dispatch: SparseCPU, SparseCUDA: mul_sparse_ + SparseCsrCPU, SparseCsrCUDA: mul_sparse_csr_ MkldnnCPU: mkldnn_mul_ - func: mul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) @@ -3192,6 +3194,7 @@ CPU, CUDA: mul_out SparseCPU: mul_out_sparse_cpu SparseCUDA: mul_out_sparse_cuda + SparseCsrCPU, SparseCsrCUDA: mul_out_sparse_csr MkldnnCPU: mkldnn_mul_out # For C++ only, until we have conversion from C++ numbers to Tensor @@ -5035,7 +5038,7 @@ - func: resize_as_sparse_(Tensor(a!) self, Tensor the_template) -> Tensor(a!) use_const_ref_for_mutable_tensors: True - variants: function + variants: function, method dispatch: SparseCPU, SparseCUDA: resize_as_sparse_ SparseCsrCPU, SparseCsrCUDA: resize_as_sparse_csr_ diff --git a/aten/src/ATen/native/sparse/SparseTensorMath.cpp b/aten/src/ATen/native/sparse/SparseTensorMath.cpp index f98ab775926bb..712a9844be065 100644 --- a/aten/src/ATen/native/sparse/SparseTensorMath.cpp +++ b/aten/src/ATen/native/sparse/SparseTensorMath.cpp @@ -707,6 +707,34 @@ Tensor& mul_sparse_(Tensor& self, const Tensor& other) { return at::mul_out(self, self, other); // redispatch! } +Tensor& mul_out_sparse_csr(const Tensor& t_, const Tensor& src_, Tensor& r) { + // // TODO: Use a specialized CSR kernel for performance if needed + TORCH_CHECK(t_.is_sparse_csr(), "mul(dense, sparse_csr) is not supported"); + TORCH_CHECK(src_.is_sparse_csr(), "mul(sparse_csr, dense) is not supported"); + TORCH_CHECK(r.is_sparse_csr(), "Expected result Tensor to be of format CSR"); + Tensor t = t_.to_sparse(); + Tensor src = src_.to_sparse(); + Tensor tmp_result = t.mul(src); + auto r_sparse_csr = tmp_result.to_sparse_csr(); + r.resize_as_sparse_(r_sparse_csr); + r.copy_(r_sparse_csr); + return r; +} + +Tensor mul_sparse_csr(const Tensor& self, const Tensor& other) { + auto commonDtype = at::result_type(self, other); + TORCH_CHECK(self.is_sparse_csr(), "mul(dense, sparse_csr) is not supported"); + TORCH_CHECK(other.is_sparse_csr(), "mul(sparse_csr, dense) is not supported"); + auto result_options = self.options().dtype(commonDtype); + // CSR is 2d! + Tensor result = at::empty({0, 0}, result_options); + return at::mul_out(result, self, other); // redispatch! +} + +Tensor& mul_sparse_csr_(Tensor& self, const Tensor& other) { + return at::mul_out(self, self, other); // redispatch! +} + SparseTensor& mul_out_sparse_cpu(const Tensor& t_, const Tensor& src_, SparseTensor& r) { if (src_.dim() == 0) { return mul_out_sparse_zerodim(r, t_, src_); diff --git a/test/test_sparse_csr.py b/test/test_sparse_csr.py index 7275c61642419..543a570d41d80 100644 --- a/test/test_sparse_csr.py +++ b/test/test_sparse_csr.py @@ -1072,6 +1072,36 @@ def _test_spadd_shape(nnz, shape): _test_spadd_shape(10, [100, 1]) _test_spadd_shape(10, [1, 100]) + @dtypes(torch.float, torch.double) + def test_mul(self, device, dtype): + def _test_spadd_shape(fn, nnz, shape): + x = self.genSparseCSRTensor(shape, nnz, dtype=dtype, device=device, index_dtype=torch.int32) + y = self.genSparseCSRTensor(shape, nnz, dtype=dtype, device=device, index_dtype=torch.int32) + + res = fn(y, x) + expected = fn(y.to_dense(), x.to_dense()).to_sparse_csr() + self.assertEqual(res, expected) + + _test_spadd_shape(torch.mul, 100, [100, 100]) + _test_spadd_shape(torch.mul, 0, [100, 100]) + _test_spadd_shape(torch.mul, 100, [100, 1]) + _test_spadd_shape(torch.mul, 100, [1, 100]) + + s = torch.sparse_coo_tensor([[0], [1]], [5.0], (2, 3), device=device) + s = s.to_sparse_csr() + t23 = s.to_dense() + + if device == 'cpu': + with self.assertRaisesRegex(RuntimeError, r"mul\(sparse_csr, dense\) is not supported"): + s * t23 + with self.assertRaisesRegex(RuntimeError, r"mul\(dense, sparse_csr\) is not supported"): + t23 * s + elif device == 'cuda': + with self.assertRaisesRegex(NotImplementedError, "CUDA"): + s * t23 + with self.assertRaisesRegex(NotImplementedError, "CUDA"): + t23 * s + @skipCPUIfNoMklSparse @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) def test_sparse_add(self, device, dtype): diff --git a/torch/overrides.py b/torch/overrides.py index 0438511367aba..010568057385f 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -1173,6 +1173,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]: Tensor.resize: lambda self, *size: -1, Tensor.resize_: lambda self, size: -1, Tensor.resize_as: lambda self, other: -1, + Tensor.resize_as_sparse_: lambda self, other: -1, Tensor.retain_grad: lambda self: -1, Tensor.set_: lambda self, source=None, storage_offset=0, size=None, stride=None: -1, Tensor.select_scatter: lambda self, src, dim, index: -1,