diff --git a/csrc/cuda/convert_cuda.cu b/csrc/cuda/convert_cuda.cu index 30f7d273..f3c57050 100644 --- a/csrc/cuda/convert_cuda.cu +++ b/csrc/cuda/convert_cuda.cu @@ -25,7 +25,7 @@ __global__ void ind2ptr_kernel(const int64_t *ind_data, int64_t *out_data, torch::Tensor ind2ptr_cuda(torch::Tensor ind, int64_t M) { CHECK_CUDA(ind); - cudaSetDevice(ind.get_device()); + c10::cuda::MaybeSetDevice(ind.get_device()); auto out = torch::empty({M + 1}, ind.options()); @@ -55,7 +55,7 @@ __global__ void ptr2ind_kernel(const int64_t *ptr_data, int64_t *out_data, torch::Tensor ptr2ind_cuda(torch::Tensor ptr, int64_t E) { CHECK_CUDA(ptr); - cudaSetDevice(ptr.get_device()); + c10::cuda::MaybeSetDevice(ptr.get_device()); auto out = torch::empty({E}, ptr.options()); auto ptr_data = ptr.data_ptr(); diff --git a/csrc/cuda/diag_cuda.cu b/csrc/cuda/diag_cuda.cu index 7b608b44..23a01fe0 100644 --- a/csrc/cuda/diag_cuda.cu +++ b/csrc/cuda/diag_cuda.cu @@ -43,7 +43,7 @@ torch::Tensor non_diag_mask_cuda(torch::Tensor row, torch::Tensor col, int64_t M, int64_t N, int64_t k) { CHECK_CUDA(row); CHECK_CUDA(col); - cudaSetDevice(row.get_device()); + c10::cuda::MaybeSetDevice(row.get_device()); auto E = row.size(0); auto num_diag = k < 0 ? std::min(M + k, N) : std::min(M, N - k); diff --git a/csrc/cuda/rw_cuda.cu b/csrc/cuda/rw_cuda.cu index e7e821f7..24254446 100644 --- a/csrc/cuda/rw_cuda.cu +++ b/csrc/cuda/rw_cuda.cu @@ -33,7 +33,7 @@ torch::Tensor random_walk_cuda(torch::Tensor rowptr, torch::Tensor col, CHECK_CUDA(rowptr); CHECK_CUDA(col); CHECK_CUDA(start); - cudaSetDevice(rowptr.get_device()); + c10::cuda::MaybeSetDevice(rowptr.get_device()); CHECK_INPUT(rowptr.dim() == 1); CHECK_INPUT(col.dim() == 1); diff --git a/csrc/cuda/spmm_cuda.cu b/csrc/cuda/spmm_cuda.cu index c58e8f84..f67b560c 100644 --- a/csrc/cuda/spmm_cuda.cu +++ b/csrc/cuda/spmm_cuda.cu @@ -99,7 +99,7 @@ spmm_cuda(torch::Tensor rowptr, torch::Tensor col, if (optional_value.has_value()) CHECK_CUDA(optional_value.value()); CHECK_CUDA(mat); - cudaSetDevice(rowptr.get_device()); + c10::cuda::MaybeSetDevice(rowptr.get_device()); CHECK_INPUT(rowptr.dim() == 1); CHECK_INPUT(col.dim() == 1); @@ -201,7 +201,7 @@ torch::Tensor spmm_value_bw_cuda(torch::Tensor row, torch::Tensor rowptr, CHECK_CUDA(col); CHECK_CUDA(mat); CHECK_CUDA(grad); - cudaSetDevice(row.get_device()); + c10::cuda::MaybeSetDevice(row.get_device()); mat = mat.contiguous(); grad = grad.contiguous();