diff --git a/python/cugraph-equivariant/cugraph_equivariant/tests/conftest.py b/python/cugraph-equivariant/cugraph_equivariant/tests/conftest.py index c7c6bad07db..806e03e6d76 100644 --- a/python/cugraph-equivariant/cugraph_equivariant/tests/conftest.py +++ b/python/cugraph-equivariant/cugraph_equivariant/tests/conftest.py @@ -29,3 +29,11 @@ def example_scatter_data(): } return src_feat, dst_indices, results + + +@pytest.fixture +def empty_scatter_data(): + src_feat = torch.empty((0, 41)) + dst_indices = torch.empty((0,)) + + return src_feat, dst_indices diff --git a/python/cugraph-equivariant/cugraph_equivariant/tests/test_scatter.py b/python/cugraph-equivariant/cugraph_equivariant/tests/test_scatter.py index ff8048468ee..d28a32edcb1 100644 --- a/python/cugraph-equivariant/cugraph_equivariant/tests/test_scatter.py +++ b/python/cugraph-equivariant/cugraph_equivariant/tests/test_scatter.py @@ -18,7 +18,7 @@ @pytest.mark.parametrize("reduce", ["sum", "mean", "prod", "amax", "amin"]) def test_scatter_reduce(example_scatter_data, reduce): - device = torch.device("cuda:0") + device = torch.device("cuda") src, index, out_true = example_scatter_data src = src.to(device) index = index.to(device) @@ -26,3 +26,15 @@ def test_scatter_reduce(example_scatter_data, reduce): out = scatter_reduce(src, index, dim=0, dim_size=None, reduce=reduce) assert torch.allclose(out.cpu(), out_true[reduce]) + + +def test_scatter_reduce_empty(empty_scatter_data): + device = torch.device("cuda") + src, index = empty_scatter_data + src = src.to(device) + index = index.to(device) + + out = scatter_reduce(src, index, dim=0, dim_size=None) + + assert out.numel() == 0 + assert out.size(1) == src.size(1) diff --git a/python/cugraph-equivariant/cugraph_equivariant/utils/scatter.py b/python/cugraph-equivariant/cugraph_equivariant/utils/scatter.py index 45cc541fc7b..909fbc99365 100644 --- a/python/cugraph-equivariant/cugraph_equivariant/utils/scatter.py +++ b/python/cugraph-equivariant/cugraph_equivariant/utils/scatter.py @@ -34,10 +34,9 @@ def scatter_reduce( size = list(src.size()) if dim_size is not None: - assert dim_size >= int(index.max()) + 1 size[dim] = dim_size else: - size[dim] = int(index.max()) + 1 + size[dim] = 0 if index.numel() == 0 else int(index.max()) + 1 out = torch.zeros(size, dtype=src.dtype, device=src.device) return out.scatter_reduce_(dim, index, src, reduce, include_self=False)