From ce7910ba812e4db30d34142ccb80a11aafad8271 Mon Sep 17 00:00:00 2001 From: Edward Yang Date: Mon, 28 Feb 2022 15:46:04 -0800 Subject: [PATCH] ufunc codegen (#65851) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/65851 Design doc: https://docs.google.com/document/d/12rtlHnPUpaJ-I52Iob3L0WA3rKRr_OY7fXqeCvn2MVY/edit First read the design doc to understand the user syntax. In this PR, we have converted add to use ufunc codegen; most of the cpp changes are deleting the preexisting implementations of add, and ufunc/add.h are the new implementations in the ufunc format. The bulk of this PR is in the new codegen machinery. Here's the order to read the files: * `tools/codegen/model.py` * Some self-explanatory utility classes: `ScalarType`, `DTYPE_CLASSES` * New classes for representing ufunc entries in `native_functions.yaml`: `UfuncKey` and `UfuncInnerLoop`, as well as parsing logic for these entries. UfuncKey has some unusual entries (e.g., CPUScalar) that don't show up in the documentation, more on these below). * A predicate `is_ufunc_dispatch_key` for testing which dispatch keys should get automatically generated when an operator opts into ufuncs (CPU and CUDA, for now!) * `tools/codegen/api/types.py` * More self-explanatory utility stuff: ScalarTypeToCppMapping mapping ScalarType to CppTypes; Binding.rename for changing the name of a binding (used when we assign constructor variables to member variables inside CUDA functors) * New VectorizedCType, representing `at::vec::Vectorized`. This is used inside vectorized CPU codegen. * New `scalar_t` and `opmath_t` BaseCppTypes, representing template parameters that we work with when doing codegen inside ufunc kernel loops (e.g., where you previously had Tensor, now you have `scalar_t`) * `StructuredImplSignature` represents a `TORCH_IMPL_FUNC` definition, and straightforwardly follows from preexisting `tools.codegen.api.structured` * `tools/codegen/translate.py` - Yes, we use translate a LOT in this PR. I improved some of the documentation, the only substantive changes are adding two new conversions: given a `scalar_t` or a `const Scalar&`, make it convertible to an `opmath_t` * `tools/codegen/api/ufunc.py` * OK, now we're at the meaty stuff. This file represents the calling conventions of three important concepts in ufunc codegen, which we'll describe shortly. All of these APIs are relatively simple, since there aren't any complicated types by the time you get to kernels. * stubs are the DispatchStub trampolines that CPU kernels use to get to their vectorized versions. They drop all Tensor arguments (as they are in TensorIterator) but otherwise match the structured calling convention * ufuncs are the inner loop template functions that you wrote in ufunc/add.h which do the actual computation in question. Here, all the Tensors and Scalars have been converted into the computation type (`opmath_t` in CUDA, `scalar_t` in CPU) * ufunctors are a CUDA-only concept representing functors that take some of their arguments on a host-side constructor, and the rest in the device-side apply. Once again, Tensors and Scalars are converted into the computation type, `opmath_t`, but for clarity all the functions take `scalar_t` as argument (as this is the type that is most salient at the call site). Because the constructor and apply are code generated separately, `ufunctor_arguments` returns a teeny struct `UfunctorBindings` * `tools/codegen/dest/ufunc.py` - the workhorse. This gets its own section below. * `tools/codegen/gen.py` - just calling out to the new dest.ufunc implementation to generate UfuncCPU_add.cpp, UFuncCPUKernel_add.cpp and UfuncCUDA_add.cu files per ufunc operator. Each of these files does what you expect (small file that registers kernel and calls stub; CPU implementation; CUDA implementation). There is a new file manager for UFuncCPUKernel files as these need to get replicated by cmake for vectorization. One little trick to avoid recompilation is we directly replicate code generated forward declarations in these files, to reduce the number of headers we depend on (this is codegen, we're just doing the preprocessors job!) * I'll talk about build system adjustments below. OK, let's talk about tools/codegen/dest/ufunc.py. This file can be roughly understood in two halves: one for CPU code generation, and the other for CUDA code generation. **CPU codegen.** Here's roughly what we want to generate: ``` // in UfuncCPU_add.cpp using add_fn = void (*)(TensorIteratorBase&, const at::Scalar&); DECLARE_DISPATCH(add_fn, add_stub); DEFINE_DISPATCH(add_stub); TORCH_IMPL_FUNC(ufunc_add_CPU) (const at::Tensor& self, const at::Tensor& other, const at::Scalar& alpha, const at::Tensor& out) { add_stub(device_type(), *this, alpha); } // in UfuncCPUKernel_add.cpp void add_kernel(TensorIteratorBase& iter, const at::Scalar& alpha) { at::ScalarType st = iter.common_dtype(); RECORD_KERNEL_FUNCTION_DTYPE("add_stub", st); switch (st) { AT_PRIVATE_CASE_TYPE("add_stub", at::ScalarType::Bool, bool, [&]() { auto _s_alpha = alpha.to(); cpu_kernel(iter, [=](scalar_t self, scalar_t other) { return ufunc::add(self, other, _s_alpha); }); }) AT_PRIVATE_CASE_TYPE( "add_stub", at::ScalarType::ComplexFloat, c10::complex, [&]() { auto _s_alpha = alpha.to(); auto _v_alpha = at::vec::Vectorized(_s_alpha); cpu_kernel_vec( iter, [=](scalar_t self, scalar_t other) { return ufunc::add(self, other, _s_alpha); }, [=](at::vec::Vectorized self, at::vec::Vectorized other) { return ufunc::add(self, other, _v_alpha); }); }) ... ``` The most interesting change about the generated code is what previously was an `AT_DISPATCH` macro invocation is now an unrolled loop. This makes it easier to vary behavior per-dtype (you can see in this example that the entry for bool and float differ) without having to add extra condtionals on top. Otherwise, to generate this code, we have to hop through several successive API changes: * In TORCH_IMPL_FUNC(ufunc_add_CPU), go from StructuredImplSignature to StubSignature (call the stub). This is normal argument massaging in the classic translate style. * In add_kernel, go from StubSignature to UfuncSignature. This is nontrivial, because we must do various conversions outside of the inner kernel loop. These conversions are done by hand, setting up the context appropriately, and then the final ufunc call is done using translate. (BTW, I introduce a new convention here, call on a Signature, for code generating a C++ call, and I think we should try to use this convention elsewhere) The other piece of nontrivial logic is the reindexing by dtype. This reindexing exists because the native_functions.yaml format is indexed by UfuncKey: ``` Generic: add (AllAndComplex, BFloat16, Half) ScalarOnly: add (Bool) ``` but when we do code generation, we case on dtype first, and then we generate a `cpu_kernel` or `cpu_kernel_vec` call. We also don't care about CUDA code generation (which Generic) hits. Do this, we lower these keys into two low level keys, CPUScalar and CPUVector, which represent the CPU scalar and CPU vectorized ufuncs, respectively (Generic maps to CPUScalar and CPUVector, while ScalarOnly maps to CPUScalar only). Reindexing then gives us: ``` AllAndComplex: CPUScalar: add CPUVector: add Bool: CPUScalar: add ... ``` which is a good format for code generation, but too wordy to force native_functions.yaml authors to write. Note that when reindexing, it is possible for there to be a conflicting definition for the same dtype; we just define a precedence order and have one override the other, so that it is easy to specialize on a particular dtype if necessary. Also note that because CPUScalar/CPUVector are part of UfuncKey, technically you can manually specify them in native_functions.yaml, although I don't expect this functionality to be used. **CUDA codegen.** CUDA code generation has many of the same ideas as CPU codegen, but it needs to know about functors, and stubs are handled slightly differently. Here is what we want to generate: ``` template struct CUDAFunctorOnSelf_add { using opmath_t = at::opmath_type; opmath_t other_; opmath_t alpha_; CUDAFunctorOnSelf_add(opmath_t other, opmath_t alpha) : other_(other), alpha_(alpha) {} __device__ scalar_t operator()(scalar_t self) { return ufunc::add(static_cast(self), other_, alpha_); } }; ... two more functors ... void add_kernel(TensorIteratorBase& iter, const at::Scalar & alpha) { TensorIteratorBase& iter = *this; at::ScalarType st = iter.common_dtype(); RECORD_KERNEL_FUNCTION_DTYPE("ufunc_add_CUDA", st); switch (st) { AT_PRIVATE_CASE_TYPE("ufunc_add_CUDA", at::ScalarType::Bool, bool, [&]() { using opmath_t = at::opmath_type; if (false) { } else if (iter.is_cpu_scalar(1)) { CUDAFunctorOnOther_add ufunctor( iter.scalar_value(1), (alpha).to()); iter.remove_operand(1); gpu_kernel(iter, ufunctor); } else if (iter.is_cpu_scalar(2)) { CUDAFunctorOnSelf_add ufunctor( iter.scalar_value(2), (alpha).to()); iter.remove_operand(2); gpu_kernel(iter, ufunctor); } else { gpu_kernel(iter, CUDAFunctor_add((alpha).to())); } }) ... REGISTER_DISPATCH(add_stub, &add_kernel); TORCH_IMPL_FUNC(ufunc_add_CUDA) (const at::Tensor& self, const at::Tensor& other, const at::Scalar& alpha, const at::Tensor& out) { add_kernel(*this, alpha); } ``` The functor business is the bulk of the complexity. Like CPU, we decompose CUDA implementation into three low-level keys: CUDAFunctor (normal, all CUDA kernels will have this), and CUDAFunctorOnOther/CUDAFunctorOnScalar (these are to support Tensor-Scalar specializations when the Scalar lives on CPU). Both Generic and ScalarOnly provide ufuncs for CUDAFunctor, but for us to also lift these into Tensor-Scalar specializations, the operator itself must be eligible for Tensor-Scalar specialization. At the moment, this is hardcoded to be all binary operators, but in the future we can use tags in native_functions.yaml to disambiguate (or perhaps expand codegen to handle n-ary operators). The reindexing process not only reassociates ufuncs by dtype, but it also works out if Tensor-Scalar specializations are needed and codegens the ufunctors necessary for the level of specialization here (`compute_ufunc_cuda_functors`). Generating the actual kernel (`compute_ufunc_cuda_dtype_body`) just consists of, for each specialization, constructing the functor and then passing it off to `gpu_kernel`. Most of the hard work is in functor generation, where we take care to make sure `operator()` has the correct input and output types (which `gpu_kernel` uses to arrange for memory accesses to the actual CUDA tensor; if you get these types wrong, your kernel will still work, it will just run very slowly!) There is one big subtlety with CUDA codegen: this won't work: ``` Generic: add (AllAndComplex, BFloat16, Half) ScalarOnly: add_bool (Bool) ``` This is because, even though there are separate Generic/ScalarOnly entries, we only generate a single functor to cover ALL dtypes in this case, and the functor has the ufunc name hardcoded into it. You'll get an error if you try to do this; to fix it, just make sure the ufunc is named the same consistently throughout. In the code, you see this because after testing for the short circuit case (when a user provided the functor themselves), we squash all the generic entries together and assert their ufunc names are the same. Hypothetically, if we generated a separate functor per dtype, we could support differently named ufuncs but... why would you do that to yourself. (One piece of nastiness is that the native_functions.yaml syntax doesn't stop you from shooting yourself in the foot.) A brief word about CUDA stubs: technically, they are not necessary, as there is no CPU/CPUKernel style split for CUDA kernels (so, if you look, structured impl actually calls add_kernel directly). However, there is some code that still makes use of CUDA stubs (in particular, I use the stub to conveniently reimplement sub in terms of add), so we still register it. This might be worth frying some more at a later point in time. **Build system changes.** If you are at FB, you should review these changes in fbcode, as there are several changes in files that are not exported to ShipIt. The build system changes in this patch are substantively complicated by the fact that I have to implement these changes five times: * OSS cmake build * OSS Bazel build * FB fbcode Buck build * FB xplat Buck build (selective build) * FB ovrsource Buck build Due to technical limitations in the xplat Buck build related to selective build, it is required that you list every ufunc header manually (this is done in tools/build_variables.bzl) The OSS cmake changes are entirely in cmake/Codegen.cmake there is a new set of files cpu_vec_generated (corresponding to UfuncCPUKernel files) which is wired up in the same way as other files. These files are different because they need to get compiled multiple times under different vectorization settings. I adjust the codegen, slightly refactoring the inner loop into its own function so I can use different base path calculation depending on if the file is traditional (in the native/cpu folder) or generated (new stuff from this diff. The Bazel/Buck changes are organized around tools/build_variables.bzl, which contain the canonical list of ufunc headers (aten_ufunc_headers), and tools/ufunc_defs.bzl (added to ShipIt export list in D34465699) which defines a number of functions that compute the generated cpu, cpu kernel and cuda files based on the headers list. For convenience, these functions take a genpattern (a string with a {} for interpolation) which can be used to easily reformat the list of formats in target form, which is commonly needed in the build systems. The split between build_variables.bzl and ufunc_defs.bzl is required because build_variables.bzl is executed by a conventional Python interpreter as part of the OSS cmake, but we require Skylark features to implement the functions in ufunc_defs.bzl (I did some quick Googling but didn't find a lightweight way to run the Skylark interpreter in open source.) With these new file lists, the rest of the build changes are mostly inserting references to these files wherever necessary; in particular, cpu kernel files have to be worked into the multiple vectorization build flow (intern_build_aten_ops in OSS Bazel). Most of the subtlety relates to selective build. Selective build requires operator files to be copied per overall selective build; as dhruvbird explains to me, glob expansion happens during the action graph phase, but the selective build handling of TEMPLATE_SOURCE_LIST is referencing the target graph. In other words, we can't use a glob to generate deps for another rule, because we need to copy files from wherever (included generated files) to a staging folder so the rules can pick them up. It can be somewhat confusing to understand which bzl files are associated with which build. Here are the relevant mappings for files I edited: * Used by everyone - tools/build_tools.bzl, tools/ufunc_defs.bzl * OSS Bazel - aten.bzl, BUILD.bazel * FB fbcode Buck - TARGETS * FB xplat Buck -BUCK, pt_defs.bzl, pt_template_srcs.bzl * FB ovrsource Buck - ovrsource_defs.bzl, pt_defs.bzl Note that pt_defs.bzl is used by both xplat and ovrsource. This leads to the "tiresome" handling for enabled backends, as selective build is CPU only, but ovrsource is CPU and CUDA. BTW, while I was at it, I beefed up fb/build_arvr.sh to also do a CUDA ovrsource build, which was not triggered previously. Signed-off-by: Edward Z. Yang Test Plan: Imported from OSS Reviewed By: albanD Differential Revision: D31306586 Pulled By: ezyang fbshipit-source-id: 210258ce83f578f79cf91b77bfaeac34945a00c6 (cherry picked from commit d65157b0b894b6701ee062f05a5f57790a06c91c) --- BUILD.bazel | 20 +- aten.bzl | 44 +- aten/src/ATen/native/BinaryOps.cpp | 12 +- aten/src/ATen/native/BinaryOps.h | 2 + aten/src/ATen/native/cpu/BinaryOpsKernel.cpp | 29 -- .../ATen/native/cuda/BinaryAddSubKernel.cu | 37 -- aten/src/ATen/native/native_functions.yaml | 4 +- aten/src/ATen/native/ufunc/add.h | 27 + aten/src/ATen/templates/UfuncCPU.cpp | 19 + aten/src/ATen/templates/UfuncCPUKernel.cpp | 14 + aten/src/ATen/templates/UfuncCUDA.cu | 21 + cmake/Codegen.cmake | 20 +- tools/build_variables.bzl | 17 +- tools/codegen/api/translate.py | 39 +- tools/codegen/api/types.py | 68 ++- tools/codegen/api/ufunc.py | 176 +++++++ tools/codegen/dest/__init__.py | 5 + tools/codegen/dest/ufunc.py | 477 ++++++++++++++++++ tools/codegen/gen.py | 141 ++++-- tools/codegen/model.py | 142 +++++- tools/test/test_codegen_model.py | 124 +++++ tools/ufunc_defs.bzl | 25 + 22 files changed, 1326 insertions(+), 137 deletions(-) delete mode 100644 aten/src/ATen/native/cuda/BinaryAddSubKernel.cu create mode 100644 aten/src/ATen/native/ufunc/add.h create mode 100644 aten/src/ATen/templates/UfuncCPU.cpp create mode 100644 aten/src/ATen/templates/UfuncCPUKernel.cpp create mode 100644 aten/src/ATen/templates/UfuncCUDA.cu create mode 100644 tools/codegen/api/ufunc.py create mode 100644 tools/codegen/dest/ufunc.py create mode 100644 tools/test/test_codegen_model.py create mode 100644 tools/ufunc_defs.bzl diff --git a/BUILD.bazel b/BUILD.bazel index d9780aa23c3dd..ba509759adc0b 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -6,7 +6,7 @@ load("//third_party:substitution.bzl", "header_template_rule") load("//:tools/build_variables.bzl", "jit_core_sources", "libtorch_core_sources", "libtorch_cuda_sources", "libtorch_distributed_sources", "libtorch_extra_sources", "libtorch_nvfuser_generated_headers", "libtorch_nvfuser_runtime_sources", "libtorch_python_core_sources", "torch_cpp_srcs") load("//tools/rules:cu.bzl", "cu_library") load("//tools/config:defs.bzl", "if_cuda") -load("//:aten.bzl", "intern_build_aten_ops", "generate_aten") +load("//:aten.bzl", "intern_build_aten_ops", "generate_aten", "aten_ufunc_generated_cpu_sources", "aten_ufunc_generated_cpu_kernel_sources", "aten_ufunc_generated_cuda_sources") COMMON_COPTS = [ "-DHAVE_MALLOC_USABLE_SIZE=1", @@ -94,9 +94,14 @@ generated_cuda_cpp = [ generate_aten( name = "generated_aten_cpp", srcs = aten_generation_srcs, - outs = generated_cpu_cpp + generated_cuda_cpp + [ - "aten/src/ATen/Declarations.yaml", - ], + outs = ( + generated_cpu_cpp + + generated_cuda_cpp + + aten_ufunc_generated_cpu_sources("aten/src/ATen/{}") + + aten_ufunc_generated_cpu_kernel_sources("aten/src/ATen/{}") + + aten_ufunc_generated_cuda_sources("aten/src/ATen/{}") + + ["aten/src/ATen/Declarations.yaml"] + ), generator=":gen", ) @@ -301,7 +306,9 @@ filegroup( "aten/src/ATen/native/cuda/*.cu", "aten/src/ATen/native/quantized/cuda/*.cu", "aten/src/ATen/native/sparse/cuda/*.cu", - ]), + ]) + aten_ufunc_generated_cuda_sources("aten/src/ATen/{}"), + # It's a bit puzzling to me why it's not necessary to declare the + # target that generates these sources... ) header_template_rule( @@ -383,6 +390,7 @@ intern_build_aten_ops( "@fbgemm", "@mkl", ], + extra_impls = aten_ufunc_generated_cpu_kernel_sources("aten/src/ATen/{}"), ) cc_library( @@ -400,7 +408,7 @@ cc_library( ":aten_native_sparse_cpp", ":aten_native_xnnpack", ":aten_src_ATen_config", - ] + generated_cpu_cpp, + ] + generated_cpu_cpp + aten_ufunc_generated_cpu_sources("aten/src/ATen/{}"), copts = ATEN_COPTS, data = if_cuda( [":libcaffe2_nvrtc.so"], diff --git a/aten.bzl b/aten.bzl index eccdb4b4d0cdd..c97f22284f106 100644 --- a/aten.bzl +++ b/aten.bzl @@ -1,5 +1,6 @@ load("@bazel_skylib//lib:paths.bzl", "paths") load("@rules_cc//cc:defs.bzl", "cc_library") +load("//:tools/build_variables.bzl", "aten_ufunc_headers") CPU_CAPABILITY_NAMES = ["DEFAULT", "AVX2"] CAPABILITY_COMPILER_FLAGS = { @@ -8,8 +9,9 @@ CAPABILITY_COMPILER_FLAGS = { } PREFIX = "aten/src/ATen/native/" +EXTRA_PREFIX = "aten/src/ATen/" -def intern_build_aten_ops(copts, deps): +def intern_build_aten_ops(copts, deps, extra_impls): for cpu_capability in CPU_CAPABILITY_NAMES: srcs = [] for impl in native.glob( @@ -28,6 +30,17 @@ def intern_build_aten_ops(copts, deps): ) srcs.append(out) + for impl in extra_impls: + name = impl.replace(EXTRA_PREFIX, "") + out = EXTRA_PREFIX + name + "." + cpu_capability + ".cpp" + native.genrule( + name = name + "_" + cpu_capability + "_cp", + srcs = [impl], + outs = [out], + cmd = "cp $< $@", + ) + srcs.append(out) + cc_library( name = "ATen_CPU_" + cpu_capability, srcs = srcs, @@ -81,3 +94,32 @@ generate_aten = rule( "srcs": attr.label_list(allow_files = True), }, ) + +# copy pasted from ufunc_defs.bzl, as ufuncs_defs.bzl cannot be included +# from BUILD.bazel because it has a directory relative load, and Bazel +# always load from workspace root. The "correct" fix would be to move +# build_variables.bzl to the top level but I don't have time to do this at +# the moment. + +aten_ufunc_names = [ + paths.split_extension(paths.basename(h))[0] + for h in aten_ufunc_headers +] + +def aten_ufunc_generated_cpu_sources(gencode_pattern = "{}"): + return [gencode_pattern.format(name) for name in [ + "UfuncCPU_{}.cpp".format(n) + for n in aten_ufunc_names + ]] + +def aten_ufunc_generated_cpu_kernel_sources(gencode_pattern = "{}"): + return [gencode_pattern.format(name) for name in [ + "UfuncCPUKernel_{}.cpp".format(n) + for n in aten_ufunc_names + ]] + +def aten_ufunc_generated_cuda_sources(gencode_pattern = "{}"): + return [gencode_pattern.format(name) for name in [ + "UfuncCUDA_{}.cu".format(n) + for n in aten_ufunc_names + ]] diff --git a/aten/src/ATen/native/BinaryOps.cpp b/aten/src/ATen/native/BinaryOps.cpp index bdd6c87403e3b..437835d7a8665 100644 --- a/aten/src/ATen/native/BinaryOps.cpp +++ b/aten/src/ATen/native/BinaryOps.cpp @@ -232,10 +232,9 @@ CREATE_COMPARISON_SCALAR_TENSOR_META_FUNC(ge); namespace native { -DEFINE_DISPATCH(add_stub); DEFINE_DISPATCH(add_clamp_stub); -DEFINE_DISPATCH(sub_stub); DEFINE_DISPATCH(mul_stub); +DEFINE_DISPATCH(sub_stub); DEFINE_DISPATCH(div_true_stub); DEFINE_DISPATCH(div_floor_stub); DEFINE_DISPATCH(div_trunc_stub); @@ -277,17 +276,10 @@ DEFINE_DISPATCH(xlogy_stub); DEFINE_DISPATCH(xlog1py_stub); DEFINE_DISPATCH(zeta_stub); -TORCH_IMPL_FUNC(add_out) ( - const Tensor& self, const Tensor& other, const Scalar& alpha, const Tensor& result -) { - add_stub(device_type(), *this, alpha); - TORCH_INTERNAL_ASSERT(result.scalar_type() == output().dtype()); -} - TORCH_IMPL_FUNC(sub_out) ( const Tensor& self, const Tensor& other, const Scalar& alpha, const Tensor& result ) { - sub_stub(device_type(), *this, alpha); + add_stub(device_type(), *this, -alpha); TORCH_INTERNAL_ASSERT(result.scalar_type() == output().dtype()); } diff --git a/aten/src/ATen/native/BinaryOps.h b/aten/src/ATen/native/BinaryOps.h index 4bdf587f0bdcb..f34f210c4e484 100644 --- a/aten/src/ATen/native/BinaryOps.h +++ b/aten/src/ATen/native/BinaryOps.h @@ -50,7 +50,9 @@ using binary_fn = void(*)(TensorIterator&); using binary_clamp_fn_alpha = void(*)(TensorIterator&, const Scalar& alpha, const Scalar& min_val, const Scalar& max_val); +// NB: codegenned DECLARE_DISPATCH(structured_binary_fn_alpha, add_stub); + DECLARE_DISPATCH(binary_clamp_fn_alpha, add_clamp_stub); DECLARE_DISPATCH(structured_binary_fn_alpha, sub_stub); DECLARE_DISPATCH(structured_binary_fn, mul_stub); diff --git a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp index d383849e290ab..0e5db26b069dc 100644 --- a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp @@ -21,27 +21,6 @@ namespace { using namespace vec; -// Note: Undefined behavior when performing addition is intentionally -// ignored. -void add_kernel(TensorIteratorBase& iter, const Scalar& alpha_scalar) { - if (iter.dtype() == ScalarType::Bool) { - using scalar_t = bool; - auto alpha = alpha_scalar.to(); - cpu_kernel(iter, - [=](scalar_t a, scalar_t b) __ubsan_ignore_undefined__ -> scalar_t { return a + alpha * b; }); - } else { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, iter.dtype(), "add_cpu/sub_cpu", [&]() { - auto alpha = alpha_scalar.to(); - auto alpha_vec = Vectorized(alpha); - cpu_kernel_vec(iter, - [=](scalar_t a, scalar_t b) __ubsan_ignore_undefined__ -> scalar_t { return a + alpha * b; }, - [=](Vectorized a, Vectorized b) __ubsan_ignore_undefined__ { - return vec::fmadd(b, alpha_vec, a); - }); - }); - } -} - void add_clamp_kernel(TensorIterator& iter, const Scalar& alpha_scalar, const Scalar& min_val, const Scalar& max_val) { AT_DISPATCH_ALL_TYPES(iter.dtype(), "add_clamp_cpu", [&]() { auto alpha = alpha_scalar.to(); @@ -74,12 +53,6 @@ void atan2_kernel(TensorIteratorBase& iter) { }); } -// Note: Undefined behavior when performing subtraction is intentionally -// ignored. -void sub_kernel(TensorIteratorBase& iter, const Scalar& alpha_scalar) __ubsan_ignore_undefined__ { - add_kernel(iter, -alpha_scalar); -} - void mul_kernel(TensorIteratorBase& iter) { if (iter.dtype() == ScalarType::Bool) { cpu_kernel(iter, [=](bool a, bool b) -> bool { return a && b; }); @@ -1133,9 +1106,7 @@ void zeta_kernel(TensorIteratorBase& iter) { } // namespace -REGISTER_DISPATCH(add_stub, &add_kernel); REGISTER_DISPATCH(add_clamp_stub, &add_clamp_kernel); -REGISTER_DISPATCH(sub_stub, &sub_kernel); REGISTER_DISPATCH(mul_stub, &mul_kernel); REGISTER_DISPATCH(div_true_stub, &div_true_kernel); REGISTER_DISPATCH(div_trunc_stub, &div_trunc_kernel); diff --git a/aten/src/ATen/native/cuda/BinaryAddSubKernel.cu b/aten/src/ATen/native/cuda/BinaryAddSubKernel.cu deleted file mode 100644 index 56d6b0acd728a..0000000000000 --- a/aten/src/ATen/native/cuda/BinaryAddSubKernel.cu +++ /dev/null @@ -1,37 +0,0 @@ -#define TORCH_ASSERT_NO_OPERATORS -#include -#include -#include -#include -#include -#include - -// NOTE: CUDA on Windows requires that the enclosing function -// of a __device__ lambda not have internal linkage. - -namespace at { namespace native { - -template -struct AddFunctor { - AddFunctor(T alpha) : alpha_(alpha) {} - T alpha_; - __device__ __forceinline__ T operator()(T a, T b) const __ubsan_ignore_undefined__ { - return a + b * alpha_; - } -}; - -void add_kernel_cuda(TensorIteratorBase& iter, const Scalar& alpha_scalar) { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kHalf, kBool, kBFloat16, iter.common_dtype(), "add_cuda/sub_cuda", [&]() { - using opmath_t = at::opmath_type; - opmath_gpu_kernel_with_scalars(iter, AddFunctor(alpha_scalar.to())); - }); -} - -static void sub_kernel_cuda(TensorIteratorBase& iter, const Scalar& alpha_scalar) { - add_kernel_cuda(iter, -alpha_scalar); -} - -REGISTER_DISPATCH(add_stub, &add_kernel_cuda); -REGISTER_DISPATCH(sub_stub, &sub_kernel_cuda); - -}} // namespace at::native diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 0b6467775e7ff..2fce3ebaa1137 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -462,8 +462,10 @@ device_check: NoCheck # TensorIterator structured: True structured_inherits: TensorIteratorBase + ufunc_inner_loop: + Generic: add (AllAndComplex, BFloat16, Half) + ScalarOnly: add (Bool) dispatch: - CPU, CUDA: add_out SparseCPU: add_out_sparse_cpu SparseCUDA: add_out_sparse_cuda SparseCsrCPU: add_out_sparse_csr_cpu diff --git a/aten/src/ATen/native/ufunc/add.h b/aten/src/ATen/native/ufunc/add.h new file mode 100644 index 0000000000000..94a776728eadc --- /dev/null +++ b/aten/src/ATen/native/ufunc/add.h @@ -0,0 +1,27 @@ +#pragma once + +#include + +#if !defined(__CUDACC__) && !defined(__HIPCC__) +#include +#include +#endif + +namespace at { +namespace native { +namespace ufunc { + +template +C10_HOST_DEVICE C10_ALWAYS_INLINE T add(T self, T other, T alpha) __ubsan_ignore_undefined__ { + return self + alpha * other; +} + +#if !defined(__CUDACC__) && !defined(__HIPCC__) +using vec::Vectorized; +template +C10_ALWAYS_INLINE Vectorized add(Vectorized self, Vectorized other, Vectorized alpha) __ubsan_ignore_undefined__ { + return vec::fmadd(other, alpha, self); +} +#endif + +}}} // namespace at::native::ufunc diff --git a/aten/src/ATen/templates/UfuncCPU.cpp b/aten/src/ATen/templates/UfuncCPU.cpp new file mode 100644 index 0000000000000..6b363a508907c --- /dev/null +++ b/aten/src/ATen/templates/UfuncCPU.cpp @@ -0,0 +1,19 @@ +#define TORCH_ASSERT_NO_OPERATORS + +#include +#include +#include + +namespace at { + +// NB: this is explicitly copied here (via codegen) rather than +// included via NativeFunctions.h to avoid recompiling this file when +// NativeFunctions.h changes +namespace meta { +${meta_declaration} +} + +namespace native { +${native_declaration} +${native_definitions} +}} // namespace at::native diff --git a/aten/src/ATen/templates/UfuncCPUKernel.cpp b/aten/src/ATen/templates/UfuncCPUKernel.cpp new file mode 100644 index 0000000000000..0cac55664d612 --- /dev/null +++ b/aten/src/ATen/templates/UfuncCPUKernel.cpp @@ -0,0 +1,14 @@ +#define TORCH_ASSERT_NO_OPERATORS + +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { +${native_definitions} +}} // namespace at::native diff --git a/aten/src/ATen/templates/UfuncCUDA.cu b/aten/src/ATen/templates/UfuncCUDA.cu new file mode 100644 index 0000000000000..e75d82d9cc84b --- /dev/null +++ b/aten/src/ATen/templates/UfuncCUDA.cu @@ -0,0 +1,21 @@ +#define TORCH_ASSERT_NO_OPERATORS + +#include +#include +#include +#include +${cuda_headers} + +namespace at { + +// NB: this is explicitly copied here (via codegen) rather than +// included via NativeFunctions.h to avoid recompiling this file when +// NativeFunctions.h changes +namespace meta { +${meta_declaration} +} + +namespace native { +${native_declaration} +${native_definitions} +}} // namespace at::native diff --git a/cmake/Codegen.cmake b/cmake/Codegen.cmake index bb573fc35cc95..d4db507a98a32 100644 --- a/cmake/Codegen.cmake +++ b/cmake/Codegen.cmake @@ -150,6 +150,7 @@ if(INTERN_BUILD_ATEN_OPS) include("${CMAKE_BINARY_DIR}/aten/src/ATen/generated_${gen_type}.cmake") include("${CMAKE_BINARY_DIR}/aten/src/ATen/core_generated_${gen_type}.cmake") + include("${CMAKE_BINARY_DIR}/aten/src/ATen/cpu_vec_generated_${gen_type}.cmake") include("${CMAKE_BINARY_DIR}/aten/src/ATen/cuda_generated_${gen_type}.cmake") include("${CMAKE_BINARY_DIR}/aten/src/ATen/ops_generated_${gen_type}.cmake") @@ -161,10 +162,12 @@ if(INTERN_BUILD_ATEN_OPS) ${generated_${gen_type}} ${cuda_generated_${gen_type}} ${core_generated_${gen_type}} + ${cpu_vec_generated_${gen_type}} ${ops_generated_${gen_type}} ${CMAKE_BINARY_DIR}/aten/src/ATen/generated_${gen_type}.cmake ${CMAKE_BINARY_DIR}/aten/src/ATen/ops_generated_${gen_type}.cmake ${CMAKE_BINARY_DIR}/aten/src/ATen/core_generated_${gen_type}.cmake + ${CMAKE_BINARY_DIR}/aten/src/ATen/cpu_vec_generated_${gen_type}.cmake ${CMAKE_BINARY_DIR}/aten/src/ATen/cuda_generated_${gen_type}.cmake COMMAND ${GEN_COMMAND_${gen_type}} DEPENDS ${all_python} ${${gen_type}_templates} @@ -177,8 +180,8 @@ if(INTERN_BUILD_ATEN_OPS) # not tracked correctly in CMake. We make the libATen.so depend explicitly # on building the generated ATen files to workaround. add_custom_target(ATEN_CPU_FILES_GEN_TARGET DEPENDS - ${generated_headers} ${core_generated_headers} ${ops_generated_headers} - ${generated_sources} ${core_generated_sources} ${ops_generated_sources} + ${generated_headers} ${core_generated_headers} ${cpu_vec_generated_headers} ${ops_generated_headers} + ${generated_sources} ${core_generated_sources} ${cpu_vec_generated_sources} ${ops_generated_sources} ${generated_declarations_yaml}) add_custom_target(ATEN_CUDA_FILES_GEN_TARGET DEPENDS ${cuda_generated_headers} ${cuda_generated_sources}) @@ -260,12 +263,11 @@ if(INTERN_BUILD_ATEN_OPS) # The sources list might get reordered later based on the capabilites. # See NOTE [ Linking AVX and non-AVX files ] foreach(i RANGE ${NUM_CPU_CAPABILITY_NAMES}) - foreach(IMPL ${cpu_kernel_cpp_in}) - file(RELATIVE_PATH NAME "${PROJECT_SOURCE_DIR}/aten/src/ATen/" "${IMPL}") + function(process_vec NAME) list(GET CPU_CAPABILITY_NAMES ${i} CPU_CAPABILITY) set(NEW_IMPL ${CMAKE_BINARY_DIR}/aten/src/ATen/${NAME}.${CPU_CAPABILITY}.cpp) configure_file("${PROJECT_SOURCE_DIR}/cmake/IncludeSource.cpp.in" ${NEW_IMPL}) - set(cpu_kernel_cpp ${NEW_IMPL} ${cpu_kernel_cpp}) # Create list of copies + set(cpu_kernel_cpp ${NEW_IMPL} ${cpu_kernel_cpp} PARENT_SCOPE) # Create list of copies list(GET CPU_CAPABILITY_FLAGS ${i} FLAGS) if(MSVC) set(EXTRA_FLAGS "/DCPU_CAPABILITY=${CPU_CAPABILITY} /DCPU_CAPABILITY_${CPU_CAPABILITY}") @@ -284,6 +286,14 @@ if(INTERN_BUILD_ATEN_OPS) endif() endif() set_source_files_properties(${NEW_IMPL} PROPERTIES COMPILE_FLAGS "${FLAGS} ${EXTRA_FLAGS}") + endfunction() + foreach(IMPL ${cpu_kernel_cpp_in}) + file(RELATIVE_PATH NAME "${PROJECT_SOURCE_DIR}/aten/src/ATen/" "${IMPL}") + process_vec("${NAME}") + endforeach() + foreach(IMPL ${cpu_vec_generated_sources}) + file(RELATIVE_PATH NAME "${CMAKE_BINARY_DIR}/aten/src/ATen/" "${IMPL}") + process_vec("${NAME}") endforeach() endforeach() list(APPEND ATen_CPU_SRCS ${cpu_kernel_cpp}) diff --git a/tools/build_variables.bzl b/tools/build_variables.bzl index 8813544588ac0..b1cae2b40f070 100644 --- a/tools/build_variables.bzl +++ b/tools/build_variables.bzl @@ -1,5 +1,16 @@ +# WARNING: the contents of this file must BOTH be valid Starlark (for Buck and + +# Bazel) as well as valid Python (for our cmake build). This means that +# load() directives are not allowed (as they are not recognized by Python). +# If you want to fix this, figure out how run this file from cmake with a proper +# Starlark interpreter as part of the default OSS build process. If you need +# some nontrivial Starlark features, make a separate bzl file (remember that + +# bzl files are not exported via ShipIt by default, so you may also need to +# update PyTorch's ShipIt config) + # In both open-source and fbcode builds, these are generated into -# torch/csrc/{autgrad,jit}/generated.i +# torch/csrc/{autograd,jit}/generated.i GENERATED_CPP = [ "autograd/generated/Functions.cpp", "autograd/generated/VariableType_0.cpp", @@ -1065,6 +1076,10 @@ aten_cpu_source_codegen_list = [ "aten/src/ATen/native/cpu/AdaptiveMaxPoolKernel.cpp", ] +aten_ufunc_headers = [ + "aten/src/ATen/native/ufunc/add.h", +] + # When building lite interpreter in OSS, "aten/src/ATen/native/cpu/AdaptiveAvgPoolKernel.cpp" will go through # codegen process. The codegen version of this file, like Activation.cpp.DEFAULT.cpp, will be included # in ${cpu_kernel_cpp} in aten/src/ATen/CMakeLists.txt. As a result, in aten/src/ATen/CMakeLists.txt, diff --git a/tools/codegen/api/translate.py b/tools/codegen/api/translate.py index 591b8d75e3b10..8342e80a53659 100644 --- a/tools/codegen/api/translate.py +++ b/tools/codegen/api/translate.py @@ -5,7 +5,8 @@ memoryFormatT, tensorOptionsT, scalarTypeT, boolT, deviceT, layoutT, optionalTensorRefT, scalarT, optionalScalarRefT, - VectorCType, longT, intArrayRefT) + VectorCType, longT, intArrayRefT, + scalar_t, opmath_t) # This file implements a small program synthesis engine that implements # conversions between one API to another. @@ -92,9 +93,34 @@ def translate( # While we're at it, do some simple forward inference, looking through # constructors. + # + # NB: When should you do forward inference versus backward inference? + # The general idea: + # + # - Backward inference WHEN the goal gets smaller + # - Forward inference WHEN the hypothesis gets smaller + # + # This helps ensure termination: backward inference starts with a goal + # and tries to make it simpler and simpler until it's trivial; if the + # goal can grow in size, we blow up to a really huge goal size. + # Similarly, with forward inference we take hypotheses and decompose + # them into simpler hypotheses; if hypotheses could expand in size, + # we also have potential nontermination. (In the code below, forward + # inference is only ever carried out at a single step, but you could + # imagine repeated application of forward inference being profitable.) + # + # A good starting point in the literature for exploring more about proof + # search are these lecture notes + # https://www.cs.cmu.edu/~fp/courses/oregon-m10/04-focusing.pdf + # # TODO: My kingdom for a pattern matcher # https://www.python.org/dev/peps/pep-0634/ - # TODO: This could get us in recomputation trouble if b.expr is nontrivial + # + # TODO: This could get us in recomputation trouble if b.expr is nontrivial. + # Fix this by implementing some sort of sharing so that if multiple + # goals share the same expression, we only compute it once. This seems + # to matter in practice as compiler is often unwilling to CSE nontrivial + # expressions like scalar.to() t = b.type if isinstance(t, ConstRefCType) and isinstance(t.elem, OptionalCType) and \ isinstance(t.elem.elem, BaseCType) and str(t.elem.elem.type) == 'at::Tensor': @@ -105,10 +131,16 @@ def translate( ctx[NamedCType(t.name, BaseCType(optionalTensorRefT))] = \ f'(({b.expr}.has_value() && (*{b.expr}).defined()) ? at::OptionalTensorRef(*{b.expr}) : at::OptionalTensorRef())' + if t.type == ConstRefCType(BaseCType(scalarT)): + ctx[NamedCType(t.name, BaseCType(opmath_t))] = f'({b.expr}).to()' + if t.type == ConstRefCType(OptionalCType(BaseCType(scalarT))): ctx[NamedCType(t.name, BaseCType(optionalScalarRefT))] = \ f'({b.expr}.has_value() ? at::OptionalScalarRef(&({b.expr}.value())) : at::OptionalScalarRef())' + if t.type == BaseCType(scalar_t): + ctx[NamedCType(t.name, BaseCType(opmath_t))] = f'static_cast({b.expr})' + # Add implicit bindings if the generated code is inside a Tensor method if method: ctx[NamedCType("self", MutRefCType(BaseCType(tensorT)))] = "const_cast(*this)" @@ -129,7 +161,8 @@ def unsat(goal: NamedCType) -> NoReturn: ''') # A shitty backtracking search implementation. It's shitty because it - # doesn't actually do backtracing or search. In particular, if + # does backtracking via stack (bad idea!) and for the most part tries to + # avoid backtracking. In particular, if # direct=True, we won't try to do any fancy synthesis, just trivial # conversions (e.g., "T a" is OK for "const T& a"). So all of the # existing rules in this function simply try to solve immediately, diff --git a/tools/codegen/api/types.py b/tools/codegen/api/types.py index d269f2c7a3ff7..8a01b49bfb42f 100644 --- a/tools/codegen/api/types.py +++ b/tools/codegen/api/types.py @@ -1,6 +1,6 @@ from tools.codegen.model import (Argument, FunctionSchema, NativeFunction, - BackendIndex, - SelfArgument, TensorOptionsArguments, BaseTy) + BackendIndex, NativeFunctionsGroup, + SelfArgument, TensorOptionsArguments, BaseTy, ScalarType) from dataclasses import dataclass from typing import Optional, Union, Sequence, TypeVar, List, Set, Dict from enum import Enum @@ -68,6 +68,27 @@ def __str__(self) -> str: typeAndSizeT = BaseCppType('torch::autograd::generated', 'TypeAndSize') tensorGeometryT = BaseCppType('at', 'TensorGeometry') +# Types representing template parameters. Technically, we probably shouldn't +# represent them this way in codegen, but it was pretty convenient. +scalar_t = BaseCppType('', 'scalar_t') +opmath_t = BaseCppType('', 'opmath_t') + +ScalarTypeToCppMapping: Dict[ScalarType, BaseCppType] = { + ScalarType.Byte: byteT, + ScalarType.Char: charT, + ScalarType.Short: shortT, + ScalarType.Int: int32T, + ScalarType.Long: longT, + ScalarType.Half: halfT, + ScalarType.Float: floatT, + ScalarType.Double: doubleT, + ScalarType.ComplexHalf: complexHalfT, + ScalarType.ComplexFloat: complexFloatT, + ScalarType.ComplexDouble: complexDoubleT, + ScalarType.Bool: boolT, + ScalarType.BFloat16: bfloat16T, +} + BaseTypeToCppMapping: Dict[BaseTy, BaseCppType] = { BaseTy.int: longT, BaseTy.float: doubleT, @@ -218,6 +239,23 @@ def cpp_type_registration_declarations(self) -> str: def remove_const_ref(self) -> 'CType': return TupleCType([e.remove_const_ref() for e in self.elems]) +@dataclass(frozen=True) +class VectorizedCType: + # This template is explicitly specialized, so the only valid + # elems are those we have specializations for (e.g., float, double, ...) + # scalar_t is also a common argument here (when we are codegen in + # a templated context) + elem: BaseCType + + def cpp_type(self, *, strip_ref: bool = False) -> str: + return f'at::vec::Vectorized<{self.elem.cpp_type()}>' + + def cpp_type_registration_declarations(self) -> str: + raise NotImplementedError + + def remove_const_ref(self) -> 'CType': + return self + CType = Union[ BaseCType, OptionalCType, @@ -227,7 +265,8 @@ def remove_const_ref(self) -> 'CType': ArrayRefCType, ArrayCType, VectorCType, - TupleCType + TupleCType, + VectorizedCType ] # A NamedCType is short for Named C++ semantic type. A NamedCType represents a C++ type, plus @@ -270,6 +309,14 @@ class Binding: # TODO: maybe don't represent default here default: Optional[str] = None + def rename(self, name: str) -> 'Binding': + return Binding( + name=name, + nctype=self.nctype, + argument=self.argument, + default=self.default, + ) + @property def type(self) -> str: return self.nctype.cpp_type() @@ -596,6 +643,19 @@ def from_func(f: NativeFunction, *, functional_op: NativeFunction, is_reverse: b return FunctionalizationLambda(f, functional_op, is_reverse) +@dataclass(frozen=True) +class StructuredImplSignature: + g: NativeFunctionsGroup + name: str + + def defn(self, name: Optional[str] = None) -> str: + args_str = ', '.join(a.defn() for a in self.arguments()) + return f"TORCH_IMPL_FUNC({self.name})({args_str})" + + def arguments(self) -> List[Binding]: + return structured.impl_arguments(self.g) + + # Helper functions def kernel_signature( @@ -615,4 +675,4 @@ def kernel_signature( return NativeSignature(f.func, prefix) # Functions only, no types -from tools.codegen.api import cpp, dispatcher, native, translate, functionalization +from tools.codegen.api import cpp, dispatcher, native, translate, functionalization, structured diff --git a/tools/codegen/api/ufunc.py b/tools/codegen/api/ufunc.py new file mode 100644 index 0000000000000..e6609e0b8888c --- /dev/null +++ b/tools/codegen/api/ufunc.py @@ -0,0 +1,176 @@ +from tools.codegen.model import (Argument, BaseTy, BaseType, FunctionSchema, + NativeFunctionsGroup, Type, DispatchKey) + +import tools.codegen.api.types as api_types +from tools.codegen.api.types import (ArgName, BaseCType, Binding, + ConstRefCType, NamedCType, + scalarT, CType, BaseCppType) + +from tools.codegen.api import cpp, structured + +from dataclasses import dataclass +from typing import List, Optional + +def schema_kernel_name(func: FunctionSchema, dispatch_key: DispatchKey) -> str: + assert func.is_out_fn(), "ufunc.kernel_name should only be invoked on out schemas" + return f"ufunc_{func.name.name}_{dispatch_key}" + +def kernel_name(g: NativeFunctionsGroup, dispatch_key: DispatchKey) -> str: + return schema_kernel_name(g.out.func, dispatch_key) + +# Tensors are omitted (as they are stored in TensorIterator), everything else is +# passed along (technically, we can pass tensors along too, it just wastes +# argument registers) +# +# NB: used for CPU only +def dispatchstub_type(t: Type, *, binds: ArgName) -> Optional[NamedCType]: + r = cpp.valuetype_type(t, binds=binds) + if r is not None: + return r + + if t == BaseType(BaseTy.Scalar): + return NamedCType(binds, ConstRefCType(BaseCType(scalarT))) + elif t == BaseType(BaseTy.Tensor): + return None + else: + raise AssertionError(f"unrecognized type {repr(t)}") + +def opmath_type(scalar_t: BaseCppType) -> BaseCppType: + if scalar_t == api_types.scalar_t: + return api_types.opmath_t + raise NotImplementedError + +# NB: Tensors in constructor are stored in opmath_t, not scalar_t +# because Tensor in constructor = its a scalar tensor partially applied = +# it can be higher precision and we want to compute in that higher precision +# +# NB: CUDA only +def ufunctor_ctor_type(t: Type, *, binds: ArgName, scalar_t: BaseCppType) -> NamedCType: + r = cpp.valuetype_type(t, binds=binds) + if r is not None: + return r + + if t == BaseType(BaseTy.Scalar): + return NamedCType(binds, BaseCType(opmath_type(scalar_t))) + elif t == BaseType(BaseTy.Tensor): + return NamedCType(binds, BaseCType(opmath_type(scalar_t))) + else: + raise AssertionError(f"unrecognized type {repr(t)}") + +# Only Tensors ever get passed directly to operator() +# +# NB: CUDA only +# (Actually, this works for CPU too) +def ufunctor_apply_type(t: Type, *, binds: ArgName, scalar_t: BaseCppType) -> NamedCType: + if t == BaseType(BaseTy.Tensor): + return NamedCType(binds, BaseCType(scalar_t)) + else: + raise AssertionError(f"unrecognized type {repr(t)}") + +# The actual ufunc template function the user writes. Everything here +# is done in the computation type. compute_t is opmath_t in CUDA and scalar_t +# in CPU +def ufunc_type(t: Type, *, binds: ArgName, compute_t: CType) -> NamedCType: + r = cpp.valuetype_type(t, binds=binds) + if r is not None: + return r + + if t == BaseType(BaseTy.Scalar): + return NamedCType(binds, compute_t) + elif t == BaseType(BaseTy.Tensor): + return NamedCType(binds, compute_t) + else: + raise AssertionError(f"unrecognized type {repr(t)}") + +def ufunctor_ctor_argument(a: Argument, scalar_t: BaseCppType) -> Binding: + return Binding( + nctype=ufunctor_ctor_type(a.type, binds=a.name, scalar_t=scalar_t), + name=a.name, + default=None, + argument=a, + ) + +def ufunctor_apply_argument(a: Argument, scalar_t: BaseCppType) -> Binding: + return Binding( + nctype=ufunctor_apply_type(a.type, binds=a.name, scalar_t=scalar_t), + name=a.name, + default=None, + argument=a, + ) + +def ufunc_argument(a: Argument, compute_t: CType) -> Binding: + return Binding( + nctype=ufunc_type(a.type, binds=a.name, compute_t=compute_t), + name=a.name, + default=None, + argument=a, + ) + +@dataclass(frozen=True) +class UfunctorBindings: + ctor: List[Binding] + apply: List[Binding] + +# ufunctors are a CUDA-only concept representing functors that take some of +# their arguments on a host-side constructor, and the rest in the device-side +# apply. E.g., +# +# template +# struct CUDAFunctorOnSelf_add { +# using opmath_t = at::opmath_type; +# opmath_t other_; +# opmath_t alpha_; +# CUDAFunctorOnSelf_add(opmath_t other, opmath_t alpha) : other_(other), alpha_(alpha) {} +# __device__ scalar_t operator()(scalar_t self) { +# return ufunc::add(static_cast(self), other_, alpha_); +# } +# }; +# +# The ctor refers to the constructor CUDAFunctorOnSelf_add, while apply refers +# to the operator() definition +def ufunctor_arguments( + g: NativeFunctionsGroup, *, scalar_tensor_idx: Optional[int], scalar_t: BaseCppType +) -> UfunctorBindings: + ctor = [] + apply = [] + for a in g.functional.func.arguments.flat_non_out: + if a.type.is_tensor_like(): + if scalar_tensor_idx == 0: + # put it in the ctor anyway + ctor.append(ufunctor_ctor_argument(a, scalar_t=scalar_t)) + scalar_tensor_idx = None + else: + if scalar_tensor_idx is not None: + scalar_tensor_idx -= 1 + apply.append(ufunctor_apply_argument(a, scalar_t=scalar_t)) + else: + ctor.append(ufunctor_ctor_argument(a, scalar_t=scalar_t)) + assert scalar_tensor_idx is None + return UfunctorBindings(ctor=ctor, apply=apply) + +# ufuncs are the inner loop template functions that you wrote in ufunc/add.h +# which do the actual computation in question. E.g., +# +# template +# C10_HOST_DEVICE T add(T self, T other, T alpha) __ubsan_ignore_undefined__ { +# return self + alpha * other; +# } +# +# In this file, we refer to T as compute_t which is bound by caller +def ufunc_arguments(g: NativeFunctionsGroup, *, compute_t: CType) -> List[Binding]: + return [ufunc_argument(a, compute_t=compute_t) for a in g.functional.func.arguments.flat_non_out] + +# Stubs are the DispatchStub trampolines that CPU kernels use to get to their +# vectorized versions. E.g., +# +# using structured_binary_fn_alpha = void(*)(TensorIteratorBase&, const Scalar& alpha); +# DECLARE_DISPATCH(structured_binary_fn_alpha, add_stub); +def stub_arguments(g: NativeFunctionsGroup) -> List[Binding]: + # stubs drop all tensor arguments (they are implicit in the TensorIterator + # argument and keep everything else) + return [ + r + for a in g.out.func.arguments.flat_non_out + if not a.type.is_tensor_like() + for r in structured.argument(a) + ] diff --git a/tools/codegen/dest/__init__.py b/tools/codegen/dest/__init__.py index ce9265adf969a..d191b8361bae8 100644 --- a/tools/codegen/dest/__init__.py +++ b/tools/codegen/dest/__init__.py @@ -7,3 +7,8 @@ gen_registration_headers as gen_registration_headers, ) from .native_functions import compute_native_function_declaration as compute_native_function_declaration +from .ufunc import ( + compute_ufunc_cuda as compute_ufunc_cuda, + compute_ufunc_cpu as compute_ufunc_cpu, + compute_ufunc_cpu_kernel as compute_ufunc_cpu_kernel +) diff --git a/tools/codegen/dest/ufunc.py b/tools/codegen/dest/ufunc.py new file mode 100644 index 0000000000000..c8b92bd538e4b --- /dev/null +++ b/tools/codegen/dest/ufunc.py @@ -0,0 +1,477 @@ +from dataclasses import dataclass +from typing import Union, Optional, List, Tuple, Dict, Sequence +from tools.codegen.api.translate import translate +from tools.codegen.model import NativeFunctionsGroup, ScalarType, UfuncKey, DispatchKey, BaseType, BaseTy, Argument +import tools.codegen.api.ufunc as ufunc +from tools.codegen.api.ufunc import UfunctorBindings +from tools.codegen.api.types import ( + StructuredImplSignature, scalar_t, opmath_t, Binding, CType, + BaseCType, Expr, NamedCType, ScalarTypeToCppMapping, VectorizedCType +) +from tools.codegen.context import with_native_function + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# +# CUDA STUFF +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + +# NB: not bothering to generate dispatch stub forward declaration in header, +# we can just paste it whereever necessary + +# TODO: use BackendIndex +# dispatch_key: DispatchKey # only CPU/CUDA right now + + +# Represents functors for implementing CUDA ufuncs. +# Functors are templated by scalar_t because when USERS instantiate functors +# they are templated. A functor looks something like this: +# +# template +# struct CUDAFunctorOnSelf_add { +# using opmath_t = at::opmath_type; +# opmath_t other_; +# opmath_t alpha_; +# CUDAFunctorOnSelf_add(opmath_t other, opmath_t alpha) +# : other_(other), alpha_(alpha) {} +# __device__ scalar_t operator()(scalar_t self) { +# return ufunc::add(static_cast(self), other_, alpha_); +# } +# }; +# +@dataclass(frozen=True) +class UfunctorSignature: + g: NativeFunctionsGroup + scalar_tensor_idx: Optional[int] + name: str + + def arguments(self) -> UfunctorBindings: + return ufunc.ufunctor_arguments(self.g, scalar_tensor_idx=self.scalar_tensor_idx, scalar_t=scalar_t) + + def fields(self) -> List[Binding]: + # fields are renamed to have a trailing underscore, as is conventional + return [b.rename(f"{b.name}_") for b in self.arguments().ctor] + + def returns_type(self) -> CType: + # TODO: don't hardcode; return type will be inferred based on tags on + # the native function + return BaseCType(scalar_t) + + def decl_fields(self) -> str: + return "\n".join(f"{f.type} {f.name};" for f in self.fields()) + + def inline_defn_ctor(self) -> str: + args_str = ', '.join(a.decl() for a in self.arguments().ctor) + # NB: hypothetically could do this with translate but the + # transition here is very regular + init_str = ', '.join(f"{a.name}_({a.name})" for a in self.arguments().ctor) + return f"{self.name}({args_str}) : {init_str} {{}}" + + def decl_apply(self) -> str: + args_str = ', '.join(a.decl() for a in self.arguments().apply) + return f"{self.returns_type().cpp_type()} operator()({args_str}) const" + + +@dataclass(frozen=True) +class UfuncSignature: + g: NativeFunctionsGroup + name: str + compute_t: CType + + def arguments(self) -> List[Binding]: + return ufunc.ufunc_arguments(self.g, compute_t=self.compute_t) + + def call(self, ctx: Sequence[Union[Binding, Expr]]) -> str: + return f"{self.name}({', '.join(a.expr for a in translate(ctx, self.arguments()))})" + + +# steps: +# 1. take the functional signature +# 2. use api.ufunc to convert it to template signature. this establishes +# the type of the template function +# 3. use api.ufunc (II) to generate a split struct / operator() signature. +# this establish context in which we call the template signature +# +# StructuredImplSignature context +# ~> functor constructor sig +# +# Functor constructor context +# ~> functor fields sig +# +# Functor apply context (functor fields + functor apply sig) +# ~> template sig +# + +def eligible_for_binary_scalar_specialization(g: NativeFunctionsGroup) -> bool: + num_tensors = sum(1 for a in g.functional.func.arguments.flat_non_out if a.type.is_tensor_like()) + return num_tensors == 2 + +def compute_ufunc_cuda_functors(g: NativeFunctionsGroup) -> Tuple[Dict[ScalarType, Dict[UfuncKey, UfunctorSignature]], str]: + # First, build the functors. + ufunctor_sigs: Dict[ScalarType, Dict[UfuncKey, UfunctorSignature]] = {} + ufunctors: List[str] = [] + loops = g.out.ufunc_inner_loop + scalar_tensor_idx_lookup = { + UfuncKey.CUDAFunctorOnSelf: 1, + UfuncKey.CUDAFunctorOnOther: 0, + UfuncKey.CUDAFunctor: None + } + if eligible_for_binary_scalar_specialization(g): + keys = [UfuncKey.CUDAFunctorOnSelf, UfuncKey.CUDAFunctorOnOther, UfuncKey.CUDAFunctor] + else: + keys = [UfuncKey.CUDAFunctor] + for k in [UfuncKey.CUDAFunctorOnSelf, UfuncKey.CUDAFunctorOnOther]: + assert k not in loops, f"cannot use {k} on non-binary function" + for k in keys: + # If the key was directly defined, skip functor codegen; we assume the + # user already done it for us + if k in loops: + ufunctor_sig = UfunctorSignature(g, scalar_tensor_idx=scalar_tensor_idx_lookup[k], name=loops[k].name) + for dtype in loops[k].supported_dtypes: + ufunctor_sigs.setdefault(dtype, {})[k] = ufunctor_sig + continue + + # Note [ScalarOnly and Generic must match names for CUDA] + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Otherwise, look in ANY of the generic entries. For simplicity of + # codegen, both ScalarOnly and Generic are defined, the ufunc name + # must match (if they didn't match, we'd have to generate distinct + # functors per dtype, which is awful, so we're not going to do it unless + # someone really forces us to) + ufunc_name = None + supported_dtypes = set() + for lk in [UfuncKey.ScalarOnly, UfuncKey.Generic]: + if lk not in loops: + continue + if ufunc_name is None: + ufunc_name = loops[lk].name + else: + # See Note [ScalarOnly and Generic must match names for CUDA] + assert ufunc_name == loops[lk].name, "ScalarOnly and Generic must have same ufunc name" + supported_dtypes |= loops[lk].supported_dtypes + assert ufunc_name is not None + + name = f"{k}_{ufunc_name}" + ufunctor_sig = UfunctorSignature(g, scalar_tensor_idx=scalar_tensor_idx_lookup[k], name=name) + for dtype in supported_dtypes: + ufunctor_sigs.setdefault(dtype, {})[k] = ufunctor_sig + + ufunc_sig = UfuncSignature(g, name=f"ufunc::{ufunc_name}", compute_t=BaseCType(opmath_t)) + apply_ctx = ufunctor_sig.fields() + ufunctor_sig.arguments().apply + ufunctors.append(f""" +template +struct {ufunctor_sig.name} {{ + using opmath_t = at::opmath_type; + {ufunctor_sig.decl_fields()} + {ufunctor_sig.inline_defn_ctor()} + __device__ {ufunctor_sig.decl_apply()} {{ + return {ufunc_sig.call(apply_ctx)}; + }} +}}; +""") + + return ufunctor_sigs, "\n".join(ufunctors) + +@dataclass(frozen=True) +class BinaryScalarSpecializationConfig: + scalar_idx: int + ctor_tensor: str + ufunc_key: UfuncKey + +BinaryScalarSpecializationConfigs = [ + BinaryScalarSpecializationConfig( + scalar_idx=0, + ctor_tensor='self', + ufunc_key=UfuncKey.CUDAFunctorOnOther, + ), + BinaryScalarSpecializationConfig( + scalar_idx=1, + ctor_tensor='other', + ufunc_key=UfuncKey.CUDAFunctorOnSelf, + ), +] + +def compute_ufunc_cuda_dtype_body( + g: NativeFunctionsGroup, dtype: ScalarType, + inner_loops: Dict[UfuncKey, UfunctorSignature], parent_ctx: Sequence[Binding] +) -> str: + body = "using opmath_t = at::opmath_type;" + body += "if (false) {}\n" # for ease of codegen + for config in BinaryScalarSpecializationConfigs: + if config.ufunc_key not in inner_loops: + continue + ufunctor_sig = inner_loops[config.ufunc_key] + scalar_idx = config.scalar_idx + 1 + # Make a copy and at the same time widen the type (not permissible + # without copy; we don't want to mutate the input argument anyway) + ctx: List[Union[Expr, Binding]] = list(parent_ctx) + ctx.append(Expr( + expr=f"iter.scalar_value({scalar_idx})", + type=NamedCType(config.ctor_tensor, BaseCType(opmath_t)), + )) + ufunctor_ctor_exprs_str = ', '.join(a.expr for a in translate(ctx, ufunctor_sig.arguments().ctor)) + + # NB: ufunctor must be allocated before iter.remove_operand is called, + # as it relies on iter + body += f"""\ +else if (iter.is_cpu_scalar({scalar_idx})) {{ + {ufunctor_sig.name} ufunctor({ufunctor_ctor_exprs_str}); + iter.remove_operand({scalar_idx}); + gpu_kernel(iter, ufunctor); +}}""" + + ufunctor_sig = inner_loops[UfuncKey.CUDAFunctor] + ufunctor_ctor_exprs_str = ', '.join(a.expr for a in translate(parent_ctx, ufunctor_sig.arguments().ctor)) + body += f""" +else {{ + gpu_kernel(iter, {ufunctor_sig.name}({ufunctor_ctor_exprs_str})); +}} + """ + return body + +@with_native_function +def compute_ufunc_cuda(g: NativeFunctionsGroup) -> str: + # First, build the functors, indexing them by dtype + ufunctor_sigs, ufunctors = compute_ufunc_cuda_functors(g) + + # Next, build the conditionals + sig = StructuredImplSignature(g, ufunc.kernel_name(g, DispatchKey.CUDA)) + dtype_cases = [] + for dtype, inner_ufunctor_sigs in ufunctor_sigs.items(): + dtype_cases.append(f""" +AT_PRIVATE_CASE_TYPE("{sig.name}", at::ScalarType::{dtype}, {ScalarTypeToCppMapping[dtype]}, + [&]() {{ + {compute_ufunc_cuda_dtype_body(g, dtype, inner_ufunctor_sigs, sig.arguments())} + }} +) +""") + + dtype_cases_str = "\n".join(dtype_cases) + + stub_sig = StubSignature(g) + + return f""" +{ufunctors} + +{stub_sig.type_defn()}; +{stub_sig.dispatch_decl()}; + +{stub_sig.kernel_defn()} {{ + at::ScalarType st = iter.common_dtype(); + RECORD_KERNEL_FUNCTION_DTYPE("{sig.name}", st); + switch (st) {{ + {dtype_cases_str} + default: + TORCH_CHECK(false, "{sig.name}", " not implemented for '", toString(st), "'"); + }} +}} +REGISTER_DISPATCH({stub_sig.name}, &{stub_sig.kernel_name}); + +{sig.defn()} {{ + {stub_sig.direct_call(sig.arguments())}; +}} +""" + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# +# CPU STUFF +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + +@dataclass(frozen=True) +class StubSignature: + g: NativeFunctionsGroup + + @property + def name(self) -> str: + return f"{str(self.g.functional.func.name.name)}_stub" + + @property + def kernel_name(self) -> str: + return f"{str(self.g.functional.func.name.name)}_kernel" + + @property + def type_name(self) -> str: + return f"{str(self.g.functional.func.name.name)}_fn" + + def arguments(self) -> List[Binding]: + return ufunc.stub_arguments(self.g) + + def type(self) -> str: + cpp_args = self.arguments() + return f"void(*)(TensorIteratorBase&, {', '.join(a.type for a in cpp_args)})" + + def dispatch_decl(self) -> str: + return f"DECLARE_DISPATCH({self.type_name}, {self.name})" + + def dispatch_defn(self) -> str: + return f"DEFINE_DISPATCH({self.name})" + + def kernel_defn(self) -> str: + return f"void {self.kernel_name}(TensorIteratorBase& iter, {', '.join(a.defn() for a in self.arguments())})" + + def type_defn(self) -> str: + return f"using {self.type_name} = {self.type()}" + + # must be called from context where this is TensorIteratorBase* + def call(self, ctx: Sequence[Binding]) -> str: + return f"{self.name}(device_type(), *this, {', '.join(a.expr for a in translate(ctx, self.arguments()))})" + + # used in CUDA to skip the unnecessary dynamic dispatch + def direct_call(self, ctx: Sequence[Binding]) -> str: + return f"{self.kernel_name}(*this, {', '.join(a.expr for a in translate(ctx, self.arguments()))})" + +@with_native_function +def compute_ufunc_cpu(g: NativeFunctionsGroup) -> str: + stub_sig = StubSignature(g) + sig = StructuredImplSignature(g, ufunc.kernel_name(g, DispatchKey.CPU)) + + return f""" +{stub_sig.type_defn()}; +{stub_sig.dispatch_decl()}; +{stub_sig.dispatch_defn()}; + +{sig.defn()} {{ + {stub_sig.call(sig.arguments())}; +}} +""" + +def compute_ufunc_cpu_dtype_body( + g: NativeFunctionsGroup, dtype: ScalarType, inner_loops: Dict[UfuncKey, UfuncSignature], + parent_ctx: Sequence[Binding] +) -> str: + assert UfuncKey.CPUScalar in inner_loops, f"{dtype}, {inner_loops.keys()}" + assert inner_loops.keys() <= {UfuncKey.CPUScalar, UfuncKey.CPUVector} + scalar_loop = inner_loops[UfuncKey.CPUScalar] + vec_loop = None + if UfuncKey.CPUVector in inner_loops: + vec_loop = inner_loops[UfuncKey.CPUVector] + + # NB: We DON'T use translate here, because translate is + # incapable of CSE'ing the scalar accesses in case it is also + # used by Vectorized; also, the unpacking here is very simple + # and only affects Scalar; everything else is implicitly captured + # by the lambda + + # Setup scalar in scope + body = [] + ctx = [] + for b in parent_ctx: + if isinstance(b.argument, Argument) and b.argument.type != BaseType(BaseTy.Scalar): + continue + body.append(f"auto _s_{b.name} = {b.name}.to();") + ctx.append(Expr(f"_s_{b.name}", NamedCType(b.nctype.name, BaseCType(scalar_t)))) + if vec_loop is not None: + for b in parent_ctx: + if isinstance(b.argument, Argument) and b.argument.type != BaseType(BaseTy.Scalar): + continue + body.append(f"auto _v_{b.name} = at::vec::Vectorized(_s_{b.name});") + ctx.append(Expr(f"_v_{b.name}", NamedCType(b.nctype.name, VectorizedCType(BaseCType(scalar_t))))) + + # Setup lambda signature + # NB: simplified version of ufunctor_arguments + scalar_bindings = [] + vec_bindings = [] + for a in g.functional.func.arguments.flat_non_out: + if not a.type.is_tensor_like(): + continue + assert a.type == BaseType(BaseTy.Tensor) + scalar_bindings.append(Binding( + name=a.name, + nctype=NamedCType(a.name, BaseCType(scalar_t)), + argument=a, + )) + if vec_loop is not None: + vec_bindings.append(Binding( + name=a.name, + nctype=NamedCType(a.name, VectorizedCType(BaseCType(scalar_t))), + argument=a, + )) + + def with_ctx(b: Sequence[Binding]) -> List[Union[Expr, Binding]]: + r: List[Union[Expr, Binding]] = [] + r.extend(ctx) + r.extend(b) + return r + + body_str = '\n'.join(body) + if vec_loop is not None: + return f""" +{body_str} +cpu_kernel_vec(iter, + [=]({', '.join(b.decl() for b in scalar_bindings)}) {{ return {scalar_loop.call(with_ctx(scalar_bindings))}; }}, + [=]({', '.join(b.decl() for b in vec_bindings)}) {{ return {vec_loop.call(with_ctx(vec_bindings))}; }} +); +""" + else: + return f""" +{body_str} +cpu_kernel(iter, + [=]({', '.join(b.decl() for b in scalar_bindings)}) {{ return {scalar_loop.call(with_ctx(scalar_bindings))}; }} +); +""" + +@with_native_function +def compute_ufunc_cpu_kernel(g: NativeFunctionsGroup) -> str: + stub_sig = StubSignature(g) + + # Reindex the ufunc by dtypes; processing generic/scalaronly as well + loops = g.out.ufunc_inner_loop + ufunc_sigs: Dict[ScalarType, Dict[UfuncKey, UfuncSignature]] = {} + for k in [UfuncKey.CPUScalar, UfuncKey.CPUVector]: + lks = [] + # ORDER MATTERS: this specifies overriding precedence + if k in loops: # should happen rarely + lks.append(k) + if UfuncKey.ScalarOnly in loops and k is UfuncKey.CPUScalar: + lks.append(UfuncKey.ScalarOnly) + if UfuncKey.Generic in loops: + lks.append(UfuncKey.Generic) + # TODO: don't hardcode ufunc:: namespace here, should be centralized smh + for lk in lks: + for dtype in loops[lk].supported_dtypes: + compute_t: CType + if k is UfuncKey.CPUScalar: + compute_t = BaseCType(scalar_t) + elif k is UfuncKey.CPUVector: + compute_t = VectorizedCType(BaseCType(scalar_t)) + else: + raise AssertionError() + inner_ufunc_sigs = ufunc_sigs.setdefault(dtype, {}) + if k not in inner_ufunc_sigs: + inner_ufunc_sigs[k] = UfuncSignature( + g, name=f"ufunc::{loops[lk].name}", + compute_t=compute_t + ) + + # Build the conditionals + dtype_cases = [] + for dtype, inner_ufunc_sigs in ufunc_sigs.items(): + dtype_cases.append(f""" +AT_PRIVATE_CASE_TYPE("{stub_sig.name}", at::ScalarType::{dtype}, {ScalarTypeToCppMapping[dtype]}, + [&]() {{ + {compute_ufunc_cpu_dtype_body(g, dtype, inner_ufunc_sigs, stub_sig.arguments())} + }} +) +""") + + dtype_cases_str = "\n".join(dtype_cases) + return f""" +namespace {{ + +{stub_sig.kernel_defn()} {{ + at::ScalarType st = iter.common_dtype(); + RECORD_KERNEL_FUNCTION_DTYPE("{stub_sig.name}", st); + switch (st) {{ + {dtype_cases_str} + default: + TORCH_CHECK(false, "{stub_sig.name}", " not implemented for '", toString(st), "'"); + }} +}} + +}} // anonymous namespace + +{stub_sig.type_defn()}; +{stub_sig.dispatch_decl()}; +REGISTER_DISPATCH({stub_sig.name}, &{stub_sig.kernel_name}); +""" diff --git a/tools/codegen/gen.py b/tools/codegen/gen.py index 846c02f1382f0..51ff8340095dd 100644 --- a/tools/codegen/gen.py +++ b/tools/codegen/gen.py @@ -16,6 +16,7 @@ TensorOptionsArguments, Type, Variant, is_cuda_dispatch_key, is_generic_dispatch_key, + is_ufunc_dispatch_key, Tag, BaseOperatorName) from tools.codegen.api.types import (Binding, CppSignature, CppSignatureGroup, DispatcherSignature, NativeSignature) @@ -111,40 +112,44 @@ def construct_mapping(self, node, deep=False): # type: ignore[no-untyped-def] # Parse native_functions.yaml into a sequence of NativeFunctions and Backend Indices. ParsedYaml = namedtuple('ParsedYaml', ['native_functions', 'backend_indices']) + +def parse_native_yaml_struct(es: object, path: str = "") -> ParsedYaml: + assert isinstance(es, list) + rs: List[NativeFunction] = [] + bs: Dict[DispatchKey, Dict[OperatorName, BackendMetadata]] = defaultdict(dict) + for e in es: + assert isinstance(e.get('__line__'), int), e + loc = Location(path, e['__line__']) + funcs = e.get('func') + with context(lambda: f'in {loc}:\n {funcs}'): + func, m = NativeFunction.from_yaml(e, loc) + rs.append(func) + BackendIndex.grow_index(bs, m) + error_check_native_functions(rs) + # Default dict is to prevent the codegen from barfing when we have a dispatch key that has no kernels yet. + indices: Dict[DispatchKey, BackendIndex] = defaultdict(lambda: BackendIndex( + dispatch_key=DispatchKey.Undefined, + use_out_as_primary=True, + external=False, + device_guard=False, + index={})) + for k, v in bs.items(): + # All structured in-tree operators are implemented in terms of their out operator. + indices[k] = BackendIndex( + dispatch_key=k, + use_out_as_primary=True, + external=False, + # Only cuda-like devices in tree require device guards + device_guard=is_cuda_dispatch_key(k), + index=v) + return ParsedYaml(rs, indices) + def parse_native_yaml(path: str) -> ParsedYaml: global _GLOBAL_PARSE_NATIVE_YAML_CACHE if path not in _GLOBAL_PARSE_NATIVE_YAML_CACHE: with open(path, 'r') as f: es = yaml.load(f, Loader=LineLoader) - assert isinstance(es, list) - rs: List[NativeFunction] = [] - bs: Dict[DispatchKey, Dict[OperatorName, BackendMetadata]] = defaultdict(dict) - for e in es: - assert isinstance(e.get('__line__'), int), e - loc = Location(path, e['__line__']) - funcs = e.get('func') - with context(lambda: f'in {loc}:\n {funcs}'): - func, m = NativeFunction.from_yaml(e, loc) - rs.append(func) - BackendIndex.grow_index(bs, m) - error_check_native_functions(rs) - # Default dict is to prevent the codegen from barfing when we have a dispatch key that has no kernels yet. - indices: Dict[DispatchKey, BackendIndex] = defaultdict(lambda: BackendIndex( - dispatch_key=DispatchKey.Undefined, - use_out_as_primary=True, - external=False, - device_guard=False, - index={})) - for k, v in bs.items(): - # All structured in-tree operators are implemented in terms of their out operator. - indices[k] = BackendIndex( - dispatch_key=k, - use_out_as_primary=True, - external=False, - # Only cuda-like devices in tree require device guards - device_guard=is_cuda_dispatch_key(k), - index=v) - _GLOBAL_PARSE_NATIVE_YAML_CACHE[path] = ParsedYaml(rs, indices) + _GLOBAL_PARSE_NATIVE_YAML_CACHE[path] = parse_native_yaml_struct(es, path=path) return _GLOBAL_PARSE_NATIVE_YAML_CACHE[path] @@ -1012,6 +1017,7 @@ def gen_aggregated_headers( *, native_functions: Sequence[NativeFunction], grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]], + structured_native_functions: Sequence[NativeFunctionsGroup], static_dispatch_idx: Optional[BackendIndex], selector: SelectiveBuilder, backend_indices: Dict[DispatchKey, BackendIndex], @@ -1023,8 +1029,6 @@ def gen_aggregated_headers( ) -> None: # Buck doesn't support dynamic output files, so we aggregate all operator # headers into a single file - structured_native_functions = [g for g in grouped_native_functions - if isinstance(g, NativeFunctionsGroup)] cpu_fm.write('NativeMetaFunctions.h', lambda: { 'NativeMetaFunctions_includes': [], 'NativeMetaFunctions_declarations': list( @@ -1242,6 +1246,7 @@ def gen_headers( *, native_functions: Sequence[NativeFunction], grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]], + structured_native_functions: Sequence[NativeFunctionsGroup], static_dispatch_idx: Optional[BackendIndex], selector: SelectiveBuilder, backend_indices: Dict[DispatchKey, BackendIndex], @@ -1272,6 +1277,7 @@ def gen_headers( gen_aggregated_headers( native_functions=native_functions, grouped_native_functions=grouped_native_functions, + structured_native_functions=structured_native_functions, static_dispatch_idx=static_dispatch_idx, selector=selector, backend_indices=backend_indices, @@ -1343,11 +1349,13 @@ def gen_source_files( *, native_functions: Sequence[NativeFunction], grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]], + structured_native_functions: Sequence[NativeFunctionsGroup], static_dispatch_idx: Optional[BackendIndex], selector: SelectiveBuilder, backend_indices: Dict[DispatchKey, BackendIndex], core_fm: FileManager, cpu_fm: FileManager, + cpu_vec_fm: FileManager, cuda_fm: FileManager, dispatch_keys: Sequence[DispatchKey], functions_keys: Set[DispatchKey], @@ -1373,19 +1381,30 @@ def gen_source_files( if per_operator_headers: def operator_headers() -> List[str]: headers = [] - for fn in native_functions: - is_registered = backend_index.has_kernel(fn) or ( - fn.structured and dispatch_key in - (DispatchKey.Meta, DispatchKey.CompositeExplicitAutograd)) + for g in grouped_native_functions: + is_registered = False + if backend_index.has_kernel(g): + is_registered = True + # The above has_kernel test on a group will only test for + # the existence of out dispatch, because that's how + # structured kernels work. But sometimes functions can be + # grouped but not be structured, and then you need to check + # each individual piece, as they may have manual dispatch + # entries. + elif isinstance(g, NativeFunctionsGroup) and any(backend_index.has_kernel(fn) for fn in g.functions()): + is_registered = True + # TODO: this condition is a bit questionable + elif g.structured and dispatch_key in (DispatchKey.Meta, DispatchKey.CompositeExplicitAutograd): + is_registered = True if not is_registered: continue - headers.append(f"#include ") + headers.append(f"#include ") if dispatch_key == DispatchKey.CompositeExplicitAutograd: - headers.append(f"#include ") + headers.append(f"#include ") if dispatch_key in functions_keys: headers.append( - f"#include ") + f"#include ") return sorted(set(headers)) else: @@ -1439,6 +1458,39 @@ def operator_headers() -> List[str]: )), }) + for g in structured_native_functions: + if not g.out.ufunc_inner_loop or not is_ufunc_dispatch_key(dispatch_key): + continue + name = g.functional.func.name.name + if dispatch_key is DispatchKey.CPU: + assert fm is cpu_fm + fm.write_with_template(f'UfuncCPU_{name}.cpp', 'UfuncCPU.cpp', lambda: { + 'meta_declaration': compute_meta_function_declaration(g), + 'native_declaration': + dest.compute_native_function_declaration(g, backend_indices[dispatch_key]), + 'native_definitions': dest.compute_ufunc_cpu(g), + }) + cpu_vec_fm.write_with_template(f'UfuncCPUKernel_{name}.cpp', 'UfuncCPUKernel.cpp', lambda: { + 'name': name, + 'native_definitions': dest.compute_ufunc_cpu_kernel(g), + }) + elif dispatch_key is DispatchKey.CUDA: + cuda_headers = "#include " + if rocm: + cuda_headers = "#include " + fm.write_with_template(f'UfuncCUDA_{name}.cu', 'UfuncCUDA.cu', lambda: { + 'name': name, + 'cuda_headers': cuda_headers, + 'meta_declaration': compute_meta_function_declaration(g), + 'native_declaration': + dest.compute_native_function_declaration(g, backend_indices[dispatch_key]), + 'native_definitions': dest.compute_ufunc_cuda(g), + }) + else: + raise AssertionError(f'unrecognized {dispatch_key} for ufunc') + + del fm + # BackendSelect is generated specially def gen_backend_select() -> Dict[str, List[str]]: relevant_fns = [fn for fn in native_functions if needs_backend_select(fn, selector)] @@ -1601,6 +1653,8 @@ def main() -> None: parsed_yaml = parse_native_yaml(native_yaml_path) native_functions, backend_indices = parsed_yaml.native_functions, parsed_yaml.backend_indices grouped_native_functions = get_grouped_native_functions(native_functions) + structured_native_functions = [g for g in grouped_native_functions + if isinstance(g, NativeFunctionsGroup)] template_dir = os.path.join(options.source_path, "templates") @@ -1620,10 +1674,15 @@ def main() -> None: pathlib.Path(ops_install_dir).mkdir(parents=True, exist_ok=True) def make_file_manager(install_dir: str) -> FileManager: - return FileManager(install_dir=install_dir, template_dir=template_dir, dry_run=options.dry_run) + return FileManager( + install_dir=install_dir, + template_dir=template_dir, + dry_run=options.dry_run + ) core_fm = make_file_manager(core_install_dir) cpu_fm = make_file_manager(options.install_dir) + cpu_vec_fm = make_file_manager(options.install_dir) cuda_fm = make_file_manager(options.install_dir) ops_fm = make_file_manager(ops_install_dir) @@ -1661,11 +1720,13 @@ def make_file_manager(install_dir: str) -> FileManager: gen_source_files( native_functions=native_functions, grouped_native_functions=grouped_native_functions, + structured_native_functions=structured_native_functions, static_dispatch_idx=static_dispatch_idx, selector=selector, backend_indices=backend_indices, core_fm=core_fm, cpu_fm=cpu_fm, + cpu_vec_fm=cpu_vec_fm, cuda_fm=cuda_fm, dispatch_keys=dispatch_keys, functions_keys=functions_keys, @@ -1678,6 +1739,7 @@ def make_file_manager(install_dir: str) -> FileManager: gen_headers( native_functions=native_functions, grouped_native_functions=grouped_native_functions, + structured_native_functions=structured_native_functions, static_dispatch_idx=static_dispatch_idx, selector=selector, backend_indices=backend_indices, @@ -1703,6 +1765,7 @@ def make_file_manager(install_dir: str) -> FileManager: for fm, prefix in [ (cpu_fm, ""), + (cpu_vec_fm, "cpu_vec_"), (core_fm, "core_"), (cuda_fm, "cuda_"), (ops_fm, "ops_"), diff --git a/tools/codegen/model.py b/tools/codegen/model.py index fab6ba3affc66..3d92d92dfe010 100644 --- a/tools/codegen/model.py +++ b/tools/codegen/model.py @@ -166,6 +166,97 @@ def is_cuda_dispatch_key(dk: DispatchKey) -> bool: def is_structured_dispatch_key(dk: DispatchKey) -> bool: return dk in STRUCTURED_DISPATCH_KEYS +def is_ufunc_dispatch_key(dk: DispatchKey) -> bool: + # For now, ufunc dispatch keys coincide with structured keys + return dk in STRUCTURED_DISPATCH_KEYS + +# This is oddly named ScalarType and not DType for symmetry with C++ +class ScalarType(Enum): + Byte = auto() + Char = auto() + Short = auto() + Int = auto() + Long = auto() + Half = auto() + Float = auto() + Double = auto() + ComplexHalf = auto() + ComplexFloat = auto() + ComplexDouble = auto() + Bool = auto() + BFloat16 = auto() + + def __str__(self) -> str: + return self.name + + @staticmethod + def maybe_parse(value: str) -> Optional['ScalarType']: + for k, v in ScalarType.__members__.items(): + if k == value: + return v + return None + + @staticmethod + def parse(value: str) -> 'ScalarType': + mb_r = ScalarType.maybe_parse(value) + assert mb_r is not None, f'unknown dtype {value}' + return mb_r + + @staticmethod + def parse_set(values: str) -> Set['ScalarType']: + dtypes: Set[ScalarType] = set() + for value in values.split(', '): + if value in DTYPE_CLASSES: + dtypes.update(DTYPE_CLASSES[value]) + else: + dtypes.add(ScalarType.parse(value)) + return dtypes + + +DTYPE_CLASSES: Dict[str, Set[ScalarType]] = {} +# NB: Integral doesn't include boolean +DTYPE_CLASSES["Integral"] = { + ScalarType.Byte, ScalarType.Char, ScalarType.Int, ScalarType.Long, + ScalarType.Short +} +# NB: Floating doesn't include low precision types +DTYPE_CLASSES["Floating"] = {ScalarType.Float, ScalarType.Double} +DTYPE_CLASSES["Complex"] = {ScalarType.ComplexFloat, ScalarType.ComplexDouble} +DTYPE_CLASSES["All"] = DTYPE_CLASSES["Integral"] | DTYPE_CLASSES["Floating"] +DTYPE_CLASSES["AllAndComplex"] = DTYPE_CLASSES["All"] | DTYPE_CLASSES["Complex"] +DTYPE_CLASSES["FloatingAndComplex"] = DTYPE_CLASSES["Floating"] | DTYPE_CLASSES["Complex"] + + +# Represents the valid entries for ufunc_inner_loop in native_functions.yaml. +# NB: if you add a new UfuncKey, you will teach tools.codegen.dest.ufunc how +# to process it. Most logic will ignore keys they don't understand, so your +# new key will get silently ignored until you hook in logic to deal with it. +class UfuncKey(Enum): + # These are low level keys that represent exactly one particular + # instantiation of the kernel produced by codegen + CUDAFunctor = auto() + CUDAFunctorOnOther = auto() + CUDAFunctorOnSelf = auto() + + CPUScalar = auto() + CPUVector = auto() + + # These are the ones users will usually specify, and + # implicitly "fill in" the low level keys + ScalarOnly = auto() # CUDA*, CPUScalar + Generic = auto() # CUDA*, CPU* + + def __str__(self) -> str: + return self.name + + @staticmethod + def parse(value: str) -> 'UfuncKey': + for k, v in UfuncKey.__members__.items(): + if k == value: + return v + raise AssertionError(f'unknown ufunc key {value}') + + class DeviceCheckType(Enum): NoCheck = 0 ExactSame = 1 @@ -239,6 +330,10 @@ class NativeFunction: # defined. This is for conveniently reporting error messages! loc: 'Location' + # If non-empty, this kernel is subject to ufunc codegen. + # Sorted by ufunc_key + ufunc_inner_loop: Dict[UfuncKey, 'UfuncInnerLoop'] + # Whether or not this out functions is a "structured kernel". Structured # kernels are defined a little differently from normal kernels; in # particular, their shape checking logic is defined separately from @@ -413,6 +508,31 @@ def from_yaml( "strictly subsumes the other. If you wanted to provide an explicit autograd " \ "implementation, specify CompositeExplicitAutograd; otherwise specify CompositeImplicitAutograd only" + raw_ufunc_inner_loop = e.pop('ufunc_inner_loop', {}) + ufunc_inner_loop = {} + if isinstance(raw_ufunc_inner_loop, str): + ufunc_inner_loop[UfuncKey.Generic] = UfuncInnerLoop.parse(raw_ufunc_inner_loop, UfuncKey.Generic) + elif isinstance(raw_ufunc_inner_loop, dict): + for k, vo in raw_ufunc_inner_loop.items(): + if k == '__line__': + continue + assert isinstance(k, str), f'ufunc_inner_loop key is not a str: {k}' + assert isinstance(vo, str), f'ufunc_inner_loop value is not a str: {v}' + ufunc_key = UfuncKey.parse(k) + ufunc_inner_loop[ufunc_key] = UfuncInnerLoop.parse(vo, ufunc_key) + else: + raise AssertionError(f'ufunc_inner_loop not str or dict: {raw_ufunc_inner_loop}') + # Program the BackendIndex for the implicit dispatch entry from ufunc + if ufunc_inner_loop: + assert structured, "ufunc must be structured" + for dispatch_key in STRUCTURED_DISPATCH_KEYS: + assert dispatch_key not in dispatch, \ + f"ufunc should not have explicit dispatch entry for {dispatch_key}" + dispatch[dispatch_key] = BackendMetadata( + kernel=ufunc.schema_kernel_name(func, dispatch_key), + structured=True + ) + if structured_delegate: # Structured functions MUST have a dispatch table is_abstract = True @@ -448,6 +568,7 @@ def from_yaml( structured_delegate=structured_delegate, structured_inherits=structured_inherits, precomputed=precomputed, + ufunc_inner_loop=ufunc_inner_loop, manual_kernel_registration=manual_kernel_registration, manual_cpp_binding=manual_cpp_binding, python_module=python_module, @@ -666,7 +787,24 @@ class BackendMetadata: # in native_functions.yaml. # However, external backends like XLA can indendently toggle which ops are structured. structured: bool - # + +@dataclass(frozen=True) +class UfuncInnerLoop: + name: str + supported_dtypes: Set[ScalarType] + # key is stored here because it affects the semantics of name, + # so its helpful to have them together for further processing + ufunc_key: UfuncKey + + @staticmethod + def parse(value: str, ufunc_key: UfuncKey) -> 'UfuncInnerLoop': + name, supported_dtypes_str = value.split(' ', 1) + assert supported_dtypes_str[0] == '(' + assert supported_dtypes_str[-1] == ')' + supported_dtypes = set() + for k in supported_dtypes_str[1:-1].split(', '): + supported_dtypes |= ScalarType.parse_set(k) + return UfuncInnerLoop(name=name, supported_dtypes=supported_dtypes, ufunc_key=ufunc_key) # BackendIndex represents a backend. @@ -1664,3 +1802,5 @@ def to_list(self) -> List[str]: replace_list.append(f'{kernel_param} -> {replacements}') return replace_list + +import tools.codegen.api.ufunc as ufunc diff --git a/tools/test/test_codegen_model.py b/tools/test/test_codegen_model.py new file mode 100644 index 0000000000000..50ea59575bedd --- /dev/null +++ b/tools/test/test_codegen_model.py @@ -0,0 +1,124 @@ +# Owner(s): ["module: codegen"] + +import expecttest +import unittest +import yaml +import textwrap + +from tools.codegen.model import NativeFunctionsGroup, DispatchKey +import tools.codegen.dest as dest +import tools.codegen.gen as gen +from tools.codegen.gen import LineLoader, parse_native_yaml_struct + +class TestCodegenModel(expecttest.TestCase): + def assertParseErrorInline(self, yaml_str: str, expect: str) -> None: + es = yaml.load(yaml_str, Loader=LineLoader) + try: + parse_native_yaml_struct(es) + except AssertionError as e: + # hack to strip out the context + msg, _ = str(e).split(' in ', 2) + self.assertExpectedInline('\n'.join(textwrap.wrap(msg)), expect, skip=1) + return + self.fail(msg="Did not raise when expected to") + + def assertUfuncErrorInline(self, yaml_str: str, expect: str) -> None: + # parse a single structured group out of the yaml to g + es = yaml.load(yaml_str, Loader=LineLoader) + parsed_yaml = parse_native_yaml_struct(es) + native_functions, backend_indices = parsed_yaml.native_functions, parsed_yaml.backend_indices + grouped_native_functions = gen.get_grouped_native_functions(native_functions) + assert len(grouped_native_functions) == 1 + g = grouped_native_functions[0] + assert isinstance(g, NativeFunctionsGroup) + assert g.out.ufunc_inner_loop + # this is not ufunc codegen per se, but it does some basic sanity tests for + # ufunc generation + gen.compute_meta_function_declaration(g) + dest.compute_native_function_declaration(g, backend_indices[DispatchKey.CPU]) + dest.compute_native_function_declaration(g, backend_indices[DispatchKey.CUDA]) + try: + # the real kahuna + dest.compute_ufunc_cpu(g) + dest.compute_ufunc_cpu_kernel(g) + dest.compute_ufunc_cuda(g) + except AssertionError as e: + # hack to strip out the context + msg, _ = str(e).split(' in ', 2) + self.assertExpectedInline('\n'.join(textwrap.wrap(msg)), expect, skip=1) + return + self.fail(msg="Did not raise when expected to") + + # NB: indent is hardcoded to be two here, so format your yaml accordingly + binop_out = 'func: binop.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)' + ti_binop_out = f'''{binop_out} + structured: True + structured_inherits: TensorIteratorBase''' + ti_binop = '''func: binop(Tensor self, Tensor other) -> Tensor + structured_delegate: binop.out +''' + + ti_unop_out = '''func: unop.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + structured: True + structured_inherits: TensorIteratorBase''' + ti_unop = '''func: unop(Tensor self) -> Tensor + structured_delegate: unop.out +''' + + def test_nonstructured_ufunc(self) -> None: + yaml_str = f'''\ +- {self.binop_out} + ufunc_inner_loop: + Generic: binop (Bool) +''' + self.assertParseErrorInline(yaml_str, '''\ +ufunc must be structured''') + + def test_overlapping_ufunc_and_dispatch(self) -> None: + yaml_str = f'''\ +- {self.ti_binop_out} + ufunc_inner_loop: + Generic: binop (Bool) + dispatch: + CPU: binop_cpu +''' + self.assertParseErrorInline(yaml_str, '''\ +ufunc should not have explicit dispatch entry for CPU''') + + # See https://github.com/pytorch/pytorch/pull/65851#discussion_r810238456 + @unittest.expectedFailure + def test_scalaronly_shadowed(self) -> None: + yaml_str = f'''\ +- {self.ti_binop_out} + ufunc_inner_loop: + Generic: binop (Bool) + ScalarOnly: binop (Bool) +''' + self.assertParseErrorInline(yaml_str, '''\ +''') + + def test_conflicting_ufunc(self) -> None: + yaml_str = f'''\ +- {self.ti_binop_out} + ufunc_inner_loop: + Generic: binop (Bool) + ScalarOnly: binop_scalar (Bool) +- {self.ti_binop} +''' + self.assertUfuncErrorInline(yaml_str, '''\ +ScalarOnly and Generic must have same ufunc name''') + + def test_invalid_cudafunctoronself_for_binary_op(self) -> None: + yaml_str = f'''\ +- {self.ti_unop_out} + ufunc_inner_loop: + Generic: unop (All) + CUDAFunctorOnSelf: unop_self_cuda (All) +- {self.ti_unop} +''' + self.assertUfuncErrorInline(yaml_str, '''\ +cannot use CUDAFunctorOnSelf on non-binary function''') + + +if __name__ == '__main__': + unittest.main() diff --git a/tools/ufunc_defs.bzl b/tools/ufunc_defs.bzl new file mode 100644 index 0000000000000..4490f05be0151 --- /dev/null +++ b/tools/ufunc_defs.bzl @@ -0,0 +1,25 @@ +load("@bazel_skylib//lib:paths.bzl", "paths") +load(":build_variables.bzl", "aten_ufunc_headers") + +aten_ufunc_names = [ + paths.split_extension(paths.basename(h))[0] + for h in aten_ufunc_headers +] + +def aten_ufunc_generated_cpu_sources(gencode_pattern = "{}"): + return [gencode_pattern.format(name) for name in [ + "UfuncCPU_{}.cpp".format(n) + for n in aten_ufunc_names + ]] + +def aten_ufunc_generated_cpu_kernel_sources(gencode_pattern = "{}"): + return [gencode_pattern.format(name) for name in [ + "UfuncCPUKernel_{}.cpp".format(n) + for n in aten_ufunc_names + ]] + +def aten_ufunc_generated_cuda_sources(gencode_pattern = "{}"): + return [gencode_pattern.format(name) for name in [ + "UfuncCUDA_{}.cu".format(n) + for n in aten_ufunc_names + ]]