diff --git a/csrc/common.h b/csrc/common.h index b56fc14..2b2b1b6 100644 --- a/csrc/common.h +++ b/csrc/common.h @@ -5,6 +5,32 @@ #include #include +class OptionalCUDAGuard { + int set_device_ = -1; + int current_device_ = -1; + + public: + OptionalCUDAGuard(int device) : set_device_(device) { + cudaError_t err = cudaGetDevice(¤t_device_); + std::stringstream ss; + if (err != cudaSuccess) { + ss << "cudaGetDevice failed with error code " << cudaGetErrorString(err); + TORCH_CHECK(err == cudaSuccess, ss.str()); + } + if (current_device_ == device) { + return; + } + err = cudaSetDevice(device); + if (err != cudaSuccess) { + ss << "cudaGetDevice failed with error code " << cudaGetErrorString(err); + TORCH_CHECK(err == cudaSuccess, ss.str()); + } + } + ~OptionalCUDAGuard() { + if (set_device_ != current_device_) cudaSetDevice(current_device_); + } +}; + #define gpuErrchk(ret) gpuAssert((ret), __FILE__, __LINE__); inline void gpuAssert(cudaError_t code, const char* file, int line) { diff --git a/csrc/dequant_impl_packed.cu b/csrc/dequant_impl_packed.cu index 99dc573..be67ed3 100644 --- a/csrc/dequant_impl_packed.cu +++ b/csrc/dequant_impl_packed.cu @@ -298,6 +298,7 @@ torch::Tensor lauch_deqantize_outliers_cuda_packkernel( const c10::optional& outliers_indices, //[num_cen, c_size, ol_in_f] const c10::optional& outliers_centroids, //[num_c, c_size, out_vec_len] const c10::optional& perm, const torch::Tensor& weight_scale, const torch::Tensor& weight_bias) { + OptionalCUDAGuard cudaguard(q_indice.device().index()); int base_groupsize = centroids.size(-1); // how many elements in a vector int res_groupsize = residual_centroids.has_value() ? residual_centroids.value().size(-1) : 0; // TORCH_CHECK((res_groupsize===base_groupsize||res_groupsize==0), "res_groupsize===base_groupsize is false, must be @@ -443,6 +444,7 @@ torch::Tensor lauch_gemv_outliers_cuda_packkernel( const c10::optional& outliers_centroids, //[num_c, c_size, out_vec_len] const c10::optional& perm, const torch::Tensor& weight_scale, const torch::Tensor& weight_bias, const c10::optional& bias) { + OptionalCUDAGuard cudaguard(input.device().index()); const int base_groupsize = centroids.size(-1); int index_bits = log2(centroids.size(1)); int res_index_bits = residual_centroids.has_value() ? log2(residual_centroids.value().size(1)) : 0; diff --git a/csrc/utils.cuh b/csrc/utils.cuh index f66154b..6409cae 100644 --- a/csrc/utils.cuh +++ b/csrc/utils.cuh @@ -122,11 +122,12 @@ __device__ __forceinline__ uint32_t iterator_packed_tensor(const uint32_t* ptr, int second = end_bits / 32; start_bits = start_bits % 32; end_bits = end_bits % 32; - uint32_t sec_v = ptr[second]; uint32_t v = (ptr[first] >> (start_bits)) & ((1 << WBITS) - 1); - if (first == second) { + if (first == second || end_bits == 0) { return v; } else { + // second position might be out of bound + uint32_t sec_v = ptr[second]; v |= ((sec_v) & ((1 << (end_bits)) - 1)) << (32 - start_bits); return v; }