diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index cf4fb6c9df1a9..c8262af0617d5 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -3173,6 +3173,7 @@
   variants: function, method
   dispatch:
     SparseCPU, SparseCUDA: mul_sparse
+    SparseCsrCPU, SparseCsrCUDA: mul_sparse_csr
     MkldnnCPU: mkldnn_mul
     ZeroTensor: mul_zerotensor
 
@@ -3182,6 +3183,7 @@
   variants: method
   dispatch:
     SparseCPU, SparseCUDA: mul_sparse_
+    SparseCsrCPU, SparseCsrCUDA: mul_sparse_csr_
     MkldnnCPU: mkldnn_mul_
 
 - func: mul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
@@ -3192,6 +3194,7 @@
     CPU, CUDA: mul_out
     SparseCPU: mul_out_sparse_cpu
     SparseCUDA: mul_out_sparse_cuda
+    SparseCsrCPU, SparseCsrCUDA: mul_out_sparse_csr
     MkldnnCPU: mkldnn_mul_out
 
   # For C++ only, until we have conversion from C++ numbers to Tensor
@@ -5035,7 +5038,7 @@
 
 - func: resize_as_sparse_(Tensor(a!) self, Tensor the_template) -> Tensor(a!)
   use_const_ref_for_mutable_tensors: True
-  variants: function
+  variants: function, method
   dispatch:
     SparseCPU, SparseCUDA: resize_as_sparse_
     SparseCsrCPU, SparseCsrCUDA: resize_as_sparse_csr_
diff --git a/aten/src/ATen/native/sparse/SparseTensorMath.cpp b/aten/src/ATen/native/sparse/SparseTensorMath.cpp
index f98ab775926bb..712a9844be065 100644
--- a/aten/src/ATen/native/sparse/SparseTensorMath.cpp
+++ b/aten/src/ATen/native/sparse/SparseTensorMath.cpp
@@ -707,6 +707,34 @@ Tensor& mul_sparse_(Tensor& self, const Tensor& other) {
   return at::mul_out(self, self, other);  // redispatch!
 }
 
+Tensor& mul_out_sparse_csr(const Tensor& t_, const Tensor& src_, Tensor& r) {
+  // // TODO: Use a specialized CSR kernel for performance if needed
+  TORCH_CHECK(t_.is_sparse_csr(), "mul(dense, sparse_csr) is not supported");
+  TORCH_CHECK(src_.is_sparse_csr(), "mul(sparse_csr, dense) is not supported");
+  TORCH_CHECK(r.is_sparse_csr(), "Expected result Tensor to be of format CSR");
+  Tensor t = t_.to_sparse();
+  Tensor src = src_.to_sparse();
+  Tensor tmp_result = t.mul(src);
+  auto r_sparse_csr = tmp_result.to_sparse_csr();
+  r.resize_as_sparse_(r_sparse_csr);
+  r.copy_(r_sparse_csr);
+  return r;
+}
+
+Tensor mul_sparse_csr(const Tensor& self, const Tensor& other) {
+  auto commonDtype = at::result_type(self, other);
+  TORCH_CHECK(self.is_sparse_csr(), "mul(dense, sparse_csr) is not supported");
+  TORCH_CHECK(other.is_sparse_csr(), "mul(sparse_csr, dense) is not supported");
+  auto result_options = self.options().dtype(commonDtype);
+  // CSR is 2d!
+  Tensor result = at::empty({0, 0}, result_options);
+  return at::mul_out(result, self, other); // redispatch!
+}
+
+Tensor& mul_sparse_csr_(Tensor& self, const Tensor& other) {
+  return at::mul_out(self, self, other); // redispatch!
+}
+
 SparseTensor& mul_out_sparse_cpu(const Tensor& t_, const Tensor& src_, SparseTensor& r) {
   if (src_.dim() == 0) {
     return mul_out_sparse_zerodim(r, t_, src_);
diff --git a/test/test_sparse_csr.py b/test/test_sparse_csr.py
index 7275c61642419..543a570d41d80 100644
--- a/test/test_sparse_csr.py
+++ b/test/test_sparse_csr.py
@@ -1072,6 +1072,36 @@ def _test_spadd_shape(nnz, shape):
         _test_spadd_shape(10, [100, 1])
         _test_spadd_shape(10, [1, 100])
 
+    @dtypes(torch.float, torch.double)
+    def test_mul(self, device, dtype):
+        def _test_spadd_shape(fn, nnz, shape):
+            x = self.genSparseCSRTensor(shape, nnz, dtype=dtype, device=device, index_dtype=torch.int32)
+            y = self.genSparseCSRTensor(shape, nnz, dtype=dtype, device=device, index_dtype=torch.int32)
+
+            res = fn(y, x)
+            expected = fn(y.to_dense(), x.to_dense()).to_sparse_csr()
+            self.assertEqual(res, expected)
+
+        _test_spadd_shape(torch.mul, 100, [100, 100])
+        _test_spadd_shape(torch.mul, 0, [100, 100])
+        _test_spadd_shape(torch.mul, 100, [100, 1])
+        _test_spadd_shape(torch.mul, 100, [1, 100])
+
+        s = torch.sparse_coo_tensor([[0], [1]], [5.0], (2, 3), device=device)
+        s = s.to_sparse_csr()
+        t23 = s.to_dense()
+
+        if device == 'cpu':
+            with self.assertRaisesRegex(RuntimeError, r"mul\(sparse_csr, dense\) is not supported"):
+                s * t23
+            with self.assertRaisesRegex(RuntimeError, r"mul\(dense, sparse_csr\) is not supported"):
+                t23 * s
+        elif device == 'cuda':
+            with self.assertRaisesRegex(NotImplementedError, "CUDA"):
+                s * t23
+            with self.assertRaisesRegex(NotImplementedError, "CUDA"):
+                t23 * s
+
     @skipCPUIfNoMklSparse
     @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
     def test_sparse_add(self, device, dtype):
diff --git a/torch/overrides.py b/torch/overrides.py
index 0438511367aba..010568057385f 100644
--- a/torch/overrides.py
+++ b/torch/overrides.py
@@ -1173,6 +1173,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
         Tensor.resize: lambda self, *size: -1,
         Tensor.resize_: lambda self, size: -1,
         Tensor.resize_as: lambda self, other: -1,
+        Tensor.resize_as_sparse_: lambda self, other: -1,
         Tensor.retain_grad: lambda self: -1,
         Tensor.set_: lambda self, source=None, storage_offset=0, size=None, stride=None: -1,
         Tensor.select_scatter: lambda self, src, dim, index: -1,