From 8ed0867d8281930b203e5efbbfa7c98bcc1191e3 Mon Sep 17 00:00:00 2001 From: "Aji, Ashwin" Date: Mon, 22 Jan 2024 18:23:00 -0500 Subject: [PATCH] [ROCm] fixes ambiguous calls to `shfl*` where there is no explicit type conversion from `c10::Half` to `__half` --- csrc/cuda/utils.cuh | 28 +++++++++++++++++++++++++--- 1 file changed, 25 insertions(+), 3 deletions(-) diff --git a/csrc/cuda/utils.cuh b/csrc/cuda/utils.cuh index ba4f3a11..747a8e2c 100644 --- a/csrc/cuda/utils.cuh +++ b/csrc/cuda/utils.cuh @@ -6,9 +6,10 @@ AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor") #define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch") -__device__ __inline__ at::Half -__shfl_sync(const unsigned mask, const at::Half var, const int srcLane) { - return __shfl_sync(mask, var.operator __half(), srcLane); +__device__ __inline__ at::Half __shfl_up_sync(const unsigned mask, + const at::Half var, + const unsigned int delta) { + return __shfl_up_sync(mask, var.operator __half(), delta); } __device__ __inline__ at::Half __shfl_down_sync(const unsigned mask, @@ -17,6 +18,27 @@ __device__ __inline__ at::Half __shfl_down_sync(const unsigned mask, return __shfl_down_sync(mask, var.operator __half(), delta); } +__device__ __inline__ at::Half __shfl_sync(const unsigned mask, + const at::Half var, + const int delta) { + return __shfl_sync(mask, var.operator __half(), delta); +} + +__device__ __inline__ at::Half __shfl_up(const at::Half var, + const unsigned int delta) { + return __shfl_up(var.operator __half(), delta); +} + +__device__ __inline__ at::Half __shfl_down(const at::Half var, + const unsigned int delta) { + return __shfl_down(var.operator __half(), delta); +} + +__device__ __inline__ at::Half +__shfl(const at::Half var, const int delta) { + return __shfl(var.operator __half(), delta); +} + #ifdef USE_ROCM __device__ __inline__ at::Half __ldg(const at::Half* ptr) { return __ldg(reinterpret_cast(ptr));