Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sparse element-wise multiplication #323

Merged
merged 5 commits into from
Apr 28, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions test/test_mul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from itertools import product

import pytest
import torch

from torch_sparse import SparseTensor, mul
from torch_sparse.testing import devices, dtypes, tensor


@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_sparse_sparse_mul(dtype, device):
rowA = torch.tensor([0, 0, 1, 2, 2], device=device)
colA = torch.tensor([0, 2, 1, 0, 1], device=device)
valueA = tensor([1, 2, 4, 1, 3], dtype, device)
A = SparseTensor(row=rowA, col=colA, value=valueA)

rowB = torch.tensor([0, 0, 1, 2, 2], device=device)
colB = torch.tensor([1, 2, 2, 1, 2], device=device)
valueB = tensor([2, 3, 1, 2, 4], dtype, device)
B = SparseTensor(row=rowB, col=colB, value=valueB)

C = A * B
rowC, colC, valueC = C.coo()

assert rowC.tolist() == [0, 2]
assert colC.tolist() == [2, 1]
assert valueC.tolist() == [6, 6]

@torch.jit.script
def jit_mul(A: SparseTensor, B: SparseTensor) -> SparseTensor:
return mul(A, B)

jit_mul(A, B)


@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_sparse_sparse_mul_empty(dtype, device):
rowA = torch.tensor([0], device=device)
colA = torch.tensor([1], device=device)
valueA = tensor([1], dtype, device)
A = SparseTensor(row=rowA, col=colA, value=valueA)

rowB = torch.tensor([1], device=device)
colB = torch.tensor([0], device=device)
valueB = tensor([2], dtype, device)
B = SparseTensor(row=rowB, col=colB, value=valueB)

C = A * B
rowC, colC, valueC = C.coo()

assert rowC.tolist() == []
assert colC.tolist() == []
assert valueC.tolist() == []
100 changes: 81 additions & 19 deletions torch_sparse/mul.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,83 @@
from typing import Optional

import torch
from torch import Tensor
from torch_scatter import gather_csr

from torch_sparse.tensor import SparseTensor


def mul(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
rowptr, col, value = src.csr()
if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise...
other = gather_csr(other.squeeze(1), rowptr)
pass
elif other.size(0) == 1 and other.size(1) == src.size(1): # Col-wise...
other = other.squeeze(0)[col]
else:
raise ValueError(
f'Size mismatch: Expected size ({src.size(0)}, 1, ...) or '
f'(1, {src.size(1)}, ...), but got size {other.size()}.')
@torch.jit._overload # noqa: F811
def mul(src, other): # noqa: F811
# type: (SparseTensor, Tensor) -> SparseTensor
pass

if value is not None:
value = other.to(value.dtype).mul_(value)

@torch.jit._overload # noqa: F811
def mul(src, other): # noqa: F811
# type: (SparseTensor, SparseTensor) -> SparseTensor
pass


def mul(src, other): # noqa: F811
if isinstance(other, Tensor):
rowptr, col, value = src.csr()
if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise...
other = gather_csr(other.squeeze(1), rowptr)
pass
# Col-wise...
elif other.size(0) == 1 and other.size(1) == src.size(1):
other = other.squeeze(0)[col]
else:
raise ValueError(
f'Size mismatch: Expected size ({src.size(0)}, 1, ...) or '
f'(1, {src.size(1)}, ...), but got size {other.size()}.')

if value is not None:
value = other.to(value.dtype).mul_(value)
else:
value = other
return src.set_value(value, layout='coo')

assert isinstance(other, SparseTensor)

if not src.is_coalesced():
raise ValueError("The `src` tensor is not coalesced")
if not other.is_coalesced():
raise ValueError("The `other` tensor is not coalesced")

rowA, colA, valueA = src.coo()
rowB, colB, valueB = other.coo()

row = torch.cat([rowA, rowB], dim=0)
col = torch.cat([colA, colB], dim=0)

if valueA is not None and valueB is not None:
value = torch.cat([valueA, valueB], dim=0)
else:
value = other
return src.set_value(value, layout='coo')
raise ValueError('Both sparse tensors must contain values')

M = max(src.size(0), other.size(0))
N = max(src.size(1), other.size(1))
sparse_sizes = (M, N)

# Sort indices:
idx = col.new_full((col.numel() + 1, ), -1)
idx[1:] = row * sparse_sizes[1] + col
perm = idx[1:].argsort()
idx[1:] = idx[1:][perm]

row, col, value = row[perm], col[perm], value[perm]

valid_mask = idx[1:] == idx[:-1]
valid_idx = valid_mask.nonzero().view(-1)

return SparseTensor(
row=row[valid_mask],
col=col[valid_mask],
value=value[valid_idx - 1] * value[valid_idx],
sparse_sizes=sparse_sizes,
)


def mul_(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
Expand All @@ -43,8 +99,11 @@ def mul_(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
return src.set_value_(value, layout='coo')


def mul_nnz(src: SparseTensor, other: torch.Tensor,
layout: Optional[str] = None) -> SparseTensor:
def mul_nnz(
src: SparseTensor,
other: torch.Tensor,
layout: Optional[str] = None,
) -> SparseTensor:
value = src.storage.value()
if value is not None:
value = value.mul(other.to(value.dtype))
Expand All @@ -53,8 +112,11 @@ def mul_nnz(src: SparseTensor, other: torch.Tensor,
return src.set_value(value, layout=layout)


def mul_nnz_(src: SparseTensor, other: torch.Tensor,
layout: Optional[str] = None) -> SparseTensor:
def mul_nnz_(
src: SparseTensor,
other: torch.Tensor,
layout: Optional[str] = None,
) -> SparseTensor:
value = src.storage.value()
if value is not None:
value = value.mul_(other.to(value.dtype))
Expand Down