Skip to content

Commit

Permalink
Revert "add mul min max mean reduce"
Browse files Browse the repository at this point in the history
This reverts commit 67933a4.

revert
  • Loading branch information
YibinLiu666 committed Dec 2, 2023
1 parent 7beb8c7 commit 4e4a9fd
Show file tree
Hide file tree
Showing 10 changed files with 75 additions and 1,586 deletions.
2 changes: 1 addition & 1 deletion paddle/phi/api/yaml/backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1758,7 +1758,7 @@

- backward_op : put_along_axis_grad
forward : put_along_axis (Tensor arr, Tensor indices, Tensor values, int axis, str reduce = "assign") -> Tensor(out)
args : (Tensor arr, Tensor indices, Tensor values, Tensor out, Tensor out_grad, int axis, str reduce)
args : (Tensor arr, Tensor indices, Tensor out_grad, int axis, str reduce)
output : Tensor(arr_grad), Tensor(values_grad)
infer_meta :
func : GeneralBinaryGradInferMeta
Expand Down
175 changes: 0 additions & 175 deletions paddle/phi/backends/gpu/gpu_primitives.h
Original file line number Diff line number Diff line change
Expand Up @@ -395,181 +395,6 @@ CUDA_ATOMIC_WRAPPER(Add, complex<double>) {
CudaAtomicAdd(imag, val.imag));
}

// For atomicMul.
CUDA_ATOMIC_WRAPPER(Mul, int) {
int res = *address, old = res; // NOLINT
do {
old = res;
res = atomicCAS(address, // NOLINT
old, // NOLINT
val * old); // NOLINT
} while (old != res);
return res;
}

CUDA_ATOMIC_WRAPPER(Mul, unsigned int) {
unsigned int res = *address, old = res; // NOLINT
do {
old = res;
res = atomicCAS(address, // NOLINT
old, // NOLINT
val * old); // NOLINT
} while (old != res);
return res;
}
// CUDA API uses unsigned long long int, we cannot use uint64_t here.
// It because unsigned long long int is not necessarily uint64_t
CUDA_ATOMIC_WRAPPER(Mul, unsigned long long int) { // NOLINT
unsigned long long int old = *address, assumed; // NOLINT

do {
assumed = old;
old = atomicCAS(address, assumed, val * assumed);
} while (assumed != old);
}

CUDA_ATOMIC_WRAPPER(Mul, int64_t) {
// Here, we check long long int must be int64_t.
static_assert(sizeof(int64_t) == sizeof(long long int), // NOLINT
"long long should be int64");
long long int res = *address, old = res; // NOLINT
do {
old = res;
res = (long long int)atomicCAS( // NOLINT
(unsigned long long int *)address, // NOLINT
(unsigned long long int)old, // NOLINT
(unsigned long long int)val * (unsigned long long int)old); // NOLINT
} while (old != res);
return res;
}

CUDA_ATOMIC_WRAPPER(Mul, float) {
int *const address_as_i = reinterpret_cast<int *>(address);
int old = *address_as_i, assumed;

do {
assumed = old;
old = atomicCAS(
address_as_i, assumed, __float_as_int(val * __int_as_float(assumed)));
} while (assumed != old);

return __int_as_float(old);
}

CUDA_ATOMIC_WRAPPER(Mul, double) {
unsigned long long int *const address_as_ull = // NOLINT
reinterpret_cast<unsigned long long int *>(address); // NOLINT
unsigned long long int old = *address_as_ull, assumed; // NOLINT

do {
assumed = old;

old = atomicCAS(address_as_ull,
assumed,
__double_as_longlong(val * __longlong_as_double(assumed)));
} while (assumed != old);

return __longlong_as_double(old);
}

#ifdef PADDLE_CUDA_FP16
inline static __device__ uint32_t mul_to_low_half(uint32_t val, float x) {
phi::dtype::float16 low_half;
// The float16 in lower 16bits
low_half.x = static_cast<uint16_t>(val & 0xFFFFu);
low_half = static_cast<phi::dtype::float16>(static_cast<float>(low_half) * x);
return (val & 0xFFFF0000u) | low_half.x;
}

inline static __device__ uint32_t mul_to_high_half(uint32_t val, float x) {
phi::dtype::float16 high_half;
// The float16 in higher 16bits
high_half.x = static_cast<uint16_t>(val >> 16);
high_half =
static_cast<phi::dtype::float16>(static_cast<float>(high_half) * x);
return (val & 0xFFFFu) | (static_cast<uint32_t>(high_half.x) << 16);
}

CUDA_ATOMIC_WRAPPER(Mul, phi::dtype::float16) {
if (*address >= val) {
return *address;
}
uint32_t *address_as_ui = reinterpret_cast<uint32_t *>(
reinterpret_cast<char *>(address) -
(reinterpret_cast<uintptr_t>(address) & 0x02));
float val_f = static_cast<float>(val);
uint32_t old = *address_as_ui;
uint32_t assumed;
if (((uintptr_t)address & 0x02) == 0) {
// The float16 value stay at lower 16 bits of the address.
do {
assumed = old;
old = atomicCAS(address_as_ui, assumed, mul_to_low_half(assumed, val_f));
} while (old != assumed);
phi::dtype::float16 ret;
ret.x = old & 0xFFFFu;
return ret;
} else {
// The float16 value stay at higher 16 bits of the address.
do {
assumed = old;
old = atomicCAS(address_as_ui, assumed, mul_to_high_half(assumed, val_f));
} while (old != assumed);
phi::dtype::float16 ret;
ret.x = old >> 16;
return ret;
}
}
#endif

inline static __device__ uint32_t bf16_mul_to_low_half(uint32_t val, float x) {
phi::dtype::bfloat16 low_half;
// The bfloat16 in lower 16bits
low_half.x = static_cast<uint16_t>(val & 0xFFFFu);
low_half =
static_cast<phi::dtype::bfloat16>(static_cast<float>(low_half) * x);
return (val & 0xFFFF0000u) | low_half.x;
}

inline static __device__ uint32_t bf16_mul_to_high_half(uint32_t val, float x) {
phi::dtype::bfloat16 high_half;
// The bfloat16 in higher 16bits
high_half.x = static_cast<uint16_t>(val >> 16);
high_half =
static_cast<phi::dtype::bfloat16>(static_cast<float>(high_half) * x);
return (val & 0xFFFFu) | (static_cast<uint32_t>(high_half.x) << 16);
}

CUDA_ATOMIC_WRAPPER(Mul, phi::dtype::bfloat16) {
uint32_t *address_as_ui = reinterpret_cast<uint32_t *>(
reinterpret_cast<char *>(address) -
(reinterpret_cast<uintptr_t>(address) & 0x02));
float val_f = static_cast<float>(val);
uint32_t old = *address_as_ui;
uint32_t assumed;
if (((uintptr_t)address & 0x02) == 0) {
// The bfloat16 value stay at lower 16 bits of the address.
do {
assumed = old;
old = atomicCAS(
address_as_ui, assumed, bf16_mul_to_low_half(assumed, val_f));
} while (old != assumed);
phi::dtype::bfloat16 ret;
ret.x = old & 0xFFFFu;
return ret;
} else {
// The bfloat16 value stay at higher 16 bits of the address.
do {
assumed = old;
old = atomicCAS(
address_as_ui, assumed, bf16_mul_to_high_half(assumed, val_f));
} while (old != assumed);
phi::dtype::bfloat16 ret;
ret.x = old >> 16;
return ret;
}
}

// For atomicMax
USE_CUDA_ATOMIC(Max, int);
USE_CUDA_ATOMIC(Max, unsigned int);
Expand Down
89 changes: 21 additions & 68 deletions paddle/phi/kernels/cpu/put_along_axis_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,11 @@ namespace phi {

template <typename T, typename Context>
void PutAlongAxisGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& x UNUSED,
const DenseTensor& index,
const DenseTensor& value,
const DenseTensor& out,
const DenseTensor& out_grad,
int axis,
const std::string& reduce,
const std::string& reduce UNUSED,
DenseTensor* x_grad,
DenseTensor* value_grad) {
PADDLE_ENFORCE_EQ(
Expand All @@ -42,76 +40,31 @@ void PutAlongAxisGradKernel(const Context& dev_ctx,
const auto& index_type = index.dtype();
if (x_grad) {
phi::Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), false, x_grad);
if (reduce == "assign") {
if (index_type == DataType::INT32) {
phi::funcs::cpu_scatter_input_grad_kernel<T, int32_t>(
// Here passing an unused argument out_grad, because it's
// convenient to instantiate a bunch of template function with the
// same arguments list.
out_grad,
axis,
index,
*x_grad,
dev_ctx);
} else {
phi::funcs::cpu_scatter_input_grad_kernel<T, int64_t>(
out_grad, axis, index, *x_grad, dev_ctx);
}
} else if (reduce == "multiply" || reduce == "mul" || reduce == "amin" ||
reduce == "amax") {
if (index_type == DataType::INT32) {
phi::funcs::cpu_scatter_mul_min_max_input_grad_kernel<T, int32_t>(
out_grad, axis, index, out, x, value, *x_grad, reduce, dev_ctx);
} else {
phi::funcs::cpu_scatter_mul_min_max_input_grad_kernel<T, int64_t>(
out_grad, axis, index, out, x, value, *x_grad, reduce, dev_ctx);
}
} else if (reduce == "mean") {
if (index_type == DataType::INT32) {
phi::funcs::cpu_scatter_mean_input_grad_kernel<T, int32_t>(
// Here passing an unused argument out_grad, because it's
// convenient to instantiate a bunch of template function with the
// same arguments list.
out_grad,
axis,
index,
*x_grad,
dev_ctx);
} else {
phi::funcs::cpu_scatter_mean_input_grad_kernel<T, int64_t>(
out_grad, axis, index, *x_grad, dev_ctx);
}
if (index_type == DataType::INT32) {
phi::funcs::cpu_scatter_input_grad_kernel<T, int32_t>(
// Here passing an unused argument out_grad, because it's
// convenient to instantiate a bunch of template function with the
// same arguments list.
out_grad,
axis,
index,
*x_grad,
dev_ctx);
} else {
phi::funcs::cpu_scatter_input_grad_kernel<T, int64_t>(
out_grad, axis, index, *x_grad, dev_ctx);
}
}

if (value_grad) {
value_grad->Resize(index.dims());
dev_ctx.template Alloc<T>(value_grad);
if (reduce == "assign") {
if (index_type == DataType::INT32) {
phi::funcs::cpu_scatter_value_grad_kernel<T, int32_t>(
out_grad, axis, index, *value_grad, dev_ctx);
} else if (index_type == DataType::INT64) {
phi::funcs::cpu_scatter_value_grad_kernel<T, int64_t>(
out_grad, axis, index, *value_grad, dev_ctx);
}
} else if (reduce == "add" || reduce == "mean") {
if (index_type == DataType::INT32) {
phi::funcs::cpu_scatter_add_mean_value_grad_kernel<T, int32_t>(
out_grad, axis, index, out, x, value, *value_grad, reduce, dev_ctx);
} else {
phi::funcs::cpu_scatter_add_mean_value_grad_kernel<T, int64_t>(
out_grad, axis, index, out, x, value, *value_grad, reduce, dev_ctx);
}
} else if (reduce == "mul" || reduce == "multiply" || reduce == "amin" ||
reduce == "amax") {
if (index_type == DataType::INT32) {
phi::funcs::cpu_scatter_mul_min_max_value_grad_kernel<T, int32_t>(
out_grad, axis, index, out, x, value, *value_grad, reduce, dev_ctx);
} else {
phi::funcs::cpu_scatter_mul_min_max_value_grad_kernel<T, int64_t>(
out_grad, axis, index, out, x, value, *value_grad, reduce, dev_ctx);
}
if (index_type == DataType::INT32) {
phi::funcs::cpu_scatter_value_grad_kernel<T, int32_t>(
out_grad, axis, index, *value_grad, dev_ctx);
} else if (index_type == DataType::INT64) {
phi::funcs::cpu_scatter_value_grad_kernel<T, int64_t>(
out_grad, axis, index, *value_grad, dev_ctx);
}
}
}
Expand Down
24 changes: 0 additions & 24 deletions paddle/phi/kernels/cpu/put_along_axis_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,30 +62,6 @@ void PutAlongAxisKernel(const Context& dev_ctx,
phi::funcs::cpu_scatter_assign_kernel<T, int64_t>(
*out, axis, index, value, dev_ctx);
}
} else if (reduce == "mean") {
if (index_type == DataType::INT32) {
phi::funcs::cpu_scatter_mean_kernel<T, int32_t>(
*out, axis, index, value, dev_ctx);
} else if (index_type == DataType::INT64) {
phi::funcs::cpu_scatter_mean_kernel<T, int64_t>(
*out, axis, index, value, dev_ctx);
}
} else if (reduce == "amax") {
if (index_type == DataType::INT32) {
phi::funcs::cpu_scatter_max_kernel<T, int32_t>(
*out, axis, index, value, dev_ctx);
} else if (index_type == DataType::INT64) {
phi::funcs::cpu_scatter_max_kernel<T, int64_t>(
*out, axis, index, value, dev_ctx);
}
} else if (reduce == "amin") {
if (index_type == DataType::INT32) {
phi::funcs::cpu_scatter_min_kernel<T, int32_t>(
*out, axis, index, value, dev_ctx);
} else if (index_type == DataType::INT64) {
phi::funcs::cpu_scatter_min_kernel<T, int64_t>(
*out, axis, index, value, dev_ctx);
}
} else {
PADDLE_THROW(errors::InvalidArgument(
"can not support reduce: '%s' for scatter kernel, only "
Expand Down
Loading

0 comments on commit 4e4a9fd

Please sign in to comment.