From bf2591a7765edda20f0f18249804e371e3884ac5 Mon Sep 17 00:00:00 2001 From: Ashwin Aji Date: Mon, 22 Jan 2024 23:24:23 -0800 Subject: [PATCH] [ROCm] fixes ambiguous calls to `shfl*` where there is no explicit type conversion from `c10::Half` to `__half` (#360) [ROCm] fixes ambiguous calls to `shfl*` where there is no explicit type conversion from `c10::Half` to `__half` --- csrc/cuda/utils.cuh | 28 +++++++++++++++++++++++++--- 1 file changed, 25 insertions(+), 3 deletions(-) diff --git a/csrc/cuda/utils.cuh b/csrc/cuda/utils.cuh index ba4f3a11..747a8e2c 100644 --- a/csrc/cuda/utils.cuh +++ b/csrc/cuda/utils.cuh @@ -6,9 +6,10 @@ AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor") #define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch") -__device__ __inline__ at::Half -__shfl_sync(const unsigned mask, const at::Half var, const int srcLane) { - return __shfl_sync(mask, var.operator __half(), srcLane); +__device__ __inline__ at::Half __shfl_up_sync(const unsigned mask, + const at::Half var, + const unsigned int delta) { + return __shfl_up_sync(mask, var.operator __half(), delta); } __device__ __inline__ at::Half __shfl_down_sync(const unsigned mask, @@ -17,6 +18,27 @@ __device__ __inline__ at::Half __shfl_down_sync(const unsigned mask, return __shfl_down_sync(mask, var.operator __half(), delta); } +__device__ __inline__ at::Half __shfl_sync(const unsigned mask, + const at::Half var, + const int delta) { + return __shfl_sync(mask, var.operator __half(), delta); +} + +__device__ __inline__ at::Half __shfl_up(const at::Half var, + const unsigned int delta) { + return __shfl_up(var.operator __half(), delta); +} + +__device__ __inline__ at::Half __shfl_down(const at::Half var, + const unsigned int delta) { + return __shfl_down(var.operator __half(), delta); +} + +__device__ __inline__ at::Half +__shfl(const at::Half var, const int delta) { + return __shfl(var.operator __half(), delta); +} + #ifdef USE_ROCM __device__ __inline__ at::Half __ldg(const at::Half* ptr) { return __ldg(reinterpret_cast(ptr));