Skip to content

Commit

Permalink
[OpenMP] Update atomic helpers to just use headers (llvm#122185)
Browse files Browse the repository at this point in the history
Summary:
Previously we had some indirection here, this patch updates these
utilities to just be normal template functions. We use SFINAE to manage
the special case handling for floats. Also this strips address spaces so
it can be used more generally.
  • Loading branch information
jhuber6 authored and DKLoehr committed Jan 17, 2025
1 parent 28bb265 commit c03d106
Show file tree
Hide file tree
Showing 4 changed files with 170 additions and 217 deletions.
2 changes: 1 addition & 1 deletion offload/DeviceRTL/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ set(bc_flags -c -foffload-lto -std=c++17 -fvisibility=hidden
${clang_opt_flags} --offload-device-only
-nocudalib -nogpulib -nogpuinc -nostdlibinc
-fopenmp -fopenmp-cuda-mode
-Wno-unknown-cuda-version
-Wno-unknown-cuda-version -Wno-openmp-target
-DOMPTARGET_DEVICE_RUNTIME
-I${include_directory}
-I${devicertl_base_directory}/../include
Expand Down
41 changes: 41 additions & 0 deletions offload/DeviceRTL/include/DeviceUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,47 @@

namespace utils {

template <typename T> struct type_identity {
using type = T;
};

template <typename T, T v> struct integral_constant {
inline static constexpr T value = v;
};

/// Freestanding SFINAE helpers.
template <class T> struct remove_cv : type_identity<T> {};
template <class T> struct remove_cv<const T> : type_identity<T> {};
template <class T> struct remove_cv<volatile T> : type_identity<T> {};
template <class T> struct remove_cv<const volatile T> : type_identity<T> {};
template <class T> using remove_cv_t = typename remove_cv<T>::type;

using true_type = integral_constant<bool, true>;
using false_type = integral_constant<bool, false>;

template <typename T, typename U> struct is_same : false_type {};
template <typename T> struct is_same<T, T> : true_type {};
template <typename T, typename U>
inline constexpr bool is_same_v = is_same<T, U>::value;

template <typename T> struct is_floating_point {
inline static constexpr bool value =
is_same_v<remove_cv<T>, float> || is_same_v<remove_cv<T>, double>;
};
template <typename T>
inline constexpr bool is_floating_point_v = is_floating_point<T>::value;

template <bool B, typename T = void> struct enable_if;
template <typename T> struct enable_if<true, T> : type_identity<T> {};
template <bool B, typename T = void>
using enable_if_t = typename enable_if<B, T>::type;

template <class T> struct remove_addrspace : type_identity<T> {};
template <class T, int N>
struct remove_addrspace<T [[clang::address_space(N)]]> : type_identity<T> {};
template <class T>
using remove_addrspace_t = typename remove_addrspace<T>::type;

/// Return the value \p Var from thread Id \p SrcLane in the warp if the thread
/// is identified by \p Mask.
int32_t shuffle(uint64_t Mask, int32_t Var, int32_t SrcLane, int32_t Width);
Expand Down
169 changes: 123 additions & 46 deletions offload/DeviceRTL/include/Synchronization.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@
#define OMPTARGET_DEVICERTL_SYNCHRONIZATION_H

#include "DeviceTypes.h"
#include "DeviceUtils.h"

namespace ompx {
#pragma omp begin declare target device_type(nohost)

namespace ompx {
namespace atomic {

enum OrderingTy {
Expand Down Expand Up @@ -48,51 +50,124 @@ uint32_t inc(uint32_t *Addr, uint32_t V, OrderingTy Ordering,
/// result is stored in \p *Addr;
/// {

#define ATOMIC_COMMON_OP(TY) \
TY add(TY *Addr, TY V, OrderingTy Ordering); \
TY mul(TY *Addr, TY V, OrderingTy Ordering); \
TY load(TY *Addr, OrderingTy Ordering); \
void store(TY *Addr, TY V, OrderingTy Ordering); \
bool cas(TY *Addr, TY ExpectedV, TY DesiredV, OrderingTy OrderingSucc, \
OrderingTy OrderingFail);

#define ATOMIC_FP_ONLY_OP(TY) \
TY min(TY *Addr, TY V, OrderingTy Ordering); \
TY max(TY *Addr, TY V, OrderingTy Ordering);

#define ATOMIC_INT_ONLY_OP(TY) \
TY min(TY *Addr, TY V, OrderingTy Ordering); \
TY max(TY *Addr, TY V, OrderingTy Ordering); \
TY bit_or(TY *Addr, TY V, OrderingTy Ordering); \
TY bit_and(TY *Addr, TY V, OrderingTy Ordering); \
TY bit_xor(TY *Addr, TY V, OrderingTy Ordering);

#define ATOMIC_FP_OP(TY) \
ATOMIC_FP_ONLY_OP(TY) \
ATOMIC_COMMON_OP(TY)

#define ATOMIC_INT_OP(TY) \
ATOMIC_INT_ONLY_OP(TY) \
ATOMIC_COMMON_OP(TY)

// This needs to be kept in sync with the header. Also the reason we don't use
// templates here.
ATOMIC_INT_OP(int8_t)
ATOMIC_INT_OP(int16_t)
ATOMIC_INT_OP(int32_t)
ATOMIC_INT_OP(int64_t)
ATOMIC_INT_OP(uint8_t)
ATOMIC_INT_OP(uint16_t)
ATOMIC_INT_OP(uint32_t)
ATOMIC_INT_OP(uint64_t)
ATOMIC_FP_OP(float)
ATOMIC_FP_OP(double)

#undef ATOMIC_INT_ONLY_OP
#undef ATOMIC_FP_ONLY_OP
#undef ATOMIC_COMMON_OP
#undef ATOMIC_INT_OP
#undef ATOMIC_FP_OP
template <typename Ty, typename V = utils::remove_addrspace_t<Ty>>
bool cas(Ty *Address, V ExpectedV, V DesiredV, atomic::OrderingTy OrderingSucc,
atomic::OrderingTy OrderingFail) {
return __scoped_atomic_compare_exchange(Address, &ExpectedV, &DesiredV, false,
OrderingSucc, OrderingFail,
__MEMORY_SCOPE_DEVICE);
}

template <typename Ty, typename V = utils::remove_addrspace_t<Ty>>
V add(Ty *Address, V Val, atomic::OrderingTy Ordering) {
return __scoped_atomic_fetch_add(Address, Val, Ordering,
__MEMORY_SCOPE_DEVICE);
}

template <typename Ty, typename V = utils::remove_addrspace_t<Ty>>
V load(Ty *Address, atomic::OrderingTy Ordering) {
return add(Address, Ty(0), Ordering);
}

template <typename Ty, typename V = utils::remove_addrspace_t<Ty>>
void store(Ty *Address, V Val, atomic::OrderingTy Ordering) {
__scoped_atomic_store_n(Address, Val, Ordering, __MEMORY_SCOPE_DEVICE);
}

template <typename Ty, typename V = utils::remove_addrspace_t<Ty>>
V mul(Ty *Address, V Val, atomic::OrderingTy Ordering) {
Ty TypedCurrentVal, TypedResultVal, TypedNewVal;
bool Success;
do {
TypedCurrentVal = atomic::load(Address, Ordering);
TypedNewVal = TypedCurrentVal * Val;
Success = atomic::cas(Address, TypedCurrentVal, TypedNewVal, Ordering,
atomic::relaxed);
} while (!Success);
return TypedResultVal;
}

template <typename Ty, typename V = utils::remove_addrspace_t<Ty>>
utils::enable_if_t<!utils::is_floating_point_v<V>, V>
max(Ty *Address, V Val, atomic::OrderingTy Ordering) {
return __scoped_atomic_fetch_max(Address, Val, Ordering,
__MEMORY_SCOPE_DEVICE);
}

template <typename Ty, typename V = utils::remove_addrspace_t<Ty>>
utils::enable_if_t<utils::is_same_v<V, float>, V>
max(Ty *Address, V Val, atomic::OrderingTy Ordering) {
if (Val >= 0)
return utils::convertViaPun<float>(
max((int32_t *)Address, utils::convertViaPun<int32_t>(Val), Ordering));
return utils::convertViaPun<float>(
min((uint32_t *)Address, utils::convertViaPun<uint32_t>(Val), Ordering));
}

template <typename Ty, typename V = utils::remove_addrspace_t<Ty>>
utils::enable_if_t<utils::is_same_v<V, double>, V>
max(Ty *Address, V Val, atomic::OrderingTy Ordering) {
if (Val >= 0)
return utils::convertViaPun<double>(
max((int64_t *)Address, utils::convertViaPun<int64_t>(Val), Ordering));
return utils::convertViaPun<double>(
min((uint64_t *)Address, utils::convertViaPun<uint64_t>(Val), Ordering));
}

template <typename Ty, typename V = utils::remove_addrspace_t<Ty>>
utils::enable_if_t<!utils::is_floating_point_v<V>, V>
min(Ty *Address, V Val, atomic::OrderingTy Ordering) {
return __scoped_atomic_fetch_min(Address, Val, Ordering,
__MEMORY_SCOPE_DEVICE);
}

// TODO: Implement this with __atomic_fetch_max and remove the duplication.
template <typename Ty, typename V = utils::remove_addrspace_t<Ty>>
utils::enable_if_t<utils::is_same_v<V, float>, V>
min(Ty *Address, V Val, atomic::OrderingTy Ordering) {
if (Val >= 0)
return utils::convertViaPun<float>(
min((int32_t *)Address, utils::convertViaPun<int32_t>(Val), Ordering));
return utils::convertViaPun<float>(
max((uint32_t *)Address, utils::convertViaPun<uint32_t>(Val), Ordering));
}

// TODO: Implement this with __atomic_fetch_max and remove the duplication.
template <typename Ty, typename V = utils::remove_addrspace_t<Ty>>
utils::enable_if_t<utils::is_same_v<V, double>, V>
min(Ty *Address, utils::remove_addrspace_t<Ty> Val,
atomic::OrderingTy Ordering) {
if (Val >= 0)
return utils::convertViaPun<double>(
min((int64_t *)Address, utils::convertViaPun<int64_t>(Val), Ordering));
return utils::convertViaPun<double>(
max((uint64_t *)Address, utils::convertViaPun<uint64_t>(Val), Ordering));
}

template <typename Ty, typename V = utils::remove_addrspace_t<Ty>>
V bit_or(Ty *Address, V Val, atomic::OrderingTy Ordering) {
return __scoped_atomic_fetch_or(Address, Val, Ordering,
__MEMORY_SCOPE_DEVICE);
}

template <typename Ty, typename V = utils::remove_addrspace_t<Ty>>
V bit_and(Ty *Address, V Val, atomic::OrderingTy Ordering) {
return __scoped_atomic_fetch_and(Address, Val, Ordering,
__MEMORY_SCOPE_DEVICE);
}

template <typename Ty, typename V = utils::remove_addrspace_t<Ty>>
V bit_xor(Ty *Address, V Val, atomic::OrderingTy Ordering) {
return __scoped_atomic_fetch_xor(Address, Val, Ordering,
__MEMORY_SCOPE_DEVICE);
}

static inline uint32_t atomicExchange(uint32_t *Address, uint32_t Val,
atomic::OrderingTy Ordering) {
uint32_t R;
__scoped_atomic_exchange(Address, &Val, &R, Ordering, __MEMORY_SCOPE_DEVICE);
return R;
}

///}

Expand Down Expand Up @@ -145,4 +220,6 @@ void system(atomic::OrderingTy Ordering);

} // namespace ompx

#pragma omp end declare target

#endif
Loading

0 comments on commit c03d106

Please sign in to comment.