Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

spspmm raises error in cuda but works well in cpu #3097

Open
morgan-bc opened this issue Sep 8, 2021 · 13 comments
Open

spspmm raises error in cuda but works well in cpu #3097

morgan-bc opened this issue Sep 8, 2021 · 13 comments

Comments

@morgan-bc
Copy link

morgan-bc commented Sep 8, 2021

🐛 Bug

To Reproduce

The net is similar with Graph UNet, but has only downsample blocks. The Code is

import torch
import torch.nn as nn
import torch_geometric.nn as gnn
from torch_geometric.nn import GCNConv, TopKPooling
from torch_geometric.utils import add_self_loops, sort_edge_index, remove_self_loops
from torch_sparse import spspmm



class GCNConvBnReLu(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = GCNConv(in_channels, out_channels, bias=False, improved=True)
        self.bn = gnn.BatchNorm(out_channels)
        self.relu = nn.ReLU()

    def forward(self, x, edge_index, edge_weight=None):
        x = self.conv(x, edge_index, edge_weight)
        x = self.bn(x)
        x = self.relu(x)
        return x


class MyNet(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, pool_ratios=0.5, depth=3):
        super().__init__()

        channels = in_channels
        self.depth = depth
        self.down_convs = nn.ModuleList()
        self.pools = nn.ModuleList()
        self.down_convs.append(GCNConvBnReLu(channels, hidden_channels))
        self.fc = nn.Linear(hidden_channels, out_channels)

        for i in range(depth):
            self.down_convs.append(GCNConvBnReLu(hidden_channels, hidden_channels))
            self.pools.append(TopKPooling(hidden_channels, ratio=pool_ratios))


    def forward(self, x, edge_index, batch=None):
        depth = self.depth
        edge_weight = x.new_ones(edge_index.size(1))
        x = self.down_convs[0](x, edge_index, edge_weight)

        for i in range(1, depth + 1):
            print(edge_index.shape)
            print(edge_index.min(), edge_index.max(), x.size(0))

            edge_index, edge_weight = self.augment_adj(edge_index, edge_weight, x.size(0))

            print(edge_index.shape)
            print(edge_index.min(), edge_index.max(), x.size(0))
            print('----------')

            x, edge_index, edge_weight, batch, _, _ = self.pools[i-1](x, edge_index, edge_weight, batch)
            x = self.down_convs[i](x, edge_index, edge_weight)

        x = gnn.global_mean_pool(x, batch)
        out = self.fc(x)
        return out

    def augment_adj(self, edge_index, edge_weight, num_nodes):
        edge_index, edge_weight = remove_self_loops(edge_index, edge_weight)
        edge_index, edge_weight = add_self_loops(edge_index, edge_weight,
                                                 num_nodes=num_nodes)
        edge_index, edge_weight = sort_edge_index(edge_index, edge_weight,
                                                  num_nodes)
        edge_index, edge_weight = spspmm(edge_index, edge_weight, edge_index,
                                         edge_weight, num_nodes, num_nodes,
                                         num_nodes)
        edge_index, edge_weight = remove_self_loops(edge_index, edge_weight)
        return edge_index, edge_weight

Test the MyNet as follow. The test data can be download in Google Drive

device = torch.device('cuda')
model = MyNet(3, 64, 4, 0.5, 3).to(device)
data1 = torch.load('success.pt').to(device)
y1 = model(data1.x, data1.edge_index)
data2 = torch.load('failed.pt').to(device)
y2 = model(data2.x, data2.edge_index)

Expected behavior

The error log is

  File "D:\Software\anaconda3\lib\site-packages\torch_sparse\spspmm.py", line 30, in spspmm
    C = matmul(A, B)
  File "D:\Software\anaconda3\lib\site-packages\torch_sparse\matmul.py", line 125, in matmul
    return spspmm(src, other, reduce)
  File "D:\Software\anaconda3\lib\site-packages\torch_sparse\matmul.py", line 102, in spspmm
    return spspmm_sum(src, other)
  File "D:\Software\anaconda3\lib\site-packages\torch_sparse\matmul.py", line 92, in spspmm_sum
    sparse_sizes=(M, K), is_sorted=True)
  File "D:\Software\anaconda3\lib\site-packages\torch_sparse\tensor.py", line 25, in __init__
    is_sorted=is_sorted)
  File "D:\Software\anaconda3\lib\site-packages\torch_sparse\storage.py", line 70, in __init__
    assert col.max().item() < sparse_sizes[1]
AssertionError

When I set model.eval() or device='cpu', the code works well.

Environment

  • OS: Win10
  • Python version: 3.7
  • PyTorch version: 1.8.1
  • PyG version: 1.7.2
  • CUDA/cuDNN version: 10.2 / 8.0.5
  • GCC version:
  • Any other relevant information:

Additional context

@rusty1s
Copy link
Member

rusty1s commented Sep 9, 2021

Thanks for reporting. Interestingly, it works for me. Can you show me the outputs when modified as follows?

for i in range(1, depth + 1):
    print(edge_index.shape)
    print(edge_index.min(), edge_index.max(), x.size(0))
    edge_index, edge_weight = self.augment_adj(edge_index, edge_weight,
                                               x.size(0))
    print(edge_index.shape)
    print(edge_index.min(), edge_index.max(), x.size(0))
    print('----------')
    ...

@morgan-bc
Copy link
Author

Test the failed case as follow:

    device = torch.device('cuda')   
    model = MyNet(3, 64, 4, 0.5, 3).to(device)
    data = torch.load('failed.pt').to(device)
    y = model(data.x, data.edge_index)

The outputs the of the failed case using cpu are:

torch.Size([2, 271284])
tensor(0) tensor(45897) 45898
torch.Size([2, 863939])
tensor(0) tensor(45897) 45898
----------
torch.Size([2, 397857])
tensor(0) tensor(22948) 22949
torch.Size([2, 1402064])
tensor(0) tensor(22948) 22949
----------
torch.Size([2, 656276])
tensor(0) tensor(11474) 11475
torch.Size([2, 2327093])
tensor(0) tensor(11474) 11475
----------

The codes raise error using cuda

torch.Size([2, 271284])
tensor(0, device='cuda:0') tensor(45897, device='cuda:0') 45898
Traceback (most recent call last):

@rusty1s
Copy link
Member

rusty1s commented Sep 12, 2021

This is weird, as your outputs do not violate the assert col.max().item() < sparse_sizes[1] assertion. Can you find out why this is the case nonetheless, e.g., by debugging the __init__ call in torch_sparse/storage.py?

@morgan-bc
Copy link
Author

The bug occurs in the the function torch.sparse.spspmm_sum in file torch_sparse/matmul.py. I add printing codes in the spspmm_sum as follows:

def spspmm_sum(src: SparseTensor, other: SparseTensor) -> SparseTensor:
    assert src.sparse_size(1) == other.sparse_size(0)
    rowptrA, colA, valueA = src.csr()
    rowptrB, colB, valueB = other.csr()
    value = valueA
    if valueA is not None and valueA.dtype == torch.half:
        valueA = valueA.to(torch.float)
    if valueB is not None and valueB.dtype == torch.half:
        valueB = valueB.to(torch.float)
    M, K = src.sparse_size(0), other.sparse_size(1)
    rowptrC, colC, valueC = torch.ops.torch_sparse.spspmm_sum(
        rowptrA, colA, valueA, rowptrB, colB, valueB, K)
    
    print('--------------')
    print('A', rowptrA.shape, rowptrA.max().item(), colA.max().item(), valueA.max().item())
    print('B', rowptrB.shape, rowptrB.max().item(), colB.max().item(), valueB.max().item())
    print('C', rowptrC.shape, rowptrC.max().item(), colC.max().item(), valueC.max().item())
    print('--------------')

    if valueC is not None and value is not None:
        valueC = valueC.to(value.dtype)
    return SparseTensor(row=None, rowptr=rowptrC, col=colC, value=valueC,
                        sparse_sizes=(M, K), is_sorted=True)

The cpu outputs are

A torch.Size([45899]) 317180 45897 1.0
B torch.Size([45899]) 317180 45897 1.0
C torch.Size([45899]) 909837 45897 12.0

The cuda outputs are

A torch.Size([45899]) 317180 45897 1.0
B torch.Size([45899]) 317180 45897 1.0
C torch.Size([45899]) 909848 1909964387 12.0

The cuda with torch_no_grad() outputs are

A torch.Size([45899]) 317180 45897 1.0
B torch.Size([45899]) 317180 45897 1.0
C torch.Size([45899]) 909848 45897 12.0

It can be found that colC in cuda is incorrectly computed. In addition to the exceptions, the results of rowptrC using cpu and cuda are not exactly the same.

@rusty1s
Copy link
Member

rusty1s commented Sep 15, 2021

Thank you very much! Really helpful. Can you let me know how you tried to install torch-sparse (source, wheels, conda)? What happens if you install it from source? Does it resolve these issues?

@morgan-bc
Copy link
Author

I installed torch-sparse via pip as discrebed the PyG documents and I also tried to install it from source. The problem ouccurs in both way.

@amirbarda
Copy link

Hi, I have exactly the same issue with cuda 10.2, pytorch 1.9.0 and python 3.7 .

On what combination of cuda\pytorch\python does this bug not happen in?

@rusty1s
Copy link
Member

rusty1s commented Dec 11, 2021

As far as I know, this may be dependent on the GPU (see rusty1s/pytorch_sparse#174), but I'm not entirely sure :(

@amirbarda
Copy link

Thanks!
I have a RTX 2080Ti, which is one of the cards listed in rusty1s/pytorch_sparse#174 for which the bug does not show up, although its not quite the same one.
I will try to run on another card to check.

@yoterel
Copy link

yoterel commented Dec 28, 2021

I also encounter this exact same bug using the provided snippet, and also using the vanilla graph unet from the repository.
I suspected it might be due to duplicated edges in the forward (edge_index), but this is not the case.

Reproduced with RTX 3080, cuda 11.3, pytorch 1.10.1, python 3.8

@rusty1s
Copy link
Member

rusty1s commented Dec 29, 2021

If you have data + code snippet to reproduce, please let me know :)

@kyu823
Copy link

kyu823 commented Feb 13, 2022

I have exact same issue when I use Titan RTX and RTX 3090. Is there any way to solve it?

@rusty1s
Copy link
Member

rusty1s commented Feb 14, 2022

@wrccrwx @kimkyusik It's really a bummer that I cannot reproduce this issue. I'm really sorry. I basically followed the instructions from the cusparse documentation for implementing our CUDA routine in spspmm_cuda.cu:

// assume matrices A, B and D are ready.
int baseC, nnzC;
csrgemm2Info_t info = NULL;
size_t bufferSize;
void *buffer = NULL;
// nnzTotalDevHostPtr points to host memory
int *nnzTotalDevHostPtr = &nnzC;
double alpha = -1.0;
double beta  =  1.0;
cusparseSetPointerMode(handle, CUSPARSE_POINTER_MODE_HOST);

// step 1: create an opaque structure
cusparseCreateCsrgemm2Info(&info);

// step 2: allocate buffer for csrgemm2Nnz and csrgemm2
cusparseDcsrgemm2_bufferSizeExt(handle, m, n, k, &alpha,
    descrA, nnzA, csrRowPtrA, csrColIndA,
    descrB, nnzB, csrRowPtrB, csrColIndB,
    &beta,
    descrD, nnzD, csrRowPtrD, csrColIndD,
    info,
    &bufferSize);
cudaMalloc(&buffer, bufferSize);

// step 3: compute csrRowPtrC
cudaMalloc((void**)&csrRowPtrC, sizeof(int)*(m+1));
cusparseXcsrgemm2Nnz(handle, m, n, k,
        descrA, nnzA, csrRowPtrA, csrColIndA,
        descrB, nnzB, csrRowPtrB, csrColIndB,
        descrD, nnzD, csrRowPtrD, csrColIndD,
        descrC, csrRowPtrC, nnzTotalDevHostPtr,
        info, buffer );
if (NULL != nnzTotalDevHostPtr){
    nnzC = *nnzTotalDevHostPtr;
}else{
    cudaMemcpy(&nnzC, csrRowPtrC+m, sizeof(int), cudaMemcpyDeviceToHost);
    cudaMemcpy(&baseC, csrRowPtrC, sizeof(int), cudaMemcpyDeviceToHost);
    nnzC -= baseC;
}

// step 4: finish sparsity pattern and value of C
cudaMalloc((void**)&csrColIndC, sizeof(int)*nnzC);
cudaMalloc((void**)&csrValC, sizeof(double)*nnzC);
// Remark: set csrValC to null if only sparsity pattern is required.
cusparseDcsrgemm2(handle, m, n, k, &alpha,
        descrA, nnzA, csrValA, csrRowPtrA, csrColIndA,
        descrB, nnzB, csrValB, csrRowPtrB, csrColIndB,
        &beta,
        descrD, nnzD, csrValD, csrRowPtrD, csrColIndD,
        descrC, csrValC, csrRowPtrC, csrColIndC,
        info, buffer);

// step 5: destroy the opaque structure
cusparseDestroyCsrgemm2Info(info);

Any chance you can debug where our routine crashes by installing torch-sparse from source?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants