Skip to content

Commit

Permalink
[jiterate] addcmul : complex
Browse files Browse the repository at this point in the history
As per title
Pull Request resolved: pytorch#74533
Approved by: https://github.com/anjali411
  • Loading branch information
kshitij12345 authored and pytorchmergebot committed Mar 23, 2022
1 parent f7ee308 commit 85d8647
Showing 1 changed file with 38 additions and 8 deletions.
46 changes: 38 additions & 8 deletions aten/src/ATen/native/cuda/PointwiseOpsKernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,53 @@
#include <ATen/Context.h>
#include <ATen/Dispatch.h>
#include <ATen/native/cuda/Loops.cuh>
#include <ATen/native/cuda/JitLoops.cuh>
#include <ATen/native/DispatchStub.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/PointwiseOps.h>
#include <c10/core/Scalar.h>

namespace at { namespace native {

const char addcmul_name[] = "addcmul";
void addcmul_cuda_kernel(TensorIteratorBase& iter, const Scalar& value) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kHalf, kBFloat16, iter.dtype(), "addcmul_cuda", [&]() {
// note(mkozuki): If scalar_t is fp16 or bfloat16, cast scalar to float
// and do math in fp32 for better accuracy.
using accscalar_t = at::acc_type<scalar_t, true>;
auto alpha = value.to<accscalar_t>();
gpu_kernel(iter, [alpha]GPU_LAMBDA(scalar_t a, scalar_t b, scalar_t c) -> scalar_t {
return a + alpha * (static_cast<accscalar_t>(b) * static_cast<accscalar_t>(c));
auto dtype = iter.dtype();
if (at::isComplexType(dtype)) {
#if AT_USE_JITERATOR()
AT_DISPATCH_COMPLEX_TYPES(dtype, "addcmul_cuda", [&]() {
auto alpha = value.to<scalar_t>();
static const auto addcmul_string = jiterator_stringify(
template <typename T> T addcmul(T a, T b, T c, T alpha) { return a + alpha * (b * c); });
jitted_gpu_kernel<
/*name=*/addcmul_name,
/*return_dtype=*/scalar_t,
/*common_dtype=*/scalar_t,
/*arity=*/3>(
iter,
addcmul_string,
/*scalar_pos=*/at::cuda::jit::BinaryFuncVariant::NoScalar,
/*scalar_val=*/0,
/*extra_args=*/std::make_tuple(alpha));
});
#else
AT_DISPATCH_COMPLEX_TYPES(dtype, "addcmul_cuda", [&]() {
auto alpha = value.to<scalar_t>();
gpu_kernel(iter, [alpha]GPU_LAMBDA(scalar_t a, scalar_t b, scalar_t c) -> scalar_t {
return a + alpha * b * c;
});
});
#endif
} else {
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, dtype, "addcmul_cuda", [&]() {
// note(mkozuki): If scalar_t is fp16 or bfloat16, cast scalar to float
// and do math in fp32 for better accuracy.
using accscalar_t = at::acc_type<scalar_t, true>;
auto alpha = value.to<accscalar_t>();
gpu_kernel(iter, [alpha]GPU_LAMBDA(scalar_t a, scalar_t b, scalar_t c) -> scalar_t {
return a + alpha * (static_cast<accscalar_t>(b) * static_cast<accscalar_t>(c));
});
});
});
}
}

void addcdiv_cuda_kernel(TensorIteratorBase& iter, const Scalar& value) {
Expand Down

0 comments on commit 85d8647

Please sign in to comment.