Skip to content

Commit

Permalink
ufunc codegen (pytorch#65851)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#65851

Design doc: https://docs.google.com/document/d/12rtlHnPUpaJ-I52Iob3L0WA3rKRr_OY7fXqeCvn2MVY/edit

First read the design doc to understand the user syntax. In this PR, we have converted add to use ufunc codegen; most of the cpp changes are deleting the preexisting implementations of add, and ufunc/add.h are the new implementations in the ufunc format.

The bulk of this PR is in the new codegen machinery. Here's the order to read the files:

* `tools/codegen/model.py`
  * Some self-explanatory utility classes: `ScalarType`, `DTYPE_CLASSES`
  * New classes for representing ufunc entries in `native_functions.yaml`: `UfuncKey` and  `UfuncInnerLoop`, as well as parsing logic for these entries. UfuncKey has some unusual entries (e.g., CPUScalar) that don't show up in the documentation, more on these below).
  * A predicate `is_ufunc_dispatch_key` for testing which dispatch keys should get automatically generated when an operator opts into ufuncs (CPU and CUDA, for now!)
* `tools/codegen/api/types.py`
  * More self-explanatory utility stuff: ScalarTypeToCppMapping mapping ScalarType to CppTypes; Binding.rename for changing the name of a binding (used when we assign constructor variables to member variables inside CUDA functors)
  * New VectorizedCType, representing `at::vec::Vectorized<T>`. This is used inside vectorized CPU codegen.
  * New `scalar_t` and `opmath_t` BaseCppTypes, representing template parameters that we work with when doing codegen inside ufunc kernel loops (e.g., where you previously had Tensor, now you have `scalar_t`)
  * `StructuredImplSignature` represents a `TORCH_IMPL_FUNC` definition, and straightforwardly follows from preexisting `tools.codegen.api.structured`
* `tools/codegen/translate.py` - Yes, we use translate a LOT in this PR. I improved some of the documentation, the only substantive changes are adding two new conversions: given a `scalar_t` or a `const Scalar&`, make it convertible to an `opmath_t`
* `tools/codegen/api/ufunc.py`
  * OK, now we're at the meaty stuff. This file represents the calling conventions of three important concepts in ufunc codegen, which we'll describe shortly. All of these APIs are relatively simple, since there aren't any complicated types by the time you get to kernels.
  * stubs are the DispatchStub trampolines that CPU kernels use to get to their vectorized versions. They drop all Tensor arguments (as they are in TensorIterator) but otherwise match the structured calling convention
  * ufuncs are the inner loop template functions that you wrote in ufunc/add.h which do the actual computation in question. Here, all the Tensors and Scalars have been converted into the computation type (`opmath_t` in CUDA, `scalar_t` in CPU)
  * ufunctors are a CUDA-only concept representing functors that take some of their arguments on a host-side constructor, and the rest in the device-side apply. Once again, Tensors and Scalars are converted into the computation type, `opmath_t`, but for clarity all the functions take `scalar_t` as argument (as this is the type that is most salient at the call site). Because the constructor and apply are code generated separately, `ufunctor_arguments` returns a teeny struct `UfunctorBindings`
* `tools/codegen/dest/ufunc.py` - the workhorse. This gets its own section below.
* `tools/codegen/gen.py` - just calling out to the new dest.ufunc implementation to generate UfuncCPU_add.cpp, UFuncCPUKernel_add.cpp and UfuncCUDA_add.cu files per ufunc operator. Each of these files does what you expect (small file that registers kernel and calls stub; CPU implementation; CUDA implementation). There is a new file manager for UFuncCPUKernel files as these need to get replicated by cmake for vectorization. One little trick to avoid recompilation is we directly replicate code generated forward declarations in these files, to reduce the number of headers we depend on (this is codegen, we're just doing the preprocessors job!)
* I'll talk about build system adjustments below.

OK, let's talk about tools/codegen/dest/ufunc.py. This file can be roughly understood in two halves: one for CPU code generation, and the other for CUDA code generation.

**CPU codegen.** Here's roughly what we want to generate:

```
// in UfuncCPU_add.cpp
using add_fn = void (*)(TensorIteratorBase&, const at::Scalar&);
DECLARE_DISPATCH(add_fn, add_stub);
DEFINE_DISPATCH(add_stub);

TORCH_IMPL_FUNC(ufunc_add_CPU)
(const at::Tensor& self, const at::Tensor& other, const at::Scalar& alpha, const at::Tensor& out) {
  add_stub(device_type(), *this, alpha);
}

// in UfuncCPUKernel_add.cpp
void add_kernel(TensorIteratorBase& iter, const at::Scalar& alpha) {
  at::ScalarType st = iter.common_dtype();
  RECORD_KERNEL_FUNCTION_DTYPE("add_stub", st);
  switch (st) {
    AT_PRIVATE_CASE_TYPE("add_stub", at::ScalarType::Bool, bool, [&]() {
      auto _s_alpha = alpha.to<scalar_t>();
      cpu_kernel(iter, [=](scalar_t self, scalar_t other) {
        return ufunc::add(self, other, _s_alpha);
      });
    })

    AT_PRIVATE_CASE_TYPE(
        "add_stub", at::ScalarType::ComplexFloat, c10::complex<float>, [&]() {
          auto _s_alpha = alpha.to<scalar_t>();
          auto _v_alpha = at::vec::Vectorized<scalar_t>(_s_alpha);
          cpu_kernel_vec(
              iter,
              [=](scalar_t self, scalar_t other) {
                return ufunc::add(self, other, _s_alpha);
              },
              [=](at::vec::Vectorized<scalar_t> self,
                  at::vec::Vectorized<scalar_t> other) {
                return ufunc::add(self, other, _v_alpha);
              });
        })

   ...
```

The most interesting change about the generated code is what previously was an `AT_DISPATCH` macro invocation is now an unrolled loop. This makes it easier to vary behavior per-dtype (you can see in this example that the entry for bool and float differ) without having to add extra condtionals on top.

Otherwise, to generate this code, we have to hop through several successive API changes:

* In TORCH_IMPL_FUNC(ufunc_add_CPU), go from StructuredImplSignature to StubSignature (call the stub). This is normal argument massaging in the classic translate style.
* In add_kernel, go from StubSignature to UfuncSignature. This is nontrivial, because we must do various conversions outside of the inner kernel loop. These conversions are done by hand, setting up the context appropriately, and then the final ufunc call is done using translate. (BTW, I introduce a new convention here, call on a Signature, for code generating a C++ call, and I think we should try to use this convention elsewhere)

The other piece of nontrivial logic is the reindexing by dtype. This reindexing exists because the native_functions.yaml format is indexed by UfuncKey:

```
   Generic: add (AllAndComplex, BFloat16, Half)
   ScalarOnly: add (Bool)
```

but when we do code generation, we case on dtype first, and then we generate a `cpu_kernel` or `cpu_kernel_vec` call. We also don't care about CUDA code generation (which Generic) hits. Do this, we lower these keys into two low level keys, CPUScalar and CPUVector, which represent the CPU scalar and CPU vectorized ufuncs, respectively (Generic maps to CPUScalar and CPUVector, while ScalarOnly maps to CPUScalar only). Reindexing then gives us:

```
    AllAndComplex:
        CPUScalar: add
        CPUVector: add
    Bool:
        CPUScalar: add
    ...
```

which is a good format for code generation, but too wordy to force native_functions.yaml authors to write. Note that when reindexing, it is possible for there to be a conflicting definition for the same dtype; we just define a precedence order and have one override the other, so that it is easy to specialize on a particular dtype if necessary. Also note that because CPUScalar/CPUVector are part of UfuncKey, technically you can manually specify them in native_functions.yaml, although I don't expect this functionality to be used.

**CUDA codegen.** CUDA code generation has many of the same ideas as CPU codegen, but it needs to know about functors, and stubs are handled slightly differently.  Here is what we want to generate:

```
template <typename scalar_t>
struct CUDAFunctorOnSelf_add {
  using opmath_t = at::opmath_type<scalar_t>;
  opmath_t other_;
  opmath_t alpha_;
  CUDAFunctorOnSelf_add(opmath_t other, opmath_t alpha)
      : other_(other), alpha_(alpha) {}
  __device__ scalar_t operator()(scalar_t self) {
    return ufunc::add(static_cast<opmath_t>(self), other_, alpha_);
  }
};

... two more functors ...

void add_kernel(TensorIteratorBase& iter, const at::Scalar & alpha) {
  TensorIteratorBase& iter = *this;
  at::ScalarType st = iter.common_dtype();
  RECORD_KERNEL_FUNCTION_DTYPE("ufunc_add_CUDA", st);
  switch (st) {
    AT_PRIVATE_CASE_TYPE("ufunc_add_CUDA", at::ScalarType::Bool, bool, [&]() {
      using opmath_t = at::opmath_type<scalar_t>;
      if (false) {
      } else if (iter.is_cpu_scalar(1)) {
        CUDAFunctorOnOther_add<scalar_t> ufunctor(
            iter.scalar_value<opmath_t>(1), (alpha).to<opmath_t>());
        iter.remove_operand(1);
        gpu_kernel(iter, ufunctor);
      } else if (iter.is_cpu_scalar(2)) {
        CUDAFunctorOnSelf_add<scalar_t> ufunctor(
            iter.scalar_value<opmath_t>(2), (alpha).to<opmath_t>());
        iter.remove_operand(2);
        gpu_kernel(iter, ufunctor);
      } else {
        gpu_kernel(iter, CUDAFunctor_add<scalar_t>((alpha).to<opmath_t>()));
      }
    })

   ...

REGISTER_DISPATCH(add_stub, &add_kernel);

TORCH_IMPL_FUNC(ufunc_add_CUDA)
(const at::Tensor& self,
 const at::Tensor& other,
 const at::Scalar& alpha,
 const at::Tensor& out) {
  add_kernel(*this, alpha);
}

```

The functor business is the bulk of the complexity. Like CPU, we decompose CUDA implementation into three low-level keys: CUDAFunctor (normal, all CUDA kernels will have this), and CUDAFunctorOnOther/CUDAFunctorOnScalar (these are to support Tensor-Scalar specializations when the Scalar lives on CPU). Both Generic and ScalarOnly provide ufuncs for CUDAFunctor, but for us to also lift these into Tensor-Scalar specializations, the operator itself must be eligible for Tensor-Scalar specialization. At the moment, this is hardcoded to be all binary operators, but in the future we can use tags in native_functions.yaml to disambiguate (or perhaps expand codegen to handle n-ary operators).

The reindexing process not only reassociates ufuncs by dtype, but it also works out if Tensor-Scalar specializations are needed and codegens the ufunctors necessary for the level of specialization here (`compute_ufunc_cuda_functors`). Generating the actual kernel (`compute_ufunc_cuda_dtype_body`) just consists of, for each specialization, constructing the functor and then passing it off to `gpu_kernel`. Most of the hard work is in functor generation, where we take care to make sure `operator()` has the correct input and output types (which `gpu_kernel` uses to arrange for memory accesses to the actual CUDA tensor; if you get these types wrong, your kernel will still work, it will just run very slowly!)

There is one big subtlety with CUDA codegen: this won't work:

```
   Generic: add (AllAndComplex, BFloat16, Half)
   ScalarOnly: add_bool (Bool)
```

This is because, even though there are separate Generic/ScalarOnly entries, we only generate a single functor to cover ALL dtypes in this case, and the functor has the ufunc name hardcoded into it. You'll get an error if you try to do this; to fix it, just make sure the ufunc is named the same consistently throughout. In the code, you see this because after testing for the short circuit case (when a user provided the functor themselves), we squash all the generic entries together and assert their ufunc names are the same. Hypothetically, if we generated a separate functor per dtype, we could support differently named ufuncs but... why would you do that to yourself. (One piece of nastiness is that the native_functions.yaml syntax doesn't stop you from shooting yourself in the foot.)

A brief word about CUDA stubs: technically, they are not necessary, as there is no CPU/CPUKernel style split for CUDA kernels (so, if you look, structured impl actually calls add_kernel directly). However, there is some code that still makes use of CUDA stubs (in particular, I use the stub to conveniently reimplement sub in terms of add), so we still register it. This might be worth frying some more at a later point in time.

**Build system changes.** If you are at FB, you should review these changes in fbcode, as there are several changes in files that are not exported to ShipIt.

The build system changes in this patch are substantively complicated by the fact that I have to implement these changes five times:

* OSS cmake build
* OSS Bazel build
* FB fbcode Buck build
* FB xplat Buck build (selective build)
* FB ovrsource Buck build

Due to technical limitations in the xplat Buck build related to selective build, it is required that you list every ufunc header manually (this is done in tools/build_variables.bzl)

The OSS cmake changes are entirely in cmake/Codegen.cmake there is a new set of files cpu_vec_generated (corresponding to UfuncCPUKernel files) which is wired up in the same way as other files. These files are different because they need to get compiled multiple times under different vectorization settings. I adjust the codegen, slightly refactoring the inner loop into its own function so I can use different base path calculation depending on if the file is traditional (in the native/cpu folder) or generated (new stuff from this diff.

The Bazel/Buck changes are organized around tools/build_variables.bzl, which contain the canonical list of ufunc headers (aten_ufunc_headers), and tools/ufunc_defs.bzl (added to ShipIt export list in D34465699) which defines a number of functions that compute the generated cpu, cpu kernel and cuda files based on the headers list. For convenience, these functions take a genpattern (a string with a {} for interpolation) which can be used to easily reformat the list of formats in target form, which is commonly needed in the build systems.

The split between build_variables.bzl and ufunc_defs.bzl is required because build_variables.bzl is executed by a conventional Python interpreter as part of the OSS cmake, but we require Skylark features to implement the functions in ufunc_defs.bzl (I did some quick Googling but didn't find a lightweight way to run the Skylark interpreter in open source.)

With these new file lists, the rest of the build changes are mostly inserting references to these files wherever necessary; in particular, cpu kernel files have to be worked into the multiple vectorization build flow (intern_build_aten_ops in OSS Bazel). Most of the subtlety relates to selective build. Selective build requires operator files to be copied per overall selective build; as dhruvbird explains to me, glob expansion happens during the action graph phase, but the selective build handling of TEMPLATE_SOURCE_LIST is referencing the target graph. In other words, we can't use a glob to generate deps for another rule, because we need to copy files from wherever (included generated files) to a staging folder so the rules can pick them up.

It can be somewhat confusing to understand which bzl files are associated with which build. Here are the relevant mappings for files I edited:

* Used by everyone - tools/build_tools.bzl, tools/ufunc_defs.bzl
* OSS Bazel - aten.bzl, BUILD.bazel
* FB fbcode Buck - TARGETS
* FB xplat Buck -BUCK, pt_defs.bzl, pt_template_srcs.bzl
* FB ovrsource Buck - ovrsource_defs.bzl, pt_defs.bzl

Note that pt_defs.bzl is used by both xplat and ovrsource. This leads to the "tiresome" handling for enabled backends, as selective build is CPU only, but ovrsource is CPU and CUDA.

BTW, while I was at it, I beefed up fb/build_arvr.sh to also do a CUDA ovrsource build, which was not triggered previously.

Signed-off-by: Edward Z. Yang <[email protected]>

Test Plan: Imported from OSS

Reviewed By: albanD

Differential Revision: D31306586

Pulled By: ezyang

fbshipit-source-id: 210258ce83f578f79cf91b77bfaeac34945a00c6
  • Loading branch information
ezyang authored and facebook-github-bot committed Feb 28, 2022
1 parent cd75ec6 commit d65157b
Show file tree
Hide file tree
Showing 22 changed files with 1,326 additions and 137 deletions.
20 changes: 14 additions & 6 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ load("//third_party:substitution.bzl", "header_template_rule")
load("//:tools/build_variables.bzl", "jit_core_sources", "libtorch_core_sources", "libtorch_cuda_sources", "libtorch_distributed_sources", "libtorch_extra_sources", "libtorch_nvfuser_generated_headers", "libtorch_nvfuser_runtime_sources", "libtorch_python_core_sources", "torch_cpp_srcs")
load("//tools/rules:cu.bzl", "cu_library")
load("//tools/config:defs.bzl", "if_cuda")
load("//:aten.bzl", "intern_build_aten_ops", "generate_aten")
load("//:aten.bzl", "intern_build_aten_ops", "generate_aten", "aten_ufunc_generated_cpu_sources", "aten_ufunc_generated_cpu_kernel_sources", "aten_ufunc_generated_cuda_sources")

COMMON_COPTS = [
"-DHAVE_MALLOC_USABLE_SIZE=1",
Expand Down Expand Up @@ -94,9 +94,14 @@ generated_cuda_cpp = [
generate_aten(
name = "generated_aten_cpp",
srcs = aten_generation_srcs,
outs = generated_cpu_cpp + generated_cuda_cpp + [
"aten/src/ATen/Declarations.yaml",
],
outs = (
generated_cpu_cpp +
generated_cuda_cpp +
aten_ufunc_generated_cpu_sources("aten/src/ATen/{}") +
aten_ufunc_generated_cpu_kernel_sources("aten/src/ATen/{}") +
aten_ufunc_generated_cuda_sources("aten/src/ATen/{}") +
["aten/src/ATen/Declarations.yaml"]
),
generator=":gen",
)

Expand Down Expand Up @@ -301,7 +306,9 @@ filegroup(
"aten/src/ATen/native/cuda/*.cu",
"aten/src/ATen/native/quantized/cuda/*.cu",
"aten/src/ATen/native/sparse/cuda/*.cu",
]),
]) + aten_ufunc_generated_cuda_sources("aten/src/ATen/{}"),
# It's a bit puzzling to me why it's not necessary to declare the
# target that generates these sources...
)

header_template_rule(
Expand Down Expand Up @@ -383,6 +390,7 @@ intern_build_aten_ops(
"@fbgemm",
"@mkl",
],
extra_impls = aten_ufunc_generated_cpu_kernel_sources("aten/src/ATen/{}"),
)

cc_library(
Expand All @@ -400,7 +408,7 @@ cc_library(
":aten_native_sparse_cpp",
":aten_native_xnnpack",
":aten_src_ATen_config",
] + generated_cpu_cpp,
] + generated_cpu_cpp + aten_ufunc_generated_cpu_sources("aten/src/ATen/{}"),
copts = ATEN_COPTS,
data = if_cuda(
[":libcaffe2_nvrtc.so"],
Expand Down
44 changes: 43 additions & 1 deletion aten.bzl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
load("@bazel_skylib//lib:paths.bzl", "paths")
load("@rules_cc//cc:defs.bzl", "cc_library")
load("//:tools/build_variables.bzl", "aten_ufunc_headers")

CPU_CAPABILITY_NAMES = ["DEFAULT", "AVX2"]
CAPABILITY_COMPILER_FLAGS = {
Expand All @@ -8,8 +9,9 @@ CAPABILITY_COMPILER_FLAGS = {
}

PREFIX = "aten/src/ATen/native/"
EXTRA_PREFIX = "aten/src/ATen/"

def intern_build_aten_ops(copts, deps):
def intern_build_aten_ops(copts, deps, extra_impls):
for cpu_capability in CPU_CAPABILITY_NAMES:
srcs = []
for impl in native.glob(
Expand All @@ -28,6 +30,17 @@ def intern_build_aten_ops(copts, deps):
)
srcs.append(out)

for impl in extra_impls:
name = impl.replace(EXTRA_PREFIX, "")
out = EXTRA_PREFIX + name + "." + cpu_capability + ".cpp"
native.genrule(
name = name + "_" + cpu_capability + "_cp",
srcs = [impl],
outs = [out],
cmd = "cp $< $@",
)
srcs.append(out)

cc_library(
name = "ATen_CPU_" + cpu_capability,
srcs = srcs,
Expand Down Expand Up @@ -81,3 +94,32 @@ generate_aten = rule(
"srcs": attr.label_list(allow_files = True),
},
)

# copy pasted from ufunc_defs.bzl, as ufuncs_defs.bzl cannot be included
# from BUILD.bazel because it has a directory relative load, and Bazel
# always load from workspace root. The "correct" fix would be to move
# build_variables.bzl to the top level but I don't have time to do this at
# the moment.

aten_ufunc_names = [
paths.split_extension(paths.basename(h))[0]
for h in aten_ufunc_headers
]

def aten_ufunc_generated_cpu_sources(gencode_pattern = "{}"):
return [gencode_pattern.format(name) for name in [
"UfuncCPU_{}.cpp".format(n)
for n in aten_ufunc_names
]]

def aten_ufunc_generated_cpu_kernel_sources(gencode_pattern = "{}"):
return [gencode_pattern.format(name) for name in [
"UfuncCPUKernel_{}.cpp".format(n)
for n in aten_ufunc_names
]]

def aten_ufunc_generated_cuda_sources(gencode_pattern = "{}"):
return [gencode_pattern.format(name) for name in [
"UfuncCUDA_{}.cu".format(n)
for n in aten_ufunc_names
]]
12 changes: 2 additions & 10 deletions aten/src/ATen/native/BinaryOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -232,10 +232,9 @@ CREATE_COMPARISON_SCALAR_TENSOR_META_FUNC(ge);

namespace native {

DEFINE_DISPATCH(add_stub);
DEFINE_DISPATCH(add_clamp_stub);
DEFINE_DISPATCH(sub_stub);
DEFINE_DISPATCH(mul_stub);
DEFINE_DISPATCH(sub_stub);
DEFINE_DISPATCH(div_true_stub);
DEFINE_DISPATCH(div_floor_stub);
DEFINE_DISPATCH(div_trunc_stub);
Expand Down Expand Up @@ -277,17 +276,10 @@ DEFINE_DISPATCH(xlogy_stub);
DEFINE_DISPATCH(xlog1py_stub);
DEFINE_DISPATCH(zeta_stub);

TORCH_IMPL_FUNC(add_out) (
const Tensor& self, const Tensor& other, const Scalar& alpha, const Tensor& result
) {
add_stub(device_type(), *this, alpha);
TORCH_INTERNAL_ASSERT(result.scalar_type() == output().dtype());
}

TORCH_IMPL_FUNC(sub_out) (
const Tensor& self, const Tensor& other, const Scalar& alpha, const Tensor& result
) {
sub_stub(device_type(), *this, alpha);
add_stub(device_type(), *this, -alpha);
TORCH_INTERNAL_ASSERT(result.scalar_type() == output().dtype());
}

Expand Down
2 changes: 2 additions & 0 deletions aten/src/ATen/native/BinaryOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ using binary_fn = void(*)(TensorIterator&);
using binary_clamp_fn_alpha =
void(*)(TensorIterator&, const Scalar& alpha, const Scalar& min_val, const Scalar& max_val);

// NB: codegenned
DECLARE_DISPATCH(structured_binary_fn_alpha, add_stub);

DECLARE_DISPATCH(binary_clamp_fn_alpha, add_clamp_stub);
DECLARE_DISPATCH(structured_binary_fn_alpha, sub_stub);
DECLARE_DISPATCH(structured_binary_fn, mul_stub);
Expand Down
29 changes: 0 additions & 29 deletions aten/src/ATen/native/cpu/BinaryOpsKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,27 +21,6 @@ namespace {

using namespace vec;

// Note: Undefined behavior when performing addition is intentionally
// ignored.
void add_kernel(TensorIteratorBase& iter, const Scalar& alpha_scalar) {
if (iter.dtype() == ScalarType::Bool) {
using scalar_t = bool;
auto alpha = alpha_scalar.to<scalar_t>();
cpu_kernel(iter,
[=](scalar_t a, scalar_t b) __ubsan_ignore_undefined__ -> scalar_t { return a + alpha * b; });
} else {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, iter.dtype(), "add_cpu/sub_cpu", [&]() {
auto alpha = alpha_scalar.to<scalar_t>();
auto alpha_vec = Vectorized<scalar_t>(alpha);
cpu_kernel_vec(iter,
[=](scalar_t a, scalar_t b) __ubsan_ignore_undefined__ -> scalar_t { return a + alpha * b; },
[=](Vectorized<scalar_t> a, Vectorized<scalar_t> b) __ubsan_ignore_undefined__ {
return vec::fmadd(b, alpha_vec, a);
});
});
}
}

void add_clamp_kernel(TensorIterator& iter, const Scalar& alpha_scalar, const Scalar& min_val, const Scalar& max_val) {
AT_DISPATCH_ALL_TYPES(iter.dtype(), "add_clamp_cpu", [&]() {
auto alpha = alpha_scalar.to<scalar_t>();
Expand Down Expand Up @@ -74,12 +53,6 @@ void atan2_kernel(TensorIteratorBase& iter) {
});
}

// Note: Undefined behavior when performing subtraction is intentionally
// ignored.
void sub_kernel(TensorIteratorBase& iter, const Scalar& alpha_scalar) __ubsan_ignore_undefined__ {
add_kernel(iter, -alpha_scalar);
}

void mul_kernel(TensorIteratorBase& iter) {
if (iter.dtype() == ScalarType::Bool) {
cpu_kernel(iter, [=](bool a, bool b) -> bool { return a && b; });
Expand Down Expand Up @@ -1133,9 +1106,7 @@ void zeta_kernel(TensorIteratorBase& iter) {

} // namespace

REGISTER_DISPATCH(add_stub, &add_kernel);
REGISTER_DISPATCH(add_clamp_stub, &add_clamp_kernel);
REGISTER_DISPATCH(sub_stub, &sub_kernel);
REGISTER_DISPATCH(mul_stub, &mul_kernel);
REGISTER_DISPATCH(div_true_stub, &div_true_kernel);
REGISTER_DISPATCH(div_trunc_stub, &div_trunc_kernel);
Expand Down
37 changes: 0 additions & 37 deletions aten/src/ATen/native/cuda/BinaryAddSubKernel.cu

This file was deleted.

4 changes: 3 additions & 1 deletion aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -462,8 +462,10 @@
device_check: NoCheck # TensorIterator
structured: True
structured_inherits: TensorIteratorBase
ufunc_inner_loop:
Generic: add (AllAndComplex, BFloat16, Half)
ScalarOnly: add (Bool)
dispatch:
CPU, CUDA: add_out
SparseCPU: add_out_sparse_cpu
SparseCUDA: add_out_sparse_cuda
SparseCsrCPU: add_out_sparse_csr_cpu
Expand Down
27 changes: 27 additions & 0 deletions aten/src/ATen/native/ufunc/add.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#pragma once

#include <c10/macros/Macros.h>

#if !defined(__CUDACC__) && !defined(__HIPCC__)
#include <ATen/cpu/vec/functional.h>
#include <ATen/cpu/vec/vec.h>
#endif

namespace at {
namespace native {
namespace ufunc {

template <typename T>
C10_HOST_DEVICE C10_ALWAYS_INLINE T add(T self, T other, T alpha) __ubsan_ignore_undefined__ {
return self + alpha * other;
}

#if !defined(__CUDACC__) && !defined(__HIPCC__)
using vec::Vectorized;
template <typename T>
C10_ALWAYS_INLINE Vectorized<T> add(Vectorized<T> self, Vectorized<T> other, Vectorized<T> alpha) __ubsan_ignore_undefined__ {
return vec::fmadd(other, alpha, self);
}
#endif

}}} // namespace at::native::ufunc
19 changes: 19 additions & 0 deletions aten/src/ATen/templates/UfuncCPU.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#define TORCH_ASSERT_NO_OPERATORS

#include <ATen/native/DispatchStub.h>
#include <ATen/TensorIterator.h>
#include <ATen/TensorMeta.h>

namespace at {

// NB: this is explicitly copied here (via codegen) rather than
// included via NativeFunctions.h to avoid recompiling this file when
// NativeFunctions.h changes
namespace meta {
${meta_declaration}
}

namespace native {
${native_declaration}
${native_definitions}
}} // namespace at::native
14 changes: 14 additions & 0 deletions aten/src/ATen/templates/UfuncCPUKernel.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#define TORCH_ASSERT_NO_OPERATORS

#include <ATen/native/ufunc/${name}.h>
#include <ATen/native/DispatchStub.h>
#include <ATen/TensorIterator.h>
#include <ATen/native/cpu/Loops.h>
#include <ATen/cpu/vec/vec.h>
#include <ATen/Dispatch.h>
#include <c10/core/Scalar.h>

namespace at {
namespace native {
${native_definitions}
}} // namespace at::native
21 changes: 21 additions & 0 deletions aten/src/ATen/templates/UfuncCUDA.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#define TORCH_ASSERT_NO_OPERATORS

#include <ATen/native/ufunc/${name}.h>
#include <ATen/Dispatch.h>
#include <ATen/native/DispatchStub.h>
#include <c10/core/Scalar.h>
${cuda_headers}

namespace at {

// NB: this is explicitly copied here (via codegen) rather than
// included via NativeFunctions.h to avoid recompiling this file when
// NativeFunctions.h changes
namespace meta {
${meta_declaration}
}

namespace native {
${native_declaration}
${native_definitions}
}} // namespace at::native
Loading

0 comments on commit d65157b

Please sign in to comment.