diff --git a/aten/src/ATen/TensorSubclassLikeUtils.h b/aten/src/ATen/TensorSubclassLikeUtils.h index 7f5517bc08114a..e9f5e7d26e112c 100644 --- a/aten/src/ATen/TensorSubclassLikeUtils.h +++ b/aten/src/ATen/TensorSubclassLikeUtils.h @@ -28,8 +28,7 @@ constexpr auto kFunctorchWrappedTensors = DispatchKeySet({ constexpr auto kTensorSubclassLike = kFunctorchWrappedTensors | DispatchKeySet({ DispatchKey::Batched, - DispatchKey::SparseCPU, - DispatchKey::SparseCUDA, + DispatchKey::Sparse, DispatchKey::SparseCsrCPU, DispatchKey::SparseCsrCUDA, DispatchKey::Meta, diff --git a/aten/src/ATen/core/dispatch/DispatchKeyExtractor.cpp b/aten/src/ATen/core/dispatch/DispatchKeyExtractor.cpp index a930edc2db6328..9180d0d19e6449 100644 --- a/aten/src/ATen/core/dispatch/DispatchKeyExtractor.cpp +++ b/aten/src/ATen/core/dispatch/DispatchKeyExtractor.cpp @@ -6,11 +6,52 @@ namespace c10 { void DispatchKeyExtractor::setOperatorHasFallthroughForKey(DispatchKey k, bool has_fallthrough) { + // (1) update nonFallthroughKeys_ if (has_fallthrough) { nonFallthroughKeys_ = nonFallthroughKeys_.remove(k); } else { nonFallthroughKeys_ = nonFallthroughKeys_.add(k); } + // (2) update nonFallthroughKeysPerBackend_ + if (isPerBackendFunctionalityKey(toFunctionalityKey(k))) { + // This is a per-backend functionality key. + // We need to figure out what the current backend is, + // and only update the bitset for that backend. + // subtracting 1 because the first backend should have index 0 (CPU), + // But the enum starts with BackendComponent::InvalidBit. + auto backend_idx = static_cast(toBackendComponent(k)) - 1; + TORCH_INTERNAL_ASSERT(backend_idx >= 0 && static_cast(backend_idx) < nonFallthroughKeysPerBackend_.size()); + if (has_fallthrough) { + nonFallthroughKeysPerBackend_[backend_idx] = nonFallthroughKeysPerBackend_[backend_idx].remove(k); + } else { + nonFallthroughKeysPerBackend_[backend_idx] = nonFallthroughKeysPerBackend_[backend_idx].add(k); + } + + // Set requiresBitsetPerBackend_ accordingly + for (const auto i : c10::irange(nonFallthroughKeysPerBackend_.size() - 1)) { + if (nonFallthroughKeysPerBackend_[i] != nonFallthroughKeysPerBackend_[i+1]) { + requiresBitsetPerBackend_ = true; + return; + } + } + requiresBitsetPerBackend_ = false; + return; + } else { + // Otherwise, if a fallthrough is set for a functionality that isn't per backend, + // Then we update the fallthrough bitset for EVERY backend. + // TODO: we could probably optimize this by only lazily updating these values + // the first time that we see requiresBitsetPerBackend_ = true + // (which should almost never happen) + if (has_fallthrough) { + for (const auto i : c10::irange(nonFallthroughKeysPerBackend_.size())) { + nonFallthroughKeysPerBackend_[i] = nonFallthroughKeysPerBackend_[i].remove(k); + } + } else { + for (const auto i : c10::irange(nonFallthroughKeysPerBackend_.size())) { + nonFallthroughKeysPerBackend_[i] = nonFallthroughKeysPerBackend_[i].add(k); + } + } + } } std::string DispatchKeyExtractor::dumpState() const { diff --git a/aten/src/ATen/core/dispatch/DispatchKeyExtractor.h b/aten/src/ATen/core/dispatch/DispatchKeyExtractor.h index 53e348d6b99ea9..d5345b28e7149f 100644 --- a/aten/src/ATen/core/dispatch/DispatchKeyExtractor.h +++ b/aten/src/ATen/core/dispatch/DispatchKeyExtractor.h @@ -156,14 +156,24 @@ struct TORCH_API DispatchKeyExtractor final { } }); // Keys that are fallthrough should be skipped - return impl::computeDispatchKeySet(ks, nonFallthroughKeys_); + if (requiresBitsetPerBackend_) { + auto backend_idx = ks.getBackendIndex(); + return impl::computeDispatchKeySet(ks, nonFallthroughKeysPerBackend_[backend_idx]); + } else { + return impl::computeDispatchKeySet(ks, nonFallthroughKeys_); + } } template DispatchKeySet getDispatchKeySetUnboxed(const Args&... args) const { auto ks = detail::multi_dispatch_key_set(args...); // Keys that are fallthrough should be skipped - return impl::computeDispatchKeySet(ks, nonFallthroughKeys_); + if (requiresBitsetPerBackend_) { + auto backend_idx = ks.getBackendIndex(); + return impl::computeDispatchKeySet(ks, nonFallthroughKeysPerBackend_[backend_idx]); + } else { + return impl::computeDispatchKeySet(ks, nonFallthroughKeys_); + } } void setOperatorHasFallthroughForKey(DispatchKey k, bool has_fallthrough); @@ -193,7 +203,12 @@ struct TORCH_API DispatchKeyExtractor final { explicit DispatchKeyExtractor(c10::utils::bitset dispatch_arg_indices_reverse) : dispatch_arg_indices_reverse_(dispatch_arg_indices_reverse) - , nonFallthroughKeys_(DispatchKeySet::FULL) {} + , nonFallthroughKeys_(DispatchKeySet::FULL) + , requiresBitsetPerBackend_(false) { + for (const auto i : c10::irange(nonFallthroughKeysPerBackend_.size())) { + nonFallthroughKeysPerBackend_[i] = DispatchKeySet::FULL; + } + } // this is a bitset that has ones for each argument index which has to be // considered for dispatch. This avoids having to iterate over the stack @@ -205,8 +220,14 @@ struct TORCH_API DispatchKeyExtractor final { // fallthrough c10::utils::bitset dispatch_arg_indices_reverse_; - // Set of keys for which the operator does NOT have fallthrough kernel. + // Set of functionality keys for which the operator does NOT have fallthrough kernel. DispatchKeySet nonFallthroughKeys_; + // Set of functionality keys for which the operator does NOT have fallthrough kernel, defined PER BACKEND. + // This is only needed if we know that the operator has a different set of fallthroughs defined for some backends. + std::array nonFallthroughKeysPerBackend_; + // Flag to tell us if we can use the single set of nonFallthroughKeys_ (fast path), + // or if we need to fall back to the slower path and check nonFallthroughKeysPerBackend_ + bool requiresBitsetPerBackend_; }; } diff --git a/aten/src/ATen/core/dispatch/Dispatcher.cpp b/aten/src/ATen/core/dispatch/Dispatcher.cpp index 3dccc4645a824c..f2426f6bb1f1a8 100644 --- a/aten/src/ATen/core/dispatch/Dispatcher.cpp +++ b/aten/src/ATen/core/dispatch/Dispatcher.cpp @@ -267,14 +267,15 @@ void Dispatcher::cleanup(const OperatorHandle& op, const OperatorName& op_name) RegistrationHandleRAII Dispatcher::registerFallback(DispatchKey dispatchKey, KernelFunction kernel, std::string debug) { std::lock_guard lock(mutex_); + auto idx = getDispatchTableIndexForDispatchKey(dispatchKey); TORCH_CHECK( - !backendFallbackKernels_[static_cast(dispatchKey)].kernel.isValid(), + !backendFallbackKernels_[idx].kernel.isValid(), "Tried to register multiple backend fallbacks for the same dispatch key ", dispatchKey, "; previous registration ", - backendFallbackKernels_[static_cast(dispatchKey)].debug, ", new registration ", debug + backendFallbackKernels_[idx].debug, ", new registration ", debug ); // NB: inferred function schema is always nullptr for fallbacks, as fallbacks // cannot be unobxed - backendFallbackKernels_[static_cast(dispatchKey)] = impl::AnnotatedKernel(std::move(kernel), nullptr, std::move(debug)); + backendFallbackKernels_[idx] = impl::AnnotatedKernel(std::move(kernel), nullptr, std::move(debug)); for (auto& op : operators_) { op.op.updateFallback(*this, dispatchKey); @@ -288,7 +289,8 @@ RegistrationHandleRAII Dispatcher::registerFallback(DispatchKey dispatchKey, Ker void Dispatcher::deregisterFallback_(DispatchKey dispatchKey) { std::lock_guard lock(mutex_); - backendFallbackKernels_[static_cast(dispatchKey)] = {}; + auto idx = getDispatchTableIndexForDispatchKey(dispatchKey); + backendFallbackKernels_[idx] = {}; for (auto& op : operators_) { op.op.updateFallback(*this, dispatchKey); diff --git a/aten/src/ATen/core/dispatch/Dispatcher.h b/aten/src/ATen/core/dispatch/Dispatcher.h index 14ffa2f94c9c8c..8108c3c1928b81 100644 --- a/aten/src/ATen/core/dispatch/Dispatcher.h +++ b/aten/src/ATen/core/dispatch/Dispatcher.h @@ -291,7 +291,7 @@ class TORCH_API Dispatcher final { // Map from namespace to debug string (saying, e.g., where the library was defined) ska::flat_hash_map libraries_; - std::array(DispatchKey::NumDispatchKeys)> backendFallbackKernels_; + std::array backendFallbackKernels_; std::unique_ptr listeners_; std::mutex mutex_; @@ -531,8 +531,7 @@ C10_DISPATCHER_INLINE_UNLESS_MOBILE Return Dispatcher::call(const TypedOperatorH detail::unused_arg_(args...); // workaround for a false-positive warning about unused parameters in gcc 5 auto dispatchKeySet = op.operatorDef_->op.dispatchKeyExtractor() .template getDispatchKeySetUnboxed(args...); - TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!c10::isAliasDispatchKey(dispatchKeySet.highestPriorityTypeId())); - const KernelFunction& kernel = op.operatorDef_->op.lookup(dispatchKeySet.highestPriorityTypeId()); + const KernelFunction& kernel = op.operatorDef_->op.lookup(dispatchKeySet); #ifndef PYTORCH_DISABLE_PER_OP_PROFILING // By default, when there're no high-frequency or non-sampled callbacks, // RecordFunction is pre-sampled as a perf optimization; @@ -553,7 +552,7 @@ template inline Return Dispatcher::redispatch(const TypedOperatorHandle& op, DispatchKeySet currentDispatchKeySet, Args... args) const { detail::unused_arg_(args...); // workaround for a false-positive warning about unused parameters in gcc 5 // do not use RecordFunction on redispatch - const KernelFunction& kernel = op.operatorDef_->op.lookup(currentDispatchKeySet.highestPriorityTypeId()); + const KernelFunction& kernel = op.operatorDef_->op.lookup(currentDispatchKeySet); return kernel.template call(op, currentDispatchKeySet, std::forward(args)...); } @@ -561,7 +560,7 @@ inline void Dispatcher::callBoxed(const OperatorHandle& op, Stack* stack) const // note: this doesn't need the mutex because write operations on the list keep iterators intact. const auto& entry = op.operatorDef_->op; auto dispatchKeySet = entry.dispatchKeyExtractor().getDispatchKeySetBoxed(stack); - const auto& kernel = entry.lookup(dispatchKeySet.highestPriorityTypeId()); + const auto& kernel = entry.lookup(dispatchKeySet); #ifndef PYTORCH_DISABLE_PER_OP_PROFILING bool pre_sampled = false; if (C10_UNLIKELY(at::shouldRunRecordFunction(&pre_sampled))) { @@ -593,7 +592,7 @@ inline void Dispatcher::callBoxed(const OperatorHandle& op, Stack* stack) const inline void Dispatcher::redispatchBoxed(const OperatorHandle& op, DispatchKeySet dispatchKeySet, Stack* stack) const { // note: this doesn't need the mutex because write operations on the list keep iterators intact. const auto& entry = op.operatorDef_->op; - const auto& kernel = entry.lookup(dispatchKeySet.highestPriorityTypeId()); + const auto& kernel = entry.lookup(dispatchKeySet); return kernel.callBoxed(op, dispatchKeySet, stack); } diff --git a/aten/src/ATen/core/dispatch/OperatorEntry.cpp b/aten/src/ATen/core/dispatch/OperatorEntry.cpp index d4d997fde69aef..ce339abb05d9f9 100644 --- a/aten/src/ATen/core/dispatch/OperatorEntry.cpp +++ b/aten/src/ATen/core/dispatch/OperatorEntry.cpp @@ -283,7 +283,8 @@ std::pair OperatorEntry::computeDispatchTab } // 3. Backend fallback - auto dispatch_ix = static_cast(dispatch_key); + auto dispatch_ix = getDispatchTableIndexForDispatchKey(dispatch_key); + TORCH_INTERNAL_ASSERT(dispatch_ix != -1); if (dispatcher.backendFallbackKernels_[dispatch_ix].kernel.isValid()) { return {dispatcher.backendFallbackKernels_[dispatch_ix], "backend fallback"}; } @@ -299,7 +300,7 @@ std::pair OperatorEntry::computeDispatchTab // or alias keys and their associated keysets). // This function should be considered a private helper for updateDispatchTable_() void OperatorEntry::updateDispatchTableEntry_(const c10::Dispatcher& dispatcher, DispatchKey dispatch_key) { - const auto dispatch_ix = c10::getDispatchTableIndexForDispatchKey(dispatch_key); + const auto dispatch_ix = getDispatchTableIndexForDispatchKey(dispatch_key); if (C10_UNLIKELY(dispatch_ix == -1)) { return; } @@ -329,8 +330,12 @@ void OperatorEntry::updateDispatchTable_(const c10::Dispatcher& dispatcher, Disp } // Note [Refresh Runtime Autograd entries in dispatchTable_] // Registering to backend key might affect computed entry at its Autograd backend key due to (2.1) & (2.3). + // In theory, we should only have to check if the given runtime key has "dense" functionality, + // e.g. DispatchKey::CPU (which is composed of DispatchKey::Dense and BackendComponent::CPUBit). + // However, there are some backends that should be included in this set that don't have the dense key set. + // E.g. DispatchKey::Meta, DispatchKey::ORT. if (c10::isBackendDispatchKey(dispatch_key)) { - DispatchKey autograd_key = getAutogradKeyFromBackend(dispatch_key); + DispatchKey autograd_key = getAutogradKeyFromBackend(toBackendComponent(dispatch_key)); updateDispatchTableEntry_(dispatcher, autograd_key); } } @@ -357,8 +362,9 @@ void OperatorEntry::updateDispatchTableFull_(const c10::Dispatcher& dispatcher) // catchAll. After catchAllKernel_ is removed, Undefined now can get a kernel from either CompositeExplicitAutograd // or CompositeImplicitAutograd alias key so that we don't break the support. Ideally isIncludedInAlias(Undefined, CompositeImplicitAutograd) // should return true, it returns false because Undefined cannot be represented in a DispatchKeySet. - for (uint8_t iter = 0; iter != static_cast(DispatchKey::NumDispatchKeys); ++iter) { - updateDispatchTable_(dispatcher, static_cast(iter)); + updateDispatchTable_(dispatcher, DispatchKey::Undefined); + for (auto k : DispatchKeySet(DispatchKeySet::FULL)) { + updateDispatchTable_(dispatcher, k); } } @@ -371,9 +377,13 @@ void OperatorEntry::checkInvariants() const { for (const auto& kv : kernels_) { TORCH_INTERNAL_ASSERT(kv.second.size() > 0, dumpState()); } - for (uint8_t iter = 0; iter != static_cast(DispatchKey::NumDispatchKeys); ++iter) { - auto expected_k = computeDispatchTableEntry(c10::Dispatcher::singleton(), static_cast(iter)); - TORCH_INTERNAL_ASSERT(expected_k._equalsBoxedAndUnboxed(dispatchTable_[iter]), + for (auto k : DispatchKeySet(DispatchKeySet::FULL)) { + auto expected_k = computeDispatchTableEntry(c10::Dispatcher::singleton(), k); + auto idx = getDispatchTableIndexForDispatchKey(k); + if (C10_UNLIKELY(idx == -1)) { + continue; + } + TORCH_INTERNAL_ASSERT(expected_k._equalsBoxedAndUnboxed(dispatchTable_[idx]), "Canonical state\n~~~~~~~~~~~\n", dumpState(), "\n\n" "Computed table:\n~~~~~~~~~~~\n", dumpComputedTable()); } @@ -384,8 +394,9 @@ std::string OperatorEntry::listAllDispatchKeys() const { str << "["; bool has_kernels = false; - for (uint8_t iter = 0; iter != static_cast(DispatchKey::NumDispatchKeys); ++iter) { - if (!dispatchTable_[iter].isValid()) { + for (auto k : DispatchKeySet(DispatchKeySet::FULL)) { + auto iter = getDispatchTableIndexForDispatchKey(k); + if (iter == -1 || !dispatchTable_[iter].isValid()) { continue; } if (has_kernels) { @@ -443,8 +454,12 @@ void OperatorEntry::reportError(DispatchKey dispatchKey) const { // updateDispatchTableFull_ would update the dispatch table to be) std::string OperatorEntry::dumpComputedTable() const { std::ostringstream oss; - for (uint8_t i = 0; i < static_cast(DispatchKey::NumDispatchKeys); i++) { - auto k = static_cast(i); + // Need to handle Undefined separately, because its a runtime key that can't be represented + // in a DispatchKeySet. + std::vector runtime_keys = {DispatchKey::Undefined}; + for (auto k : DispatchKeySet(DispatchKeySet::FULL)) runtime_keys.push_back(k); + + for (auto k : runtime_keys) { auto kernel_prov = computeDispatchTableEntryWithDebug(c10::Dispatcher::singleton(), k); if (kernel_prov.first.kernel.isValid()) { oss << toString(k) << ": " diff --git a/aten/src/ATen/core/dispatch/OperatorEntry.h b/aten/src/ATen/core/dispatch/OperatorEntry.h index d98bd6bc69041a..c0f90808280a8e 100644 --- a/aten/src/ATen/core/dispatch/OperatorEntry.h +++ b/aten/src/ATen/core/dispatch/OperatorEntry.h @@ -173,10 +173,10 @@ class TORCH_API OperatorEntry final { [[noreturn]] void reportError(DispatchKey dispatchKey) const; - const KernelFunction& lookup(DispatchKey k) const { - const auto idx = getDispatchTableIndexForDispatchKey(k); + const KernelFunction& lookup(DispatchKeySet ks) const { + const auto idx = ks.getDispatchTableIndexForDispatchKeySet(); if (C10_UNLIKELY(idx == -1)) { - reportError(k); + reportError(ks.highestPriorityTypeId()); } const auto& kernel = dispatchTable_[idx]; // A valid kernel *always* has a boxed kernel and *may* have an @@ -187,7 +187,7 @@ class TORCH_API OperatorEntry final { // in the common case. if (C10_UNLIKELY(!kernel.isValidUnboxed())) { if (!kernel.isValid()) { - reportError(k); + reportError(ks.highestPriorityTypeId()); } } return kernel; @@ -211,7 +211,7 @@ class TORCH_API OperatorEntry final { OperatorName name_; c10::optional schema_; - std::array dispatchTable_; + std::array dispatchTable_; DispatchKeyExtractor dispatchKeyExtractor_; // kernels_ stores all registered kernels for the corresponding dispatch key diff --git a/aten/src/ATen/core/op_registration/op_registration_test.cpp b/aten/src/ATen/core/op_registration/op_registration_test.cpp index e9d3b1bddaaaf5..05294c25548eb1 100644 --- a/aten/src/ATen/core/op_registration/op_registration_test.cpp +++ b/aten/src/ATen/core/op_registration/op_registration_test.cpp @@ -592,7 +592,7 @@ TEST(OperatorRegistrationTest, AutogradBackendOverridesAutogradKernel) { void LazyBackendsAutogradOverridesAutogradKernel(DispatchKey key) { auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options() - .kernel(c10::getAutogradKeyFromBackend(key)) + .kernel(c10::getAutogradKeyFromBackend(toBackendComponent(key))) .kernel(DispatchKey::Autograd)); auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""}); @@ -1791,22 +1791,22 @@ TEST(NewOperatorRegistrationTest, dispatchAutogradPrecedence) { TEST(NewOperatorRegistrationTest, throwsWhenRegisterToBackendMapsToAutogradOther) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - bool sparsecpu_called, math_called = false; + bool fpga_called, math_called = false; auto m = MAKE_TORCH_LIBRARY(test); - m.def("fn", torch::dispatch(c10::DispatchKey::SparseCPU, [&](const Tensor& x) { sparsecpu_called = true; return x; })); + m.def("fn", torch::dispatch(c10::DispatchKey::FPGA, [&](const Tensor& x) { fpga_called = true; return x; })); m.impl("fn", c10::DispatchKey::CompositeImplicitAutograd, [&](const Tensor& x) { math_called = true; return x; }); auto op = Dispatcher::singleton().findSchema({"test::fn", ""}); ASSERT_TRUE(op.has_value()); { - callOp(*op, dummyTensor(c10::DispatchKey::SparseCPU)); - ASSERT_TRUE(sparsecpu_called); + callOp(*op, dummyTensor(c10::DispatchKey::FPGA)); + ASSERT_TRUE(fpga_called); } { expectThrows([&] { - callOp(*op, dummyTensor(c10::DispatchKey::SparseCPU, /*requires_grad=*/true)); + callOp(*op, dummyTensor(c10::DispatchKey::FPGA, /*requires_grad=*/true)); }, "test::fn has kernels registered to both CompositeImplicitAutograd and a backend mapped to AutogradOther."); } } @@ -1849,18 +1849,15 @@ TEST(NewOperatorRegistrationTest, dispatchMultipleTensors) { } { - // TODO(#43908): currently this will fallthrough AutogradPrivateUse1 then call catchall kernel - // at AutogradCPU, while backend extenders are indeed expecting to call PrivateUse1 kernel. - // This confusing behavior is caused by we registering fallthrough as backend fallback for - // Autograd keys. Note users could always work around this by registering the same kernel to - // AutogradPrivateUse1 as shown below until we support it. auto op = Dispatcher::singleton().findOp({"test::fn", ""}); ASSERT_TRUE(op.has_value()); catchall_called = false; + privateuse1_called = false; callOp(*op, dummyTensor(c10::DispatchKey::PrivateUse1, /*requires_grad=*/true), dummyTensor(c10::DispatchKey::CPU, /*requires_grad=*/true)); - ASSERT_TRUE(catchall_called); + ASSERT_FALSE(catchall_called); + ASSERT_TRUE(privateuse1_called); } m.impl("fn", c10::DispatchKey::AutogradPrivateUse1, [&](const Tensor& x, const Tensor& y) { privateuse1_called = true; return x; }); @@ -1876,6 +1873,27 @@ TEST(NewOperatorRegistrationTest, dispatchMultipleTensors) { } } +TEST(NewOperatorRegistrationTest, registerCompositeImplicitAutogradWithCPUKernel_andCallAutogradOtherKernel_callsComposite) { + bool math_called = false; + bool cpu_called = false; + auto m = MAKE_TORCH_LIBRARY(test); + m.def("fn(Tensor dummy) -> Tensor"); + m.impl("fn", c10::DispatchKey::CPU, [&](const Tensor& x) { cpu_called = true; return x; }); + m.impl("fn", c10::DispatchKey::CompositeImplicitAutograd, [&](const Tensor& x) { math_called = true; return x; }); + + auto op = Dispatcher::singleton().findSchema({"test::fn", ""}); + ASSERT_TRUE(op.has_value()); + + { + math_called = cpu_called = false; + // Meta should redispatch to the AutogradOther backend, + // which the composite kernel should be registered to. + callOp(*op, dummyTensor(c10::DispatchKey::Meta, /*requires_grad=*/true)); + ASSERT_TRUE(math_called); + ASSERT_FALSE(cpu_called); + } +} + TEST(NewOperatorRegistrationTest, dispatchMultiple) { bool cpu_called = false; bool cuda_called = false; diff --git a/c10/core/DispatchKey.cpp b/c10/core/DispatchKey.cpp index 6dbcaf88d5db78..ab9f41e58f3e3a 100644 --- a/c10/core/DispatchKey.cpp +++ b/c10/core/DispatchKey.cpp @@ -1,14 +1,47 @@ #include +#include #include namespace c10 { +const char* toString(BackendComponent t) { + switch (t) { + case BackendComponent::CPUBit: + return "CPUBit"; + case BackendComponent::CUDABit: + return "CUDABit"; + case BackendComponent::HIPBit: + return "HIPBit"; + case BackendComponent::XLABit: + return "XLABit"; + case BackendComponent::LazyBit: + return "LazyBit"; + case BackendComponent::XPUBit: + return "XPUBit"; + case BackendComponent::MLCBit: + return "MLCBit"; + case BackendComponent::HPUBit: + return "HPUBit"; + case BackendComponent::VEBit: + return "VEBit"; + case BackendComponent::PrivateUse1Bit: + return "PrivateUse1Bit"; + case BackendComponent::PrivateUse2Bit: + return "PrivateUse2Bit"; + case BackendComponent::PrivateUse3Bit: + return "PrivateUse3Bit"; + case BackendComponent::InvalidBit: + return "InvalidBit"; + default: + return "UNKNOWN_BACKEND_BIT"; + } +} + const char* toString(DispatchKey t) { switch (t) { case DispatchKey::Undefined: return "Undefined"; - case DispatchKey::CPU: return "CPU"; case DispatchKey::CUDA: @@ -103,8 +136,6 @@ const char* toString(DispatchKey t) { return "AutogradMLC"; case DispatchKey::AutogradHPU: return "AutogradHPU"; - case DispatchKey::AutogradNestedTensor: - return "AutogradNestedTensor"; case DispatchKey::AutogradPrivateUse1: return "AutogradPrivateUse1"; case DispatchKey::AutogradPrivateUse2: @@ -113,6 +144,8 @@ const char* toString(DispatchKey t) { return "AutogradPrivateUse3"; case DispatchKey::AutogradOther: return "AutogradOther"; + case DispatchKey::AutogradNestedTensor: + return "AutogradNestedTensor"; case DispatchKey::ZeroTensor: return "ZeroTensor"; @@ -170,6 +203,15 @@ const char* toString(DispatchKey t) { case DispatchKey::FuncTorchBatched: return "FuncTorchBatched"; + case DispatchKey::Dense: + return "Dense"; + case DispatchKey::Quantized: + return "Quantized"; + case DispatchKey::Sparse: + return "Sparse"; + case DispatchKey::AutogradFunctionality: + return "AutogradFunctionality"; + default: return "UNKNOWN_TENSOR_TYPE_ID"; } @@ -178,76 +220,37 @@ const char* toString(DispatchKey t) { std::ostream& operator<<(std::ostream& str, DispatchKey rhs) { return str << toString(rhs); } +std::ostream& operator<<(std::ostream& str, BackendComponent rhs) { + return str << toString(rhs); +} -// for a given backend key, return the associated autograd key. -// for non-backend keys, return AutogradOther as a default. -// Note: it's convenient and fast to return a default here rather than (say) -// returning an optional, or throwing. But it makes callers -// responsible for either a) enforcing the invariant that only backend keys -// be passed as arguments, or b) interpreting our return value carefully. -// -DispatchKey getAutogradKeyFromBackend(DispatchKey t) { - switch (t) { - case DispatchKey::CPU: - return DispatchKey::AutogradCPU; - case DispatchKey::XPU: - return DispatchKey::AutogradXPU; - case DispatchKey::CUDA: - return DispatchKey::AutogradCUDA; - case DispatchKey::XLA: - return DispatchKey::AutogradXLA; - case DispatchKey::Lazy: - return DispatchKey::AutogradLazy; - case DispatchKey::MLC: - return DispatchKey::AutogradMLC; - case DispatchKey::HPU: - return DispatchKey::AutogradHPU; - case DispatchKey::NestedTensor: - return DispatchKey::AutogradNestedTensor; - case DispatchKey::PrivateUse1: - return DispatchKey::AutogradPrivateUse1; - case DispatchKey::PrivateUse2: - return DispatchKey::AutogradPrivateUse2; - case DispatchKey::PrivateUse3: - return DispatchKey::AutogradPrivateUse3; - default: - return DispatchKey::AutogradOther; - } +DispatchKey getAutogradKeyFromBackend(BackendComponent k) { + // We want this to return an autograd key. We're relying on the fact that + // getAutogradRelatedKeySetFromBackend returns an autograd key + + // ADInplaceOrView, and autograd has higher precedence. The core mapping from + // backend -> autograd key lives in `getAutogradRelatedKeySetFromBackend` + // instead of here for performance. `getAutogradRelatedKeySetFromBackend` is a + // hotpath function, and we want to make sure that it doesn't have to + // construct any DispatchKeySets at runtime. + return getAutogradRelatedKeySetFromBackend(k).highestPriorityTypeId(); } c10::DispatchKey parseDispatchKey(const std::string& k) { static std::unordered_map key_map = { {"Undefined", c10::DispatchKey::Undefined}, - {"CPU", c10::DispatchKey::CPU}, - {"CUDA", c10::DispatchKey::CUDA}, - {"HIP", c10::DispatchKey::HIP}, + {"Dense", c10::DispatchKey::Dense}, {"FPGA", c10::DispatchKey::FPGA}, {"ORT", c10::DispatchKey::ORT}, - {"XLA", c10::DispatchKey::XLA}, - {"MLC", c10::DispatchKey::MLC}, {"Vulkan", c10::DispatchKey::Vulkan}, {"Metal", c10::DispatchKey::Metal}, - {"XPU", c10::DispatchKey::XPU}, - {"HPU", c10::DispatchKey::HPU}, {"VE", c10::DispatchKey::VE}, - {"Lazy", c10::DispatchKey::Lazy}, {"Meta", c10::DispatchKey::Meta}, - {"QuantizedCPU", c10::DispatchKey::QuantizedCPU}, - {"QuantizedCUDA", c10::DispatchKey::QuantizedCUDA}, - {"QuantizedXPU", c10::DispatchKey::QuantizedXPU}, + {"Quantized", c10::DispatchKey::Quantized}, {"CustomRNGKeyId", c10::DispatchKey::CustomRNGKeyId}, {"MkldnnCPU", c10::DispatchKey::MkldnnCPU}, - {"SparseCPU", c10::DispatchKey::SparseCPU}, - {"SparseCUDA", c10::DispatchKey::SparseCUDA}, - {"SparseHIP", c10::DispatchKey::SparseHIP}, - {"SparseXPU", c10::DispatchKey::SparseXPU}, - {"SparseVE", c10::DispatchKey::SparseVE}, + {"Sparse", c10::DispatchKey::Sparse}, {"SparseCsrCPU", c10::DispatchKey::SparseCsrCPU}, {"SparseCsrCUDA", c10::DispatchKey::SparseCsrCUDA}, - {"NestedTensor", c10::DispatchKey::NestedTensor}, - {"PrivateUse1", c10::DispatchKey::PrivateUse1}, - {"PrivateUse2", c10::DispatchKey::PrivateUse2}, - {"PrivateUse3", c10::DispatchKey::PrivateUse3}, {"BackendSelect", c10::DispatchKey::BackendSelect}, {"Python", c10::DispatchKey::Python}, {"PythonTLSSnapshot", c10::DispatchKey::PythonTLSSnapshot}, @@ -259,17 +262,8 @@ c10::DispatchKey parseDispatchKey(const std::string& k) { c10::DispatchKey::FuncTorchDynamicLayerBackMode}, {"ADInplaceOrView", c10::DispatchKey::ADInplaceOrView}, {"AutogradOther", c10::DispatchKey::AutogradOther}, - {"AutogradCPU", c10::DispatchKey::AutogradCPU}, - {"AutogradCUDA", c10::DispatchKey::AutogradCUDA}, - {"AutogradXLA", c10::DispatchKey::AutogradXLA}, - {"AutogradLazy", c10::DispatchKey::AutogradLazy}, - {"AutogradXPU", c10::DispatchKey::AutogradXPU}, - {"AutogradMLC", c10::DispatchKey::AutogradMLC}, - {"AutogradHPU", c10::DispatchKey::AutogradHPU}, + {"AutogradFunctionality", c10::DispatchKey::AutogradFunctionality}, {"AutogradNestedTensor", c10::DispatchKey::AutogradNestedTensor}, - {"AutogradPrivateUse1", c10::DispatchKey::AutogradPrivateUse1}, - {"AutogradPrivateUse2", c10::DispatchKey::AutogradPrivateUse2}, - {"AutogradPrivateUse3", c10::DispatchKey::AutogradPrivateUse3}, {"Tracer", c10::DispatchKey::Tracer}, {"AutocastCPU", c10::DispatchKey::AutocastCPU}, {"AutocastCUDA", c10::DispatchKey::AutocastCUDA}, @@ -283,6 +277,41 @@ c10::DispatchKey parseDispatchKey(const std::string& k) { {"TESTING_ONLY_GenericWrapper", c10::DispatchKey::TESTING_ONLY_GenericWrapper}, {"TESTING_ONLY_GenericMode", c10::DispatchKey::TESTING_ONLY_GenericMode}, + + {"CPU", c10::DispatchKey::CPU}, + {"CUDA", c10::DispatchKey::CUDA}, + {"HIP", c10::DispatchKey::HIP}, + {"XLA", c10::DispatchKey::XLA}, + {"MLC", c10::DispatchKey::MLC}, + {"XPU", c10::DispatchKey::XPU}, + {"HPU", c10::DispatchKey::HPU}, + {"Lazy", c10::DispatchKey::Lazy}, + {"NestedTensor", c10::DispatchKey::NestedTensor}, + {"PrivateUse1", c10::DispatchKey::PrivateUse1}, + {"PrivateUse2", c10::DispatchKey::PrivateUse2}, + {"PrivateUse3", c10::DispatchKey::PrivateUse3}, + + {"QuantizedCPU", c10::DispatchKey::QuantizedCPU}, + {"QuantizedCUDA", c10::DispatchKey::QuantizedCUDA}, + {"QuantizedXPU", c10::DispatchKey::QuantizedXPU}, + + {"SparseCPU", c10::DispatchKey::SparseCPU}, + {"SparseCUDA", c10::DispatchKey::SparseCUDA}, + {"SparseHIP", c10::DispatchKey::SparseHIP}, + {"SparseXPU", c10::DispatchKey::SparseXPU}, + {"SparseVE", c10::DispatchKey::SparseVE}, + + {"AutogradCPU", c10::DispatchKey::AutogradCPU}, + {"AutogradCUDA", c10::DispatchKey::AutogradCUDA}, + {"AutogradXLA", c10::DispatchKey::AutogradXLA}, + {"AutogradLazy", c10::DispatchKey::AutogradLazy}, + {"AutogradXPU", c10::DispatchKey::AutogradXPU}, + {"AutogradMLC", c10::DispatchKey::AutogradMLC}, + {"AutogradHPU", c10::DispatchKey::AutogradHPU}, + {"AutogradPrivateUse1", c10::DispatchKey::AutogradPrivateUse1}, + {"AutogradPrivateUse2", c10::DispatchKey::AutogradPrivateUse2}, + {"AutogradPrivateUse3", c10::DispatchKey::AutogradPrivateUse3}, + {"Autograd", c10::DispatchKey::Autograd}, {"CompositeImplicitAutograd", c10::DispatchKey::CompositeImplicitAutograd}, diff --git a/c10/core/DispatchKey.h b/c10/core/DispatchKey.h index 29315051b4177e..75d0865b6efae1 100644 --- a/c10/core/DispatchKey.h +++ b/c10/core/DispatchKey.h @@ -9,20 +9,98 @@ namespace c10 { +// Semantically, each value of BackendComponent identifies a "backend" for our +// dispatch. Some functionalities that we may dispatch to are allowed to +// register different handlers for each backend. The BackendComponent is then +// used to figure out which backend implementation to dispatch to. + +// In implementation terms, the backend component identifies a specific "bit" in +// a DispatchKeySet. The bits in the DispatchKeySet are split between the bottom +// ~12 "BackendComponent" bits, while the remaining upper bits are assigned to +// functionalities. When we encounter a functionality bit that is known to be +// customizeable per-backend, then we also look at the lower BackendComponent +// bits and take the highest bit to determine which backend's implementation to +// use. + +enum class BackendComponent : uint8_t { + + // A "backend" is colloquially used to refer to handlers for dispatch + // which actually implement the numerics of an operation in question. + // + // Due to the nature of the enum, these backends are specified in + // an ordered way, but for most backends this order is not semantically + // meaningful (e.g., it's valid to reorder these backends without changing + // semantics). The only situation when backend ordering is meaningful + // is when the backend participates in multiple dispatch with another + // backend; e.g., CPU and CUDA (cuda must have higher priority). + + // These keys don't correspond to individual kernels. + // Instead, they represent the backends that are allowed to override specific + // pieces of functionality: + // - dense kernels (e.g. DispatchKey::CPU) + // - sparse kernels (e.g. DispatchKey::SparseCPU) + // - quantized kernels (e.g. DispatchKey::QuantizedCPU) + // - autograd kernels (e.g. DispatchKey::AutogradCPU) + // We reserve space in the runtime operator table for this full cross product + // of + // [backends in this enum] x [keys below that are explicitly marked as having + // per-backend functionality] + + InvalidBit = 0, + CPUBit, + CUDABit, + HIPBit, + XLABit, + MLCBit, + XPUBit, + HPUBit, + VEBit, + LazyBit, + PrivateUse1Bit, + PrivateUse2Bit, + PrivateUse3Bit, + // Define an alias to represent end of backend dispatch keys. + // If you add new backend keys after PrivateUse3, please also update it here. + // (But you shouldn't: private use keys should have higher precedence than + // all built-in keys) + EndOfBackendKeys = PrivateUse3Bit, +}; + // Semantically, a dispatch key identifies a possible "level" in our -// dispatch, for which a handler may be registered. Traditional -// backends like CPU and CUDA get dispatch keys; however, so do -// "wrapping" layers like Variable (for autograd handling). +// dispatch, for which a handler may be registered. Each handler corresponds +// to a type of functionality. // // In implementation terms, the dispatch key identifies a specific "bit" in a // DispatchKeySet. Higher bit indexes get handled by dispatching first (because // we "count leading zeros" when we extract the highest priority dispatch // key.) // +// Note [DispatchKey Classification] +// This enum actually contains several types of keys, which are explained +// in more detail further down: +// (1) non-customizable backends (e.g. FPGA) +// (2) non-customizable functionalities (e.g. Functionalize) +// (3) functionalized that are customizable per backend (e.g. Dense, Sparse, +// AutogradFunctionality) (4) per-backend instances of customizable +// functionalities (e.g. CPU, SparseCPU, AutogradCPU) (5) alias keys (e.g. +// CompositeImplicitAutograd) +// +// Of the categories above, it's important to note: +// (a) which keys are assigned individual bits in a DispatchKeySet +// (b) which keys are assigned individual slots in the runtime operator table +// ("Runtime keys") +// +// (1), (2) and (3) all get their own dedicated bits in the DispatchKeySet. +// (1), (2) and (4) all get their own dedicated slots in the runtime operator +// table. + +// See Note [DispatchKeySet Internal Representation] for more details. +// // NOTE: Keep the list in sync with `DispatchKey` in tools/codegen/model.py -enum class DispatchKey : uint8_t { +enum class DispatchKey : uint16_t { + // ~~~~~~~~~~~~~~~~~~~~~~~~~~ UNDEFINED ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // - // This is not a "real" tensor id, but it exists to give us a "nullopt" + // This is not a "real" functionality, but it exists to give us a "nullopt" // element we can return for cases when a DispatchKeySet contains no elements. // You can think a more semantically accurate definition of DispatchKey is: // @@ -38,24 +116,31 @@ enum class DispatchKey : uint8_t { // this will get eliminated, but for now it's convenient) CatchAll = Undefined, - // ~~~~~~~~~~~~~~~~~~~~~~~~~~ BACKENDS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // - // A "backend" is colloquially used to refer to handlers for dispatch - // which actually implement the numerics of an operation in question. + // ~~~~~~~~~~~~~~~~~~~~~~~~~~ Functionality Keys ~~~~~~~~~~~~~~~~~~~~~~ // + // Every value in the enum (up to EndOfFunctionalityKeys) + // corresponds to an individual "functionality" that can be dispatched to. + // This is represented in the DispatchKeySet by assigning each of these enum + // values + // to each of the remaining (64 - len(BackendComponent)) bits. // - // Due to the nature of the enum, these backends are specified in - // an ordered way, but for most backends this order is not semantically - // meaningful (e.g., it's valid to reorder these backends without changing - // semantics). The only situation when backend ordering is meaningful - // is when the backend participates in multiple dispatch with another - // backend; e.g., CPU and SparseCPU (sparse must have - // higher priority). + // Most of these functionalities have a single handler assigned to them, + // making them "runtime keys". + // That map to a single slot in the runtime operator table. + // + // A few functionalities are allowed to be customizable per backend. + // See [Note: Per-Backend Functionality Dispatch Keys] for details. + + // See [Note: Per-Backend Functionality Dispatch Keys] + Dense, + + // Below are non-extensible backends. + // These are backends that currently don't have their own overrides for + // Autograd/Sparse/Quantized kernels, + // and we therefore don't waste space in the runtime operator table allocating + // space for them. + // If any of these backends ever need to customize, e.g., Autograd, then we'll + // need to add a DispatchKey::*Bit for them. - // Here are backends which you think of as traditionally specifying - // how to implement operations on some device. - CPU, // registered at build/aten/src/ATen/RegisterCPU.cpp - CUDA, // registered at build/aten/src/ATen/RegisterCUDA.cpp - HIP, // NB: I think this is not actually used, due to Note [Masquerading as - // CUDA] FPGA, // Xilinx support lives out of tree at // https://gitlab.com/pytorch-complex/vitis_kernels @@ -67,14 +152,8 @@ enum class DispatchKey : uint8_t { // - aten/src/ATen/test/extension_backend_test.cpp ORT, - XLA, // lives out of tree at https://github.com/pytorch/xla - MLC, // lives out of tree at https://github.com/pytorch/MLCompute Vulkan, Metal, - XPU, // For out of tree Intel's heterogeneous computing plug-in - HPU, // For out of tree & closed source integration of HPU / Habana - VE, // For out of tree & closed source integration of SX-Aurora / NEC - Lazy, // For lazy tensor backends // A meta tensor is a tensor without any data associated with it. (They // have also colloquially been referred to as tensors on the "null" device). @@ -83,11 +162,8 @@ enum class DispatchKey : uint8_t { // tensor with the output shape and dtype, but wouldn't actually add anything. Meta, - // Here are backends which specify more specialized operators - // based on the dtype of the tensor. - QuantizedCPU, // registered at build/aten/src/ATen/RegisterQuantizedCPU.cpp - QuantizedCUDA, // registered at build/aten/src/ATen/RegisterQuantizedCUDA.cpp - QuantizedXPU, // For out of tree Intel's heterogeneous computing plug-in + // See [Note: Per-Backend Functionality Dispatch Keys] + Quantized, // This backend is to support custom RNGs; it lets you go // to a different kernel if you pass in a generator that is not a @@ -106,30 +182,28 @@ enum class DispatchKey : uint8_t { // the corresponding dense tensors, and must be handled before them. MkldnnCPU, // registered at build/aten/src/ATen/RegisterMkldnnCPU.cpp // NB: not to be confused with MKLDNN, which is Caffe2 only - SparseCPU, // registered at build/aten/src/ATen/RegisterSparseCPU.cpp - SparseCUDA, // registered at build/aten/src/ATen/RegisterSparseCUDA.cpp - SparseHIP, // TODO: I think this is not actually used, due to Note - // [Masquerading as CUDA] - SparseXPU, // For out of tree Intel's heterogeneous computing plug-in - SparseVE, // For out of tree & closed source integration of SX-Aurora / NEC + + // See [Note: Per-Backend Functionality Dispatch Keys] + Sparse, SparseCsrCPU, SparseCsrCUDA, - NestedTensor, // lives out of tree at https://github.com/pytorch/nestedtensor - - // Here are reserved backends for user-defined backends, see Note [Private use - // DispatchKey] - // To see some example about how to use this, check out ORT - PrivateUse1, - PrivateUse2, - PrivateUse3, + // Note [Non-Customizable Backend Keys] + // Every key above here is considered a "non-customizable backend". + // These are backends that will work correctly with autograd, but + // but currently don't require separate implementations + // for autograd sparse or quantized kernels. + // Any new backends that don't need to be customized should go above here. + // If an existing backend needs to e.g. override autograd, then we can + // consider promoting it into the "BackendComponent" enum + // + // For all intents and purposes from the perspective of DispatchKeySet, + // "non-customizable backend" keys are treated the same way + // as other functionality keys + EndOfNonCustomizableBackends = SparseCsrCUDA, - // Define an alias key to represent end of backend dispatch keys. - // If you add new backend keys after PrivateUse3, please also update it here. - // (But you shouldn't: private use keys should have higher precedence than - // all built-in keys) - EndOfBackendKeys = PrivateUse3, + NestedTensor, // lives out of tree at https://github.com/pytorch/nestedtensor // In some situations, it is not immediately obvious what the correct // backend for function is, because the function in question doesn't @@ -233,20 +307,18 @@ enum class DispatchKey : uint8_t { // AutogradOther key. We can add specific autograd key for those backends // upon request. AutogradOther, - AutogradCPU, - AutogradCUDA, - AutogradXLA, - AutogradLazy, - AutogradXPU, - AutogradMLC, - AutogradHPU, - AutogradNestedTensor, // lives out of tree at + + // See [Note: Per-Backend Functionality Dispatch Keys] + AutogradFunctionality, + + // NestedTensor is an example of something that isn't a "real backend" + // (because it mostly consists of redispatching kernels) + // but it would like to override autograd functionality in C++. + // We can handle cases like this by adding an extra functionality key + // exclusively for handling autograd for NestedTensor. + // lives out of tree at // https://github.com/pytorch/nestedtensor - // Here are some reserved pre-autograd keys for user-defined backends, see - // Note [Private use DispatchKey] - AutogradPrivateUse1, - AutogradPrivateUse2, - AutogradPrivateUse3, + AutogradNestedTensor, Tracer, @@ -304,9 +376,100 @@ enum class DispatchKey : uint8_t { TESTING_ONLY_GenericMode, // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ FIN ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // - NumDispatchKeys, // Sentinel, end of runtime keys. + EndOfFunctionalityKeys, // End of functionality keys. + + // ~~~~~~~~~~~~~~ "Dense" Per-Backend Dispatch keys ~~~~~~~~~~~~~~~~~~~~ // + // Here are backends which you think of as traditionally specifying + // how to implement operations on some device. + + // See Note [The Ordering of Per-Backend Dispatch Keys Matters!] + StartOfDenseBackends, + CPU, // registered at build/aten/src/ATen/RegisterCPU.cpp + CUDA, // registered at build/aten/src/ATen/RegisterCUDA.cpp + HIP, // NB: I think this is not actually used, due to Note [Masquerading as + // CUDA] + XLA, // lives out of tree at https://github.com/pytorch/xla + MLC, // lives out of tree at https://github.com/pytorch/MLCompute + XPU, // For out of tree Intel's heterogeneous computing plug-in + HPU, // For out of tree & closed source integration of HPU / Habana + VE, // For out of tree & closed source integration of SX-Aurora / NEC + Lazy, // For lazy tensor backends + // Here are reserved backends for user-defined backends, see Note [Private use + // DispatchKey] + // To see some example about how to use this, check out ORT + PrivateUse1, + PrivateUse2, + PrivateUse3, + EndOfDenseBackends = PrivateUse3, + + // ~~~~~~~~~~~~~~ "Quantized" Per-Backend Dispatch keys ~~~~~~~~~~~~~~~~ // + // keys starting with an _ are not currently used, + // but are needed to ensure that every backend is indexed correctly. + + // See Note [The Ordering of Per-Backend Dispatch Keys Matters!] + StartOfQuantizedBackends, + QuantizedCPU, // registered at build/aten/src/ATen/RegisterQuantizedCPU.cpp + QuantizedCUDA, // registered at build/aten/src/ATen/RegisterQuantizedCUDA.cpp + _QuantizedHIP, + _QuantizedXLA, + _QuantizedMLC, + QuantizedXPU, // For out of tree Intel's heterogeneous computing plug-in + _QuantizedHPU, + _QuantizedVE, + _QuantizedLazy, + _QuantizedPrivateUse1, + _QuantizedPrivateUse2, + _QuantizedPrivateUse3, + EndOfQuantizedBackends = _QuantizedPrivateUse3, + + // ~~~~~~~~~~~~~~ "Sparse" Per-Backend Dispatch keys ~~~~~~~~~~~~~~~~~~~ // + // keys starting with an _ are not currently used, + // but are needed to ensure that every backend is indexed correctly. + + // See Note [The Ordering of Per-Backend Dispatch Keys Matters!] + StartOfSparseBackends, + SparseCPU, // registered at build/aten/src/ATen/RegisterSparseCPU.cpp + SparseCUDA, // registered at build/aten/src/ATen/RegisterSparseCUDA.cpp + SparseHIP, // TODO: I think this is not actually used, due to Note + // [Masquerading as CUDA] + _SparseXLA, + _SparseMLC, + SparseXPU, // For out of tree Intel's heterogeneous computing plug-in + _SparseHPU, + SparseVE, // For out of tree & closed source integration of SX-Aurora / NEC + _SparseLazy, + _SparsePrivateUse1, + _SparsePrivateUse2, + _SparsePrivateUse3, + EndOfSparseBackends = _SparsePrivateUse3, + + // ~~~~~~~~~~~~~~ "Autograd" Per-Backend Dispatch keys ~~~~~~~~~~~~~~~~~ // + // keys starting with an _ are not currently used, + // but are needed to ensure that every backend is indexed correctly. + + // See Note [The Ordering of Per-Backend Dispatch Keys Matters!] + StartOfAutogradBackends, + AutogradCPU, + AutogradCUDA, + _AutogradHIP, + AutogradXLA, + AutogradMLC, + AutogradXPU, + AutogradHPU, + _AutogradVE, + AutogradLazy, + // Here are some reserved pre-autograd keys for user-defined backends, see + // Note [Private use DispatchKey] + AutogradPrivateUse1, + AutogradPrivateUse2, + AutogradPrivateUse3, + EndOfAutogradBackends = AutogradPrivateUse3, + // If we add a new per-backend functionality key that has higher priority + // than Autograd, then this key should be updated. + EndOfRuntimeBackendKeys = EndOfAutogradBackends, // ~~~~~~~~~~~~~~~~~~~~~~ Alias Dispatch Keys ~~~~~~~~~~~~~~~~~~~~~~~~~~ // + // Note [Alias Dispatch Keys] // Alias dispatch keys are synthetic dispatch keys which map to multiple // runtime dispatch keys. Alisa keys have precedence, but they are always // lower precedence than runtime keys. You can register a kernel to an @@ -326,6 +489,7 @@ enum class DispatchKey : uint8_t { // Define an alias key to represent end of alias dispatch keys. // If you add new alias keys after Autograd, please also update it here. + StartOfAliasKeys = Autograd, EndOfAliasKeys = CompositeExplicitAutograd, // // ~~~~~~~~~~~~~~~~~~~~~~~~~ BC ALIASES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // @@ -365,54 +529,83 @@ enum class DispatchKey : uint8_t { // built-in autograd formulas for operators are not appropriate. static_assert( - static_cast(DispatchKey::NumDispatchKeys) <= 64, - "DispatchKey is used as index into 64-bit bitmask; you must have less than 64 entries"); + (static_cast(BackendComponent::EndOfBackendKeys) + + static_cast(DispatchKey::EndOfFunctionalityKeys)) <= 64, + "The BackendComponent and DispatchKey enums (below EndOfFunctionalityKeys)" + " both map to backend and functionality bits" + " into a 64-bit bitmask; you must have less than 64 total entries between them"); -#if defined(C10_MOBILE_TRIM_DISPATCH_KEYS) -/** - * The method below maps the dispatch key in the enum DispatchKey to an - * integer index in the dispatchTable_ array in OperatorEntry. The array - * is trimmed for mobile to reduce peak memory usage since it's - * unnecessary to reserve additional space for dispatch keys that will - * never be used on mobile. - */ -C10_API constexpr int getDispatchTableIndexForDispatchKey(DispatchKey dk) { - switch (dk) { - case DispatchKey::Undefined: - return 0; - case DispatchKey::CPU: - return 1; - case DispatchKey::QuantizedCPU: - return 2; - case DispatchKey::SparseCPU: - return 3; - case DispatchKey::BackendSelect: - return 4; - case DispatchKey::ADInplaceOrView: - return 5; - case DispatchKey::AutogradOther: - return 6; - case DispatchKey::AutogradCPU: - return 7; - case DispatchKey::NumDispatchKeys: // Sentinel, end of runtime keys. - return 8; - default: - return -1; +// Check if a DispatchKey is an alias mapping to other runtime keys. +constexpr bool isAliasDispatchKey(DispatchKey k) { + return k >= DispatchKey::StartOfAliasKeys && k <= DispatchKey::EndOfAliasKeys; +} + +// [Note: Per-Backend Functionality Dispatch Keys] +// Check if a DispatchKey is a per-backend functionality key +// Any functionalities that can be customized per-backend should be added here. +// These keys correspond to functionalities that can be customized indivually +// per backend. While they only take up one bit in the `DispatchKeySet` bitset, +// they map to (# backends) slots in the operator table. +// Each of these keys also has a separate set of "runtime keys" in the dispatch +// key enum, per backend, which *do* map to the individual operator table slots. +// For example, the "Sparse" key maps to an individual bit in the +// DispatchKeySet, while `SparseCPU`, `SparseCUDA`, etc all map to individual +// slots in the runtime operator table. + +constexpr bool isPerBackendFunctionalityKey(DispatchKey k) { + if (k == DispatchKey::Dense || k == DispatchKey::Quantized || + k == DispatchKey::Sparse || k == DispatchKey::AutogradFunctionality) { + return true; + } else { + return false; } } -#else -/** - * For the server use-case, make this a simple pass-through. - */ -C10_API constexpr int getDispatchTableIndexForDispatchKey(DispatchKey dk) { - return static_cast(dk); + +// Note that this includes Undefined in the total count. +// BUT EndOfFunctionalityKeys is its own (placeholder) key. +// e.g. Undefined=0, Dense=1, Sparse=2, EndOfFunctionalityKeys=3. +// In the above example, there are 3 total functionality keys. +constexpr uint8_t num_functionality_keys = + static_cast(DispatchKey::EndOfFunctionalityKeys); + +constexpr uint8_t num_backends = + static_cast(BackendComponent::EndOfBackendKeys); + +// Note [No More Than 16 Backends] +// Search for this note to find places in the code where the "no more than 16 +// backends" invariant is baked in. +static_assert( + static_cast(BackendComponent::EndOfBackendKeys) <= 16, + "BackendComponent currently only supports <= 16 backends. If we really need to extend this, \ +there are a few places where this invariant is baked in"); + +constexpr uint8_t numPerBackendFunctionalityKeys() { + uint8_t count = 0; + for (uint8_t k = 0; k <= num_functionality_keys; ++k) { + if (isPerBackendFunctionalityKey(static_cast(k))) + ++count; + } + return count; } + +#if defined(C10_MOBILE_TRIM_DISPATCH_KEYS) +// See [Note: Trimmed Mobile Dispatch Keys] +constexpr uint16_t num_runtime_entries = 8; +#else +constexpr uint16_t num_runtime_entries = num_functionality_keys + + (numPerBackendFunctionalityKeys() * (num_backends - 1)); #endif +// See Note [No More Than 16 Backends] +constexpr uint16_t full_backend_mask = + (static_cast(1) << num_backends) - 1; + C10_API const char* toString(DispatchKey); +C10_API const char* toString(BackendComponent); C10_API std::ostream& operator<<(std::ostream&, DispatchKey); +C10_API std::ostream& operator<<(std::ostream&, BackendComponent); -C10_API DispatchKey getAutogradKeyFromBackend(DispatchKey t); +C10_API DispatchKey getAutogradKeyFromBackend(BackendComponent k); // Parses a string into a dispatch key. // If the string cannot be correctly parsed, throws an exception. @@ -425,10 +618,86 @@ C10_API c10::DispatchKey parseDispatchKey(const std::string& k); // torch::dispatch(torch::kCPU, ...) is also valid. constexpr DispatchKey kAutograd = DispatchKey::Autograd; -// Check if a DispatchKey is an alias mapping to other runtime keys. -inline bool isAliasDispatchKey(DispatchKey k) { - return k > DispatchKey::NumDispatchKeys && k <= DispatchKey::EndOfAliasKeys; +// See Note [The Ordering of Per-Backend Dispatch Keys Matters!] +// This function relies on the invariant that the dispatch keys between +// StartOfDenseBackends and EndOfRuntimeBackendKeys are ordered by backend +// in the same order as `BackendComponent`. +constexpr BackendComponent toBackendComponent(DispatchKey k) { + if (k >= DispatchKey::StartOfDenseBackends && + k <= DispatchKey::EndOfDenseBackends) { + return static_cast( + static_cast(k) - + static_cast(DispatchKey::StartOfDenseBackends)); + } else if ( + k >= DispatchKey::StartOfQuantizedBackends && + k <= DispatchKey::EndOfQuantizedBackends) { + return static_cast( + static_cast(k) - + static_cast(DispatchKey::StartOfQuantizedBackends)); + } else if ( + k >= DispatchKey::StartOfSparseBackends && + k <= DispatchKey::EndOfSparseBackends) { + return static_cast( + static_cast(k) - + static_cast(DispatchKey::StartOfSparseBackends)); + } else if ( + k >= DispatchKey::StartOfAutogradBackends && + k <= DispatchKey::EndOfAutogradBackends) { + return static_cast( + static_cast(k) - + static_cast(DispatchKey::StartOfAutogradBackends)); + } else { + return BackendComponent::InvalidBit; + } +} + +constexpr DispatchKey toFunctionalityKey(DispatchKey k) { + if (k <= DispatchKey::EndOfFunctionalityKeys) { + return k; + } else if (k <= DispatchKey::EndOfDenseBackends) { + return DispatchKey::Dense; + } else if (k <= DispatchKey::EndOfQuantizedBackends) { + return DispatchKey::Quantized; + } else if (k <= DispatchKey::EndOfSparseBackends) { + return DispatchKey::Sparse; + } else if (k <= DispatchKey::EndOfAutogradBackends) { + return DispatchKey::AutogradFunctionality; + } else { + return DispatchKey::Undefined; + } } + +// Given (DispatchKey::Dense, DispatchKey::CUDABit), returns DispatchKey::CUDA +// See Note [The Ordering of Per-Backend Dispatch Keys Matters!] +// This function relies on the invariant that the dispatch keys between +// StartOfDenseBackends and EndOfRuntimeBackendKeys are ordered by backend +// in the same order as `BackendComponent`. +constexpr DispatchKey toRuntimePerBackendFunctionalityKey( + DispatchKey functionality_k, + BackendComponent backend_k) { + if (functionality_k == DispatchKey::Dense) { + return static_cast( + static_cast(DispatchKey::StartOfDenseBackends) + + static_cast(backend_k)); + } + if (functionality_k == DispatchKey::Sparse) { + return static_cast( + static_cast(DispatchKey::StartOfSparseBackends) + + static_cast(backend_k)); + } + if (functionality_k == DispatchKey::Quantized) { + return static_cast( + static_cast(DispatchKey::StartOfQuantizedBackends) + + static_cast(backend_k)); + } + if (functionality_k == DispatchKey::AutogradFunctionality) { + return static_cast( + static_cast(DispatchKey::StartOfAutogradBackends) + + static_cast(backend_k)); + } + return DispatchKey::Undefined; +} + } // namespace c10 namespace torch { diff --git a/c10/core/DispatchKeySet.cpp b/c10/core/DispatchKeySet.cpp index 7f85567f886f6b..192b1e0b471858 100644 --- a/c10/core/DispatchKeySet.cpp +++ b/c10/core/DispatchKeySet.cpp @@ -1,37 +1,30 @@ #include +#include +#include namespace c10 { -// backend_dispatch_keyset should include all runtime backend keys. +// backend_dispatch_keyset includes all dispatch keys that map to backends. // Alias key DispatchKey::CompositeExplicitAutograd maps to -// backend_dispatch_keyset NestedTensor has been explicitly removed due to -// incompatibility with some kernels, such as structured kernels, that use the -// DefaultBackend key. -constexpr DispatchKeySet backend_dispatch_keyset = autogradother_backends | - DispatchKeySet({ - DispatchKey::CPU, - DispatchKey::CUDA, - DispatchKey::XLA, - DispatchKey::Lazy, - DispatchKey::XPU, - DispatchKey::PrivateUse1, - DispatchKey::PrivateUse2, - DispatchKey::PrivateUse3, - DispatchKey::MLC, - DispatchKey::HPU, - DispatchKey::ORT, - DispatchKey::Meta, - }); +// backend_dispatch_keyset +constexpr DispatchKeySet backend_dispatch_keyset = + autogradother_backends | DispatchKeySet(DispatchKey::Dense); bool isBackendDispatchKey(DispatchKey t) { return t != DispatchKey::Undefined // See Note [No Alias Keys in DispatchKeySet] - && !isAliasDispatchKey(t) && backend_dispatch_keyset.has(t); + && !isAliasDispatchKey(t) + // Note [NestedTensor Not Included in Backend Keys] + // NestedTensor has been explicitly removed from the "backend keyset" due + // to incompatibility with some kernels, so we don't want it to be + // included in CompositeImplicitAutograd or CompositeExplicitAutograd + // kernels. + && t != DispatchKey::NestedTensor && backend_dispatch_keyset.has(t); } // math_dispatch_keyset contains all keys in backend_dispatch_keyset and // autograd_dispatch_keyset Alias key DispatchKey::CompositeImplicitAutograd -// maps to math_dispatch_keyset. +// maps to [math_dispatch_keyset x full_backend_mask] constexpr DispatchKeySet math_dispatch_keyset = backend_dispatch_keyset | autograd_dispatch_keyset; @@ -39,7 +32,12 @@ DispatchKeySet getRuntimeDispatchKeySet(DispatchKey t) { TORCH_INTERNAL_ASSERT(t != DispatchKey::Undefined); switch (t) { case DispatchKey::Autograd: - return autograd_dispatch_keyset; + // See Note [autograd_dispatch_keyset Does Not Include Backend Bits] + // That's why we OR it with a mask of the backend bits here. + // getRuntimeDispatchKeySet() expects to return a keyset of runtime + // dispatch keys, like AutogradCPU, but that requires having backend bits. + return autograd_dispatch_keyset | + DispatchKeySet(DispatchKeySet::RAW, full_backend_mask); case DispatchKey::CompositeImplicitAutograd: return math_dispatch_keyset; case DispatchKey::CompositeExplicitAutograd: @@ -53,11 +51,13 @@ bool runtimeDispatchKeySetHas(DispatchKey t, DispatchKey k) { TORCH_INTERNAL_ASSERT(t != DispatchKey::Undefined); switch (t) { case DispatchKey::Autograd: - return autograd_dispatch_keyset.has(k); + return autograd_dispatch_keyset.has(toFunctionalityKey(k)); case DispatchKey::CompositeImplicitAutograd: - return math_dispatch_keyset.has(k); + // See Note [NestedTensor Not Included in Backend Keys] + return k != DispatchKey::NestedTensor && math_dispatch_keyset.has(k); case DispatchKey::CompositeExplicitAutograd: - return backend_dispatch_keyset.has(k); + // See Note [NestedTensor Not Included in Backend Keys] + return k != DispatchKey::NestedTensor && backend_dispatch_keyset.has(k); default: return t == k; } @@ -79,8 +79,6 @@ DispatchKeySet getBackendKeySetFromAutograd(DispatchKey t) { return DispatchKeySet(DispatchKey::MLC); case DispatchKey::AutogradHPU: return DispatchKeySet(DispatchKey::HPU); - case DispatchKey::AutogradNestedTensor: - return DispatchKeySet(DispatchKey::NestedTensor); case DispatchKey::AutogradXPU: return DispatchKeySet(DispatchKey::XPU); case DispatchKey::AutogradPrivateUse1: @@ -96,23 +94,6 @@ DispatchKeySet getBackendKeySetFromAutograd(DispatchKey t) { } } -DispatchKeySet getAutocastRelatedKeySetFromBackend(DispatchKey t) { - switch (t) { - case DispatchKey::CPU: - return DispatchKeySet(DispatchKey::AutocastCPU); - case DispatchKey::CUDA: - case DispatchKey::XLA: - return DispatchKeySet(DispatchKey::AutocastCUDA); - default: - return DispatchKeySet(); - } -} - -DispatchKeySet getAutogradRelatedKeySetFromBackend(DispatchKey t) { - return DispatchKeySet( - {DispatchKey::ADInplaceOrView, getAutogradKeyFromBackend(t)}); -} - bool isIncludedInAlias(DispatchKey k, DispatchKey alias) { return k != DispatchKey::Undefined && runtimeDispatchKeySetHas(alias, k); } @@ -129,18 +110,135 @@ std::ostream& operator<<(std::ostream& os, DispatchKeySet ts) { return os; } os << "DispatchKeySet("; - DispatchKey tid; bool first = true; - while ((tid = ts.highestPriorityTypeId()) != DispatchKey::Undefined) { + for (auto k : ts) { if (!first) { os << ", "; } - os << tid; - ts = ts.remove(tid); + os << k; first = false; } os << ")"; return os; } +DispatchKeySet::iterator& DispatchKeySet::iterator::operator++() { + TORCH_INTERNAL_ASSERT(next_functionality_ <= iterator::end_iter_mask_val); + TORCH_INTERNAL_ASSERT(next_backend_ <= num_backends, next_backend_); + + // Create a masked version of the set representation to ignore previous + // keys that we've iterated through. + uint64_t masked_functionality_bits = + llvm::maskTrailingZeros(next_functionality_) & *data_ptr_; + uint64_t masked_backend_bits = + llvm::maskTrailingZeros(next_backend_) & full_backend_mask & + *data_ptr_; + + uint64_t first_functionality_idx = + llvm::findFirstSet(masked_functionality_bits); + uint64_t first_backendcomponent_idx = llvm::findFirstSet(masked_backend_bits); + + // If there are no keys, set to end iterator value + if (first_functionality_idx == std::numeric_limits::max() || + next_functionality_ == iterator::end_iter_mask_val) { + // Set up state to be the same as end() + next_functionality_ = iterator::end_iter_mask_val; + current_dispatchkey_idx_ = iterator::end_iter_key_val; + next_backend_ = 0; + current_backendcomponent_idx_ = iterator::end_iter_key_val; + return *this; + } + + // The +1 is because of DispatchKey::Undefined and + // BackendComponent::InvalidBit + auto new_next_functionality = first_functionality_idx + 1; + auto new_backendcomponent_idx = first_backendcomponent_idx + 1; + // and the -num_backends is because the first bits in the + // keyset are not Dispatch Keys. + auto next_dispatchkey_idx = new_next_functionality - num_backends; + + // If the current functionality bit is a per-backend bit, we need special + // handling + if (isPerBackendFunctionalityKey( + static_cast(next_dispatchkey_idx))) { + // case 1: if the current backend is undefined, then there is no valid + // backend instance of this functionality key so we can skip it. + if (first_backendcomponent_idx == std::numeric_limits::max()) { + // increment the functionality mask so we skip the current functionality + // bit on the next increment. + next_functionality_ = new_next_functionality; + ++(*this); + return *this; + } + + // Otherwise, at this point we know what the current backend and + // functionality bits are. + current_dispatchkey_idx_ = next_dispatchkey_idx; + current_backendcomponent_idx_ = new_backendcomponent_idx; + + // Next, we need to set up the masks for the next increment. + uint64_t next_backendcomponent_bits = + llvm::maskTrailingZeros(first_backendcomponent_idx + 1) & + full_backend_mask & *data_ptr_; + uint64_t next_backendcomponent_idx = + llvm::findFirstSet(next_backendcomponent_bits); + if (next_backendcomponent_idx == std::numeric_limits::max()) { + // case 2: the current backend is valid, but there is not another backend + // in the keyset. In this case, we need to bump the functionality mask and + // reset the backend mask for the next increment + next_functionality_ = new_next_functionality; + next_backend_ = 0; + } else { + // case 3: we have another backend to iterate over. We want to iterate + // over the same functionality bit next time, but a different backend bit. + next_backend_ = first_backendcomponent_idx + 1; + } + } else { + // Functionality bits that aren't per backend are simpler to handle. We can + // ignore the backend bits. + TORCH_INTERNAL_ASSERT(next_backend_ == 0); + current_dispatchkey_idx_ = next_dispatchkey_idx; + next_functionality_ = new_next_functionality; + } + return *this; +} + +std::array +initializeFunctionalityOffsetsAndMasks() { + std::array + offsets_and_masks; + // manualy set the first entry, which corresponds to Undefined. + offsets_and_masks[0] = FunctionalityOffsetAndMask(0, 0); + // loop through every functionality key (aside from Undefined). + for (const auto functionality_idx : c10::irange(1, num_functionality_keys)) { + // functionality_idx should be Dense -> 1, ... + auto prev_offset_and_mask = offsets_and_masks[functionality_idx - 1]; + auto k = static_cast(functionality_idx); + + // If the previous functionality was not per-backend, then we can just + // increment the previous offset. Otherwise, the next offset = + // previous_offset + num_backends. + auto next_offset = prev_offset_and_mask.offset + + (prev_offset_and_mask.mask == 0 ? 1 : num_backends); + // the mask is used in the runtime index calculation to find the offset of + // the backend. For non-per-backend functionalities, this offset should + // always be 0. Otherwise, we need to get the index of the backend (which we + // can do using a backend mask). + auto next_mask = isPerBackendFunctionalityKey(k) ? full_backend_mask : 0; + offsets_and_masks[functionality_idx] = + FunctionalityOffsetAndMask(next_offset, next_mask); + } + // Sanity check that the computed offset index of the last functionality key + // is correct. This assumes that the highest priority functionality key is not + // per backend. + TORCH_INTERNAL_ASSERT( + offsets_and_masks[num_functionality_keys - 1].offset == + (num_runtime_entries - 1), + "num_runtime_entries: ", + num_runtime_entries, + "last_offset: ", + offsets_and_masks[num_functionality_keys - 1].offset); + return offsets_and_masks; +} + } // namespace c10 diff --git a/c10/core/DispatchKeySet.h b/c10/core/DispatchKeySet.h index 79d39652219b51..821b701022728e 100644 --- a/c10/core/DispatchKeySet.h +++ b/c10/core/DispatchKeySet.h @@ -1,5 +1,4 @@ #pragma once - #include #include #include @@ -8,29 +7,147 @@ namespace c10 { +struct FunctionalityOffsetAndMask { + // empty constructor shouldn't be used; only needed to initialize + // the array before populating it. + FunctionalityOffsetAndMask() {} + FunctionalityOffsetAndMask(uint16_t offset, uint16_t mask) + : offset(offset), mask(mask) {} + // This needs to big enough to cover the size of the operator table. + uint16_t offset; + // See Note [No More Than 16 Backends] + // This mask needs to be big enough to mask all of the backend bits. + // We probably don't ever want to have more than 16 backend bits, so uint16_t + // should be enough. + uint16_t mask; +}; +static_assert( + c10::num_runtime_entries < 65536, + "The dispatcher currently only supports up to 2^16 runtime entries"); + +C10_API std::array +initializeFunctionalityOffsetsAndMasks(); + +C10_ALWAYS_INLINE static const std:: + array& + offsetsAndMasks() { + static auto offsets_and_masks_ = initializeFunctionalityOffsetsAndMasks(); + return offsets_and_masks_; +} + +// A representation of a set of DispatchKeys. A DispatchKeySet contains both +// "functionality" bits and "backend bits", and every tensor holds its own +// DispatchKeySet. The Dispatcher implements multiple dispatch by grabbing the +// keyset on every input tensor, or’ing them together, and dispatching to a +// specific piece of functionality. The functionality bits are *ordered*. When +// multiple functionality bits are set, we use the highest priority +// functionality. Similarly, multiple backend bits can theoretically be set if +// you call an operator with multiple tensors from difference devices (e.g. CPU +// and CUDA), although support for mixed device dispatch is limited (the only +// kernels that gracefully handle mixed device inputs for now are cuda kernels +// that take in a scalar cpu tensor). + // A representation of a set of DispatchKeys. A tensor may have multiple // tensor type ids, e.g., a Variable tensor can also be a CPU tensor; the // DispatchKeySet specifies what type ids apply. The internal representation is // as a 64-bit bit set (this means only 64 tensor type ids are supported). // -// Note that DispatchKeys are ordered; thus, we can ask questions like "what is -// the highest priority DispatchKey in the set"? (The set itself is not -// ordered; two sets with the same ids will always have the ids ordered in the -// same way.) +// As mentioned above, DispatchKeys are ordered; thus, we can ask questions like +// "what is the highest priority DispatchKey in the set"? (The set itself is +// not ordered; two sets with the same ids will always have the ids ordered in +// the same way.) +// +// Note [DispatchKeySet Internal Representation] +// Internally, dispatch keys are packed into 64-bit DispatchKeySet objects +// that get passed around at runtime. +// However, there isn't necessarily a 1-to-1 mapping between bits in the keyset +// and individual dispatch keys. +// +// First: why do we have this distinction, and why not map every dispatch key +// directly to a bit? This is mostly because we have several types of +// functionalities that different backends would like to customize. For example, +// we have: +// - "Dense": CPU, CUDA, XLA, ... (~12 keys) +// - "Sparse": SparseCPU, SparseCUDA, ... +// - "Quantized": QuantizedCPU, QuantizedCUDA, QuantizedXLA, ... +// - "Autograd": AutogradCPU, AutogradCUDA, Autograd XLA, ... +// The problem is that total number of keys grows quadratically with [# +// backends] x [# functionalities], making it very difficult to map each key +// directly to a bit in a bitset without dramatically increasing the size of the +// bitset over time. +// +// The two enums (BackendComponent and DispatchKey) can be divided roughly into +// 5 categories. +// +// (1) "Building block" keys +// (a) backends: jEverything in the BackendComponent enum (e.g. CPUBit, +// CUDABIt) (b) functionalities: (per-backend) functionality-bit DispatchKeys +// (e.g. AutogradFunctionality, Sparse, Dense) +// (2) "Runtime" keys +// (a) "non-customizable backends" (e.g. FPGA) +// (b) "non-customizable functionalities" (e.g. Functionalize) +// (c) "per-backend instances of customizable functionalities" (e.g. CPU, +// SparseCPU, AutogradCPU) +// (3) "Alias" DispatchKeys (see Note [Alias Dispatch Keys]) +// +// (1) Building block keys always correspond to individual bits in a +// DispatchKeySet. They can also be combined in a DispatchKeySet to form actual +// runtime keys. e.g. +// auto dense_cpu_ks = DispatchKeySet({DispatchKey::CPUBit, +// DispatchKey::Dense}); +// // The keyset has the runtime dense-cpu key. +// dense_cpu_ks.has(DispatchKey::CPU); +// // And it contains the building block keys too. +// dense_cpu_ks.has(DispatchKey::CPUBit); +// dense_cpu_ks.has(DispatchKey::Dense); // -// At the moment, there are no nontrivial uses of this set; tensors are always -// singletons. In the near future, this set will represent variable? + tensor -// type id. In the far future, it will be requires grad? + profiling? + -// tracing? + lazy? + tensor type id. +// Not every backend and not every functionality counts as a "building block +// key". This is mostly to give us more levers to pull in the design space. +// Backend keys and functionality keys that count as "building blocks" will +// contribute to a full cross product of functionality that can be overriden. // -// (The difference between variable and requires grad, is that -// there are currently three states a tensor can be: -// 1. Not a variable -// 2. Variable with requires_grad=False -// 3. Variable with requires_grad=True -// Eventually, we want to kill state (1), and only dispatch to autograd -// handling code if one of the inputs requires grad.) +// For example, right now we have at least 12 "backend" building blocks (CPU, +// CUDA, XLA, ...) and at least 4 "functionality" building blocks (Dense, +// Sparse, Quantized, AutogradFunctionality, ...). These keys together allow +// every dispatcher operator to be customized in up to 12*4 different ways. Each +// of those requires a slot in the operator table of every dispatcher operator. +// Not every piece of functionality necessarily needs to be customizeable +// per-backend, and not every backend necessarily needs to be able to customize +// every type of functionality. // +// +// (2) Every runtime key corresponds directly to a slot in an operator's runtime +// dispatch table, and you can directly register kernels to a runtime dispatch +// key. +// +// For per-backend functionalities like "Dense" or "AutogradFunctionality", +// you can think of the corresponding runtime dispatch keys as "instances" of +// that functionality, per backend. E.g. "CPU", "CUDA", "XLA", etc. are all +// runtime instances of the "Dense" building block key. + +// (2a) and (2b) are represented identically in the DispatchKeySet logic: +// - backend-agnostic functionalities (e.g. FuncTorchBatched) are NOT +// customizeable per backend. +// In order to do so, we'd need to promote it to a per-backend functionality +// "building block" key. +// - non-customizeable backends (e.g. FPGA) can NOT customize existing +// functionality like Sparse, Autograd, etc. +// In order to do so, we'd need to promote it to a backend "building block" +// key. +// +// In both cases, these keys directly correspond to runtime slots in the +// operator table. +// +// +// (3) "Alias" keys +// See Note [Alias Dispatch Keys] +// +// Final note: for anyone making future changes to the Dispatcher + +// DispatchKeySet internals, there's a closed PR with a basic +// python-implementation of the Dispatcher that might be useful in quickly +// testing out and validating changes. See it at +// https://github.com/pytorch/pytorch/pull/68743 + // An undefined tensor is one with an empty tensor type set. class DispatchKeySet final { public: @@ -41,29 +158,146 @@ class DispatchKeySet final { // NB: default constructor representation as zero is MANDATORY as // use of DispatchKeySet in TLS requires this. constexpr DispatchKeySet() : repr_(0) {} + constexpr DispatchKeySet(Full) - : repr_(std::numeric_limits::max()) {} + : repr_((1ULL << (num_backends + num_functionality_keys - 1)) - 1) {} + constexpr DispatchKeySet(FullAfter, DispatchKey t) // LSB after t are OK, but not t itself. - : repr_((1ULL << (static_cast(t) - 1)) - 1) {} + // "functionalities" have a notion of ordering (e.g. Autograd > Sparse > + // Quantized > Dense). But backends don't really have an ordering. + // Therefore, we're enforcing that FullAfter can only be used on + // "functionality" keys. + : repr_( + (1ULL + << (num_backends + static_cast(toFunctionalityKey(t)) - + 1)) - + 1) {} + // Public version of DispatchKeySet(uint64_t) API; external users // must be explicit when they do this! constexpr DispatchKeySet(Raw, uint64_t x) : repr_(x) {} - explicit constexpr DispatchKeySet(DispatchKey t) - : repr_( - t == DispatchKey::Undefined - ? 0 - : 1ULL << (static_cast(t) - 1)) {} - explicit constexpr DispatchKeySet(std::initializer_list ks) - : repr_(0) { + + constexpr explicit DispatchKeySet(BackendComponent k) { + if (k == BackendComponent::InvalidBit) { + repr_ = 0; + } else { + repr_ = 1ULL << (static_cast(k) - 1); + } + } + + constexpr explicit DispatchKeySet(DispatchKey k) { + if (k == DispatchKey::Undefined) { + // Case 1: handle Undefined specifically + repr_ = 0; + } else if (k <= DispatchKey::EndOfFunctionalityKeys) { + // Case 2: handle "functionality-only" keys + // These keys have a functionality bit set, but no backend bits + // These can technically be either: + // - valid runtime keys (e.g. DispatchKey::AutogradOther, + // DispatchKey::FuncTorchBatched, etc) + // - "building block" keys that aren't actual runtime keys (e.g. + // DispatchKey::Dense or Sparse) + uint64_t functionality_val = 1ULL + << (num_backends + static_cast(k) - 1); + repr_ = functionality_val; + } else if (k <= DispatchKey::EndOfRuntimeBackendKeys) { + // Case 3: "runtime" keys that have a functionality bit AND a backend bit. + // First compute which bit to flip for the functionality. + auto functionality_k = toFunctionalityKey(k); + // The - 1 is because Undefined is technically a "functionality" that + // doesn't show up in the bitset. So e.g. Dense is technically the second + // functionality, but the lowest functionality bit. + uint64_t functionality_val = 1ULL + << (num_backends + static_cast(functionality_k) - 1); + + // then compute which bit to flip for the backend + // Case 4a: handle the runtime instances of "per-backend functionality" + // keys For example, given DispatchKey::CPU, we should set: + // - the Dense functionality bit + // - the CPUBit backend bit + // first compute which bit to flip for the backend + auto backend_k = toBackendComponent(k); + uint64_t backend_val = backend_k == BackendComponent::InvalidBit + ? 0 + : 1ULL << (static_cast(backend_k) - 1); + repr_ = functionality_val + backend_val; + } else { + // At this point, we should have covered every case except for alias keys. + // Technically it would be possible to add alias dispatch keys to a + // DispatchKeySet, but the semantics are a little confusing and this + // currently isn't needed anywhere. + repr_ = 0; + } + } + + constexpr uint64_t keys_to_repr(std::initializer_list ks) { + uint64_t repr = 0; + for (auto k : ks) { + repr |= DispatchKeySet(k).repr_; + } + return repr; + } + + constexpr uint64_t backend_bits_to_repr( + std::initializer_list ks) { + uint64_t repr = 0; for (auto k : ks) { - repr_ |= DispatchKeySet(k).repr_; + repr |= DispatchKeySet(k).repr_; } + return repr; } + + explicit constexpr DispatchKeySet(std::initializer_list ks) + : repr_(keys_to_repr(ks)) {} + + explicit constexpr DispatchKeySet(std::initializer_list ks) + // Note: for some reason, putting this logic directly in the constructor + // appears to fail to compile on CUDA 10.1. + // See an example internal failure at + // https://www.internalfb.com/intern/skycastle/run/76561193669136035/artifact/actionlog.76561193742069401.stderr + : repr_(backend_bits_to_repr(ks)) {} + // Test if a DispatchKey is in the set - bool inline has(DispatchKey t) const { + inline bool has(DispatchKey t) const { TORCH_INTERNAL_ASSERT_DEBUG_ONLY(t != DispatchKey::Undefined); - return static_cast(repr_ & DispatchKeySet(t).repr_); + return has_all(DispatchKeySet(t)); + } + constexpr bool has_backend(BackendComponent t) const { + return has_all(DispatchKeySet(t)); + } + + // Test if a DispatchKey is in the set + // Given a DispatchKeySet of functionality keys and (potentially) backend + // keys, tests if all of them are in the current set. + constexpr bool has_all(DispatchKeySet ks) const { + return static_cast((repr_ & ks.repr_) == ks.repr_); + } + + // Given a DispatchKeySet of functionality keys and (potentially) backend + // keys, tests if any of them are in the current set. This could technically + // be pretty easily implemented using has(). It is strictly a perf + // optimization though. There are many places in the code base where we want + // to test for multiple functionality keys together. HOWEVER, runtime + // per-backend functionality keys aren't allowed to be used with this + // function, because you can end up with weird results. e.g. + // DispatchKeySet(DispatchKey::AutogradCPU).has_any(DispatchKeySet(DispatchKey::CPU)) + // would return true. + inline bool has_any(DispatchKeySet ks) const { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + // Either there are no backend bits in the input keyset + ((ks.repr_ & full_backend_mask) == 0) || + // or there are no per-backend-functionality bits + // See [Note: Per-Backend Functionality Dispatch Keys] + ((ks & + DispatchKeySet({ + DispatchKey::Dense, + DispatchKey::Quantized, + DispatchKey::Sparse, + DispatchKey::AutogradFunctionality, + }) + .repr_) == 0)); + return static_cast((repr_ & ks.repr_) != 0); } // Test if DispatchKeySet is a superset of ks. bool isSupersetOf(DispatchKeySet ks) const { @@ -74,31 +308,64 @@ class DispatchKeySet final { return DispatchKeySet(repr_ | other.repr_); } // Perform set intersection - DispatchKeySet operator&(DispatchKeySet other) const { + constexpr DispatchKeySet operator&(DispatchKeySet other) const { return DispatchKeySet(repr_ & other.repr_); } - // Compute the set difference self - other + // Compute the set difference self - other, + // but ONLY for the functionality keys. + // Any backend bits set on self will remain unchanged. + // See Note [Removing keys from DispatchKeySet Only Affects Functionality + // Keys] DispatchKeySet operator-(DispatchKeySet other) const { - return DispatchKeySet(repr_ & ~other.repr_); + return DispatchKeySet(repr_ & (full_backend_mask | ~other.repr_)); } + // Compute self ^ other constexpr DispatchKeySet operator^(DispatchKeySet other) const { return DispatchKeySet(repr_ ^ other.repr_); } - // Perform set equality bool operator==(DispatchKeySet other) const { return repr_ == other.repr_; } + bool operator!=(DispatchKeySet other) const { + return repr_ != other.repr_; + } // Add a DispatchKey to the DispatchKey set. Does NOT mutate, // returns the extended DispatchKeySet! C10_NODISCARD DispatchKeySet add(DispatchKey t) const { return *this | DispatchKeySet(t); } - // Remove a DispatchKey from the DispatchKey set. This is - // generally not an operation you should be doing (it's - // used to implement operator<<) - C10_NODISCARD constexpr DispatchKeySet remove(DispatchKey t) const { - return DispatchKeySet(repr_ & ~DispatchKeySet(t).repr_); + C10_NODISCARD DispatchKeySet add(DispatchKeySet ks) const { + return *this | ks; + } + + // Remove a DispatchKey from the DispatchKey set. + // This is generally not an operation you should be doing + // (it's used to implement the printing overload, operator<<) + // + // Note [Removing keys from DispatchKeySet Only Affects Functionality Keys] + // Only functionality bits are allowed to be removed from a keyset. + // For now, we're only allowing removal of "functionality bits" from the + // keyset, which is specifically needed by the fallthrough key calculation + // logic. Why is removing backend bits problematic? Consider this example: + // + // DispatchKeySet([DispatchKey.CPU, DispatchKey.AutogradCUDA, + // DispatchKey.CUDA]).remove(DispatchKey.AutogradCUDA) + // DispatchKeySet([DispatchKey.CPU, + // DispatchKey.AutogradCUDA]).remove(DispatchKey.AutogradCUDA) + // + // What do we want to happen? + // Technically, we'd like it to be true that after removal, + // the first keyset still has the CUDA dispatch key while the second doesn't. + // Unfortunately there's no way to represent that, because the two keysets are + // represented the same way internally: functionality bits: Autograd, Dense + // backend bits: CPU, CUDA + // + // Instead, remove(DispatchKey.AutogradCPU) will only remove the "Autograd" + // bit from the bitset. + constexpr DispatchKeySet remove(DispatchKey t) const { + return DispatchKeySet( + repr_ & ~(DispatchKeySet(t).repr_ & ~full_backend_mask)); } // Is the set empty? (AKA undefined tensor) bool empty() const { @@ -107,22 +374,112 @@ class DispatchKeySet final { uint64_t raw_repr() { return repr_; } - // Return the type id in this set with the highest priority (i.e., - // is the largest in the DispatchKey enum). Intuitively, this - // type id is the one that should handle dispatch (assuming there - // aren't any further exclusions or inclusions). + + DispatchKey highestFunctionalityKey() const { + auto functionality_idx = indexOfHighestBit(); + // This means that none of the functionality bits were set. + if (functionality_idx < num_backends) + return DispatchKey::Undefined; + // The first num_backend bits in the keyset don't correspond to real + // dispatch keys. + return static_cast(functionality_idx - num_backends); + } + + // This is similar like toBackendComponent(DispatchKey), but less restrictive. + // toBackendComponent() errors out if the key that it was passed has no + // backend bits, which is useful for error checking. We need a version of that + // here that can also handle "fake" backends like FPGA, because they need to + // map to the AutogradOther key. For those backends, we return + // BackendComponent::InvalidBit. + BackendComponent highestBackendKey() const { + // mask to mask out functionality bits + auto backend_idx = + DispatchKeySet(repr_ & full_backend_mask).indexOfHighestBit(); + // all zeros across the backend bits means that no backend bits are set. + if (backend_idx == 0) + return BackendComponent::InvalidBit; + return static_cast(backend_idx); + } + + // returns the DispatchKey of highest priority in the set. DispatchKey highestPriorityTypeId() const { - // TODO: If I put Undefined as entry 64 and then adjust the - // singleton constructor to shift from the right, we can get rid of the - // subtraction here. It's modestly more complicated to get right so I - // didn't do it for now. - return static_cast(64 - llvm::countLeadingZeros(repr_)); + auto functionality_k = highestFunctionalityKey(); + if (isPerBackendFunctionalityKey(functionality_k)) { + return toRuntimePerBackendFunctionalityKey( + functionality_k, highestBackendKey()); + } + return functionality_k; + } + + // Returns the index of the most-significant bit in the keyset. + // This is used to as part of the calculation into the operator table to get: + // - the highest "functionality" bit in the keyset. + // - the highest "backend" bit in the keyset. + uint8_t indexOfHighestBit() const { + return 64 - llvm::countLeadingZeros(repr_); } - DispatchKey highestPriorityBackendTypeId() const { - return (*this & - ((1ULL << static_cast(DispatchKey::EndOfBackendKeys)) - 1)) - .highestPriorityTypeId(); +#if defined(C10_MOBILE_TRIM_DISPATCH_KEYS) + // [Note: Trimmed Mobile Dispatch Keys] + /** + * The method below maps the dispatch key in the enum DispatchKey to an + * integer index in the dispatchTable_ array in OperatorEntry. The array + * is trimmed for mobile to reduce peak memory usage since it's + * unnecessary to reserve additional space for dispatch keys that will + * never be used on mobile. + */ + int getDispatchTableIndexForDispatchKeySet() const { + auto dk = highestPriorityTypeId(); + switch (dk) { + case DispatchKey::Undefined: + return 0; + case DispatchKey::CPU: + return 1; + case DispatchKey::QuantizedCPU: + return 2; + case DispatchKey::SparseCPU: + return 3; + case DispatchKey::BackendSelect: + return 4; + case DispatchKey::ADInplaceOrView: + return 5; + case DispatchKey::AutogradOther: + return 6; + case DispatchKey::AutogradCPU: + return 7; + default: + return -1; + } + } +#else + // returns the index in the operator table of highest priority key in the the + // keyset Note that we could in theory implement this using + // highestPriorityTypeId(), but this code is very hotpath and we can do it + // faster without it. + int getDispatchTableIndexForDispatchKeySet() const { + auto functionality_idx = + DispatchKeySet(repr_ >> num_backends).indexOfHighestBit(); + auto offset_and_mask = offsetsAndMasks()[functionality_idx]; + // Mask the functionality bits out first, then right-shift by 1. + // right-shifting by 1 because everything is zero-indexed. + // E.g. 000001 (CPU) should give us an offset of 0, 000010 (CUDA) should + // give us an offset of 1, etc. + auto backend_idx = + DispatchKeySet((repr_ & offset_and_mask.mask) >> 1).indexOfHighestBit(); + return offset_and_mask.offset + backend_idx; + } +#endif + + // returns the "index" of the highest priority backend in the keyset. + // This is pretty similar to getBackendKey(), but: + // - It's hotpath code (part of the runtime bitset calculation) + // - I's returns an integer index, not an enum value + // - Everything is shifted to the right by 1. + // BackendComponent::InvalidBit is technically the lowest enum value, + // but it isn't included in the runtime table. So CPUBit = 1, CUDABit = 2, + // etc. + uint64_t getBackendIndex() const { + return DispatchKeySet((repr_ & full_backend_mask) >> 1).indexOfHighestBit(); } private: @@ -130,42 +487,53 @@ class DispatchKeySet final { uint64_t repr_ = 0; public: - // STL iterator for DispatchKeySet. Iterates through all DispatchKeys in the - // set. The iterator is only invalidated by the destruction of the underlying - // DispatchKeySet as the iterator stores a pointer to the raw representation - // of the DispatchKeySet. + // STL iterator for DispatchKeySet. Iterates through all runtime DispatchKeys + // in the set. The iterator is only invalidated by the destruction of the + // underlying DispatchKeySet as the iterator stores a pointer to the raw + // representation of the DispatchKeySet. Note: When we encounter a per-backend + // functionality (e.g. Dense or Sparse), we will iterate through EVERY backend + // in the keyset, for that functionality. For example, if the next + // functionality key to iterate over is Autograd, and the backend bits in the + // keyset correspond to [BackendComponent::CPUBit, BackendComponent::CUDABit], + // then the next two keys we return will be DispatchKey::AutogradCPU, + // DispatchKey::AutogradCUDA (CPU first because it has lower precedence than + // CUDA in DispatchKey.h). class iterator { public: using self_type = iterator; using iterator_category = std::input_iterator_tag; using value_type = DispatchKey; using difference_type = ptrdiff_t; - - explicit iterator(const uint64_t* data_ptr, uint8_t i = 0) - : data_ptr_(data_ptr), i_(i) { + // final mask value should mask out the entire keyset + static const uint8_t end_iter_mask_val = + num_backends + num_functionality_keys; + // final key value should be the last DispatchKey + static const uint8_t end_iter_key_val = num_functionality_keys; + + // current_dispatchkey_idx_ will iterate through all functionality bits. + // current_backendcomponent_idx_ will iterate through all backend bits. + explicit iterator( + const uint64_t* data_ptr, + uint8_t next_functionality = num_backends, + uint8_t next_backend = 0) + : data_ptr_(data_ptr), + next_functionality_(next_functionality), + next_backend_(next_backend), + // These are in an invalid state at construction time, and set by the + // first increment call + current_dispatchkey_idx_(end_iter_key_val), + current_backendcomponent_idx_(end_iter_key_val) { // Go to the first key in the set + TORCH_INTERNAL_ASSERT( + next_functionality_ >= num_backends, + "num_backends=", + static_cast(num_backends), + "next_functionality_=", + static_cast(next_functionality_)); ++(*this); } - self_type& operator++() { - TORCH_INTERNAL_ASSERT( - i_ <= static_cast(DispatchKey::NumDispatchKeys)); - - // Create a masked version of the set representation to ignore previous - // keys that we've iterated through. - uint64_t masked_data = llvm::maskTrailingZeros(i_) & *data_ptr_; - uint64_t firstKeyIndex = llvm::findFirstSet(masked_data); - - // If there are no keys, set to end iterator value - if (firstKeyIndex == std::numeric_limits::max() || - i_ == static_cast(DispatchKey::NumDispatchKeys)) { - i_ = static_cast(DispatchKey::NumDispatchKeys); - return *this; - } - - i_ = static_cast(firstKeyIndex) + 1; - return *this; - } + C10_API self_type& operator++(); self_type operator++(int) { self_type previous_iterator = *this; @@ -174,18 +542,50 @@ class DispatchKeySet final { } bool operator==(const self_type& rhs) const { - return i_ == rhs.i_; + return next_functionality_ == rhs.next_functionality_ && + current_dispatchkey_idx_ == rhs.current_dispatchkey_idx_ && + next_backend_ == rhs.next_backend_ && + current_backendcomponent_idx_ == rhs.current_backendcomponent_idx_; } bool operator!=(const self_type& rhs) const { - return i_ != rhs.i_; + return next_functionality_ != rhs.next_functionality_ || + current_dispatchkey_idx_ != rhs.current_dispatchkey_idx_ || + next_backend_ != rhs.next_backend_ || + current_backendcomponent_idx_ != rhs.current_backendcomponent_idx_; } DispatchKey operator*() const { - return static_cast(i_); + auto functionality_key = + static_cast(current_dispatchkey_idx_); + if (isPerBackendFunctionalityKey(functionality_key)) { + auto next_key = toRuntimePerBackendFunctionalityKey( + functionality_key, + static_cast(current_backendcomponent_idx_)); + // We expect all of the Dense, Sparse, Quantized, and Autograd keys to + // be ordered the same way with respect to their backends + TORCH_INTERNAL_ASSERT( + toBackendComponent(next_key) == + static_cast(current_backendcomponent_idx_), + "Tried to map functionality key ", + toString(functionality_key), + " and backend bit ", + toString( + static_cast(current_backendcomponent_idx_)), + " to a runtime key, but ended up with ", + toString(next_key), + ". This can happen if the order of the backend dispatch keys in DispatchKey.h isn't consistent.", + " Please double check that enum for inconsistencies."); + return next_key; + } else { + return functionality_key; + } } private: const uint64_t* data_ptr_; - uint8_t i_; + uint8_t next_functionality_; + uint8_t next_backend_; + uint8_t current_dispatchkey_idx_; + uint8_t current_backendcomponent_idx_; }; public: @@ -195,31 +595,35 @@ class DispatchKeySet final { return iterator(&repr_); } - // We do not need to iterate beyond NumDispatchKeys so we will treat this as - // the end iterator. NumDispatchKeys will always be strictly less than 64. + // We do not need to iterate beyond EndOfFunctionalityKeys so we will treat + // this as the end iterator. iterator end() const { - return iterator(&repr_, static_cast(DispatchKey::NumDispatchKeys)); + return iterator(&repr_, iterator::end_iter_mask_val); } }; C10_API std::string toString(DispatchKeySet); C10_API std::ostream& operator<<(std::ostream&, DispatchKeySet); -// autograd_dispatch_keyset should include all runtime autograd keys. -// Alias key DispatchKey::Autograd maps to autograd_dispatch_keyset. +C10_API inline int getDispatchTableIndexForDispatchKey(DispatchKey k) { + return DispatchKeySet(k).getDispatchTableIndexForDispatchKeySet(); +} + +// Alias key DispatchKey::Autograd maps to +// (autograd_dispatch_keyset x full_backend_mask) // NB: keys in this set also get associated with CompositeImplicitAutograd +// +// Note [autograd_dispatch_keyset Does Not Include Backend Bits] +// We don't want to include any backend bits (BackendComponent::CPUBit, etc) +// directly in autograd_dispatch_keyset. +// Why? keysets like autograd_dispatch_keyset are commonly used to remove +// autograd keys from a DispatchKeySet throughout the code base. However, you +// are only allowed to remove functionality bits from a keyset, not backend +// bits. See Note [Removing keys from DispatchKeySet Only Affects Functionality +// Keys] for details. To be consistent and avoid confusion, we're explicitly +// setting up autograd_dispatch_keyset to not have any backend bits. constexpr DispatchKeySet autograd_dispatch_keyset = DispatchKeySet({ - DispatchKey::AutogradCPU, - DispatchKey::AutogradCUDA, - DispatchKey::AutogradXLA, - DispatchKey::AutogradLazy, - DispatchKey::AutogradNestedTensor, - DispatchKey::AutogradMLC, - DispatchKey::AutogradHPU, - DispatchKey::AutogradXPU, - DispatchKey::AutogradPrivateUse1, - DispatchKey::AutogradPrivateUse2, - DispatchKey::AutogradPrivateUse3, + DispatchKey::AutogradFunctionality, DispatchKey::AutogradOther, }); @@ -244,25 +648,28 @@ constexpr DispatchKeySet autograd_dispatch_keyset_with_ADInplaceOrView = // backend dispatch keys that map to DispatchKey::AutogradOther // NB: keys in this set also get associated with CompositeImplicitAutograd -constexpr DispatchKeySet autogradother_backends = DispatchKeySet( - {DispatchKey::HIP, - DispatchKey::VE, - DispatchKey::FPGA, - DispatchKey::ORT, - DispatchKey::Vulkan, - DispatchKey::Metal, - DispatchKey::QuantizedCPU, - DispatchKey::QuantizedCUDA, - DispatchKey::CustomRNGKeyId, - DispatchKey::MkldnnCPU, - DispatchKey::SparseCPU, - DispatchKey::SparseCUDA, - DispatchKey::SparseHIP, - DispatchKey::SparseVE, - DispatchKey::SparseXPU, - DispatchKey::SparseCsrCPU, - DispatchKey::SparseCsrCUDA, - DispatchKey::Meta}); +constexpr DispatchKeySet autogradother_backends = + DispatchKeySet( + // HIP and VE aren't in this list: they now have their own backend bits + // which means that they can now have their own Autograd keys. + // Technically, HIP will now redispatch to its own custom AutogradHIP + // slot in the runtime table. + {DispatchKey::FPGA, + DispatchKey::ORT, + DispatchKey::Vulkan, + DispatchKey::Metal, + DispatchKey::SparseCsrCPU, + DispatchKey::SparseCsrCUDA, + DispatchKey::CustomRNGKeyId, + DispatchKey::MkldnnCPU, + DispatchKey::Meta, + // Sparse and Quantized backends also live here. + DispatchKey::Sparse, + DispatchKey::Quantized}) + // Including the backend bits because this keyset is used during op + // registration, which requires looping over all runtime autogradother + // backend keys. + | DispatchKeySet(DispatchKeySet::RAW, full_backend_mask); // The set of dispatch keys that come after autograd // n.b. this relies on the fact that AutogradOther is currently the lowest @@ -292,6 +699,36 @@ constexpr DispatchKeySet after_func_keyset = // away with it by explicitly removing the key here. c10::DispatchKey::ADInplaceOrView); +constexpr DispatchKeySet backend_bitset_mask = + DispatchKeySet(DispatchKeySet::RAW, (1ULL << num_backends) - 1); + +constexpr auto inplace_or_view_ks = + DispatchKeySet(DispatchKey::ADInplaceOrView); +constexpr auto autograd_cpu_ks = DispatchKeySet(DispatchKey::AutogradCPU); +constexpr auto autograd_xpu_ks = DispatchKeySet(DispatchKey::AutogradXPU); +constexpr auto autograd_cuda_ks = DispatchKeySet(DispatchKey::AutogradCUDA); +constexpr auto autograd_xla_ks = DispatchKeySet(DispatchKey::AutogradXLA); +constexpr auto autograd_lazy_ks = DispatchKeySet(DispatchKey::AutogradLazy); +constexpr auto autograd_mlc_ks = DispatchKeySet(DispatchKey::AutogradMLC); +constexpr auto autograd_hpu_ks = DispatchKeySet(DispatchKey::AutogradHPU); +constexpr auto autograd_privateuse1_ks = + DispatchKeySet(DispatchKey::AutogradPrivateUse1); +constexpr auto autograd_privateuse2_ks = + DispatchKeySet(DispatchKey::AutogradPrivateUse2); +constexpr auto autograd_privateuse3_ks = + DispatchKeySet(DispatchKey::AutogradPrivateUse3); +constexpr auto autograd_other_ks = DispatchKeySet(DispatchKey::AutogradOther); + +struct OpTableOffsetAndMask { + uint16_t offset; + uint16_t backend_mask; +}; + +static_assert( + num_backends <= 16, + "Right now we expect the number of backends not to exceed 16. In the (unlikely) event" + " that this changes, the size of OpTableOffsetAndMask::backend_mask needs to be increased too."); + // true if t is a backend dispatch key C10_API bool isBackendDispatchKey(DispatchKey t); @@ -307,10 +744,53 @@ C10_API bool runtimeDispatchKeySetHas(DispatchKey t, DispatchKey k); C10_API DispatchKeySet getBackendKeySetFromAutograd(DispatchKey t); // Returns a DispatchKeySet of autograd related keys mapped to backend. -C10_API DispatchKeySet getAutogradRelatedKeySetFromBackend(DispatchKey t); +// for a given backend key, use the associated autograd key. +// for non-backend keys, use AutogradOther as a default. +// Note: it's convenient and fast to return a default here rather than (say) +// returning an optional, or throwing. But it makes callers +// responsible for either a) enforcing the invariant that only backend keys +// be passed as arguments, or b) interpreting our return value carefully. +inline DispatchKeySet getAutogradRelatedKeySetFromBackend(BackendComponent t) { + switch (t) { + case BackendComponent::CPUBit: + return inplace_or_view_ks | autograd_cpu_ks; + case BackendComponent::XPUBit: + return inplace_or_view_ks | autograd_xpu_ks; + case BackendComponent::CUDABit: + return inplace_or_view_ks | autograd_cuda_ks; + case BackendComponent::XLABit: + return inplace_or_view_ks | autograd_xla_ks; + case BackendComponent::LazyBit: + return inplace_or_view_ks | autograd_lazy_ks; + case BackendComponent::MLCBit: + return inplace_or_view_ks | autograd_mlc_ks; + case BackendComponent::HPUBit: + return inplace_or_view_ks | autograd_hpu_ks; + case BackendComponent::PrivateUse1Bit: + return inplace_or_view_ks | autograd_privateuse1_ks; + case BackendComponent::PrivateUse2Bit: + return inplace_or_view_ks | autograd_privateuse2_ks; + case BackendComponent::PrivateUse3Bit: + return inplace_or_view_ks | autograd_privateuse3_ks; + default: + return inplace_or_view_ks | autograd_other_ks; + } +} // Returns a DispatchKeySet of autocast related keys mapped to backend. -C10_API DispatchKeySet getAutocastRelatedKeySetFromBackend(DispatchKey t); +inline DispatchKeySet getAutocastRelatedKeySetFromBackend(BackendComponent t) { + constexpr auto autocast_cpu_ks = DispatchKeySet(DispatchKey::AutocastCPU); + constexpr auto autocast_cuda_ks = DispatchKeySet(DispatchKey::AutocastCUDA); + switch (t) { + case BackendComponent::CPUBit: + return autocast_cpu_ks; + case BackendComponent::CUDABit: + case BackendComponent::XLABit: + return autocast_cuda_ks; + default: + return DispatchKeySet(); + } +} // This API exists because we have a use case for checking // getRuntimeDispatchKeySet(alias).has(DispatchKey::Undefined) diff --git a/c10/core/TensorImpl.cpp b/c10/core/TensorImpl.cpp index bed86ca257cc00..158dfc590e496f 100644 --- a/c10/core/TensorImpl.cpp +++ b/c10/core/TensorImpl.cpp @@ -192,7 +192,7 @@ TensorImpl::TensorImpl( // TODO: be more explicit about the full key set at call sites so we // don't have to keep recomputing it here - DispatchKey k = key_set.highestPriorityBackendTypeId(); + auto k = key_set.highestBackendKey(); key_set = key_set | getAutocastRelatedKeySetFromBackend(k); diff --git a/c10/core/TensorImpl.h b/c10/core/TensorImpl.h index 508665583186ec..bec47f3acba714 100644 --- a/c10/core/TensorImpl.h +++ b/c10/core/TensorImpl.h @@ -854,10 +854,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { bool is_sparse() const { // NB: This method is not virtual and avoid dispatches for performance // reasons. - return key_set_.has(DispatchKey::SparseCPU) || - key_set_.has(DispatchKey::SparseCUDA) || - key_set_.has(DispatchKey::SparseHIP) || - key_set_.has(DispatchKey::SparseXPU); + return key_set_.has(DispatchKey::Sparse); } // Whether a tensor is sparse COO or not. Use is_sparse_csr for checking CSR @@ -870,9 +867,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { bool is_quantized() const { // NB: This method is not virtual and avoid dispatches for performance // reasons. - return key_set_.has(DispatchKey::QuantizedCPU) || - key_set_.has(DispatchKey::QuantizedCUDA) || - key_set_.has(DispatchKey::QuantizedXPU); + return key_set_.has(DispatchKey::Quantized); } bool is_meta() const { @@ -884,53 +879,46 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { bool is_cpu() const { // NB: This method is not virtual and avoid dispatches for performance // reasons. - return key_set_.has(DispatchKey::CPU) || - key_set_.has(DispatchKey::SparseCPU) || + return key_set_.has_backend(BackendComponent::CPUBit) || key_set_.has(DispatchKey::SparseCsrCPU) || - key_set_.has(DispatchKey::QuantizedCPU) || key_set_.has(DispatchKey::MkldnnCPU); } bool is_cuda() const { // NB: This method is not virtual and avoid dispatches for performance // reasons. - return key_set_.has(DispatchKey::CUDA) || - key_set_.has(DispatchKey::SparseCUDA) || - key_set_.has(DispatchKey::SparseCsrCUDA) || - key_set_.has(DispatchKey::QuantizedCUDA); + return key_set_.has_backend(BackendComponent::CUDABit) || + key_set_.has(DispatchKey::SparseCsrCUDA); } bool is_xpu() const { // NB: This method is not virtual and avoid dispatches for performance // reasons. - return key_set_.has(DispatchKey::XPU) || - key_set_.has(DispatchKey::SparseXPU) || - key_set_.has(DispatchKey::QuantizedXPU); + return key_set_.has_backend(BackendComponent::XPUBit); } bool is_xla() const { - return key_set_.has(DispatchKey::XLA); + return key_set_.has_backend(BackendComponent::XLABit); } bool is_hpu() const { - return key_set_.has(DispatchKey::HPU); + return key_set_.has_backend(BackendComponent::HPUBit); } bool is_lazy() const { - return key_set_.has(DispatchKey::Lazy); + return key_set_.has_backend(BackendComponent::LazyBit); } bool is_hip() const { // NB: This method is not virtual and avoid dispatches for performance // reasons. - return key_set_.has(DispatchKey::HIP) || - key_set_.has(DispatchKey::SparseHIP); + return key_set_.has_backend(BackendComponent::HIPBit); } bool is_ve() const { // NB: This method is not virtual and avoid dispatches for performance // reasons. - return key_set_.has(DispatchKey::VE) || key_set_.has(DispatchKey::SparseVE); + return key_set_.has_backend(BackendComponent::VEBit); } bool is_mkldnn() const { @@ -1570,13 +1558,22 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { */ inline bool has_compatible_shallow_copy_type(DispatchKeySet from) { auto is_dense = [](DispatchKeySet ts) { - return ts.has(DispatchKey::CPU) || ts.has(DispatchKey::CUDA) || - ts.has(DispatchKey::HIP) || ts.has(DispatchKey::XPU); + constexpr auto dense_backends = DispatchKeySet( + {BackendComponent::CPUBit, + BackendComponent::CUDABit, + BackendComponent::HIPBit, + BackendComponent::XPUBit}); + constexpr auto dense_k = DispatchKeySet(DispatchKey::Dense); + return ts.has_any(dense_k) && ts.has_any(dense_backends); }; auto is_sparse = [](DispatchKeySet ts) { - return ts.has(DispatchKey::SparseCPU) || - ts.has(DispatchKey::SparseCUDA) || ts.has(DispatchKey::SparseHIP) || - ts.has(DispatchKey::SparseXPU); + constexpr auto sparse_backends = DispatchKeySet( + {BackendComponent::CPUBit, + BackendComponent::CUDABit, + BackendComponent::HIPBit, + BackendComponent::XPUBit}); + constexpr auto sparse_k = DispatchKeySet(DispatchKey::Sparse); + return ts.has_any(sparse_k) && ts.has_any(sparse_backends); }; return (key_set_ == from) || (is_dense(key_set_) && is_dense(from)) || (is_sparse(key_set_) && is_sparse(from)); diff --git a/c10/test/core/DispatchKeySet_test.cpp b/c10/test/core/DispatchKeySet_test.cpp index 43b06c110e5bac..2c0de14405d0b8 100644 --- a/c10/test/core/DispatchKeySet_test.cpp +++ b/c10/test/core/DispatchKeySet_test.cpp @@ -3,25 +3,163 @@ #include #include +#include using namespace c10; +// This test exists not to be comprehensive, but to more clearly show +// what the semantics of DispatchKeySet are. +TEST(DispatchKeySet, ShowSemantics) { + // the "CPU" dispatch key is an instance of a per-backend-functionality key. + // It corresponds to "dense" functionality, "CPU" backend. + // This means that it gets a dense functionality bit, and a cpu backend bit + // set. + auto undefined_set = DispatchKeySet(); + auto dense_cpu_set = DispatchKeySet(DispatchKey::CPU); + ASSERT_TRUE(dense_cpu_set.has(DispatchKey::Dense)); + ASSERT_TRUE(dense_cpu_set.has_backend(BackendComponent::CPUBit)); + ASSERT_TRUE(dense_cpu_set.has(DispatchKey::CPU)); + + auto dense_lazy_set = DispatchKeySet(DispatchKey::Lazy); + ASSERT_TRUE(dense_lazy_set.has(DispatchKey::Dense)); + ASSERT_TRUE(dense_lazy_set.has_backend(BackendComponent::LazyBit)); + ASSERT_TRUE(dense_lazy_set.has(DispatchKey::Lazy)); + + // You can think of "Dense/Sparse", and "CPUBit/CUDABit", as "building block" + // dispatch keys. You are allowed to directly create keysets out of them! + auto dense_cpu_set_from_building_blocks = DispatchKeySet(DispatchKey::Dense) | + DispatchKeySet(BackendComponent::CPUBit); + ASSERT_TRUE(dense_cpu_set.has(DispatchKey::Dense)); + ASSERT_TRUE(dense_cpu_set.has_backend(BackendComponent::CPUBit)); + ASSERT_TRUE(dense_cpu_set.has(DispatchKey::CPU)); + ASSERT_EQ(dense_cpu_set, dense_cpu_set_from_building_blocks); + + // Similarly, the AutogradCUDA key gets 2 bits in the keyset: + // The "Autograd" functionality bit, and the "CUDA" backend bit + auto autograd_cuda = DispatchKeySet(DispatchKey::AutogradCUDA); + ASSERT_TRUE(autograd_cuda.has(DispatchKey::AutogradFunctionality)); + ASSERT_TRUE(autograd_cuda.has_backend(BackendComponent::CUDABit)); + + // Because DispatchKeySet uses a condensed internal representation, you cannot + // use it to represent the FULL cross product of backends and functionalities + // for example: + auto autograd_dense_cpu_cuda = DispatchKeySet( + {DispatchKey::AutogradFunctionality, + DispatchKey::Dense, + DispatchKey::CUDA, + DispatchKey::CPU}); + auto fpga = DispatchKeySet(DispatchKey::FPGA); + auto fpga_and_cpu = DispatchKeySet({DispatchKey::FPGA, DispatchKey::CPU}); + // this keyset has all of the building block keys: + ASSERT_TRUE(autograd_dense_cpu_cuda.has(DispatchKey::AutogradFunctionality)); + ASSERT_TRUE(autograd_dense_cpu_cuda.has(DispatchKey::Dense)); + ASSERT_TRUE(autograd_dense_cpu_cuda.has_backend(BackendComponent::CUDABit)); + ASSERT_TRUE(autograd_dense_cpu_cuda.has_backend(BackendComponent::CPUBit)); + + // and it also has the "runtime" keys that correspond to the full + // cross-product of functionality + ASSERT_TRUE(autograd_dense_cpu_cuda.has(DispatchKey::AutogradCPU)); + ASSERT_TRUE(autograd_dense_cpu_cuda.has(DispatchKey::AutogradCPU)); + ASSERT_TRUE(autograd_dense_cpu_cuda.has(DispatchKey::CPU)); + ASSERT_TRUE(autograd_dense_cpu_cuda.has(DispatchKey::CUDA)); + + // This means that there's no way to represent a keyset with, say, only + // Autograd CUDA + Dense CPU. Instead, you should think of a keyset as + // inheriting the full set of functionalities + backends of its keys. This + // means that the below keysets are all indistinguishable from each other. + ASSERT_EQ( + autograd_dense_cpu_cuda, + DispatchKeySet( + {DispatchKey::AutogradCUDA, + DispatchKey::AutogradCPU, + DispatchKey::CUDA, + DispatchKey::CPU})); + ASSERT_EQ( + autograd_dense_cpu_cuda, + DispatchKeySet({DispatchKey::AutogradCUDA, DispatchKey::CPU})); + ASSERT_EQ( + autograd_dense_cpu_cuda, + DispatchKeySet({DispatchKey::CUDA, DispatchKey::AutogradCPU})); + + // ~~~~~~~~~~ DispatchKeySet iterators ~~~~~~~~~~~ + + // Iterators allow you to iterate individually through the DispatchKey's in a + // DispatchKeySet + auto empty_set = DispatchKeySet(); + auto t1 = empty_set.begin(); + auto t2 = empty_set.end(); + ASSERT_EQ(*empty_set.begin(), *empty_set.end()); + + // However, only keys that correspond to actual runtime indices of kernels in + // the operator table show up when you iterate through a keyset. i.e. + // DispatchKey::Dense, and BackendComponent::CPUBit won't show up in an + // iterator. + auto dense_cpu_iter = dense_cpu_set.begin(); + ASSERT_EQ(*dense_cpu_iter++, DispatchKey::CPU); + ASSERT_EQ(*dense_cpu_iter, *dense_cpu_set.end()); + + auto autograd_dense_cpu_cuda_iter = autograd_dense_cpu_cuda.begin(); + ASSERT_EQ(*autograd_dense_cpu_cuda_iter++, DispatchKey::CPU); + ASSERT_EQ(*autograd_dense_cpu_cuda_iter++, DispatchKey::CUDA); + ASSERT_EQ(*autograd_dense_cpu_cuda_iter++, DispatchKey::AutogradCPU); + ASSERT_EQ(*autograd_dense_cpu_cuda_iter++, DispatchKey::AutogradCUDA); + ASSERT_EQ(*autograd_dense_cpu_cuda_iter, *autograd_dense_cpu_cuda.end()); + + // But other "functionality bits" that are not defined per-backend DO get + // their own slots in the operator table. + auto mixed_keyset = DispatchKeySet(BackendComponent::CPUBit) | + DispatchKeySet( + {DispatchKey::FPGA, // runtime key + DispatchKey::Functionalize, // runtime key + DispatchKey::Dense}); // NOT a runtime key + auto mixed_iter = mixed_keyset.begin(); + ASSERT_EQ(*mixed_iter++, DispatchKey::CPU); + ASSERT_EQ(*mixed_iter++, DispatchKey::FPGA); + ASSERT_EQ(*mixed_iter++, DispatchKey::Functionalize); + ASSERT_EQ(*mixed_iter, *mixed_keyset.end()); +} + TEST(DispatchKeySet, Empty) { DispatchKeySet empty_set; - for (uint8_t i = 1; i < static_cast(DispatchKey::NumDispatchKeys); + for (uint8_t i = 0; + i <= static_cast(DispatchKey::EndOfRuntimeBackendKeys); i++) { auto tid = static_cast(i); + if (tid == DispatchKey::Undefined) + continue; ASSERT_FALSE(empty_set.has(tid)); } ASSERT_TRUE(empty_set.empty()); DispatchKeySet empty_set2; ASSERT_TRUE(empty_set == empty_set2); - ASSERT_EQ(empty_set.highestPriorityTypeId(), DispatchKey::Undefined); } -TEST(DispatchKeySet, Singleton) { - for (uint8_t i = 1; i < static_cast(DispatchKey::NumDispatchKeys); - i++) { +// This covers all keys that correspond to a single backend bit, e.g. +// BackendComponent::CPUBit. Even though these are NOT runtime keys, we still +// allow adding them directly to a keyset +TEST(DispatchKeySet, SingletonBackendComponent) { + for (const auto i : c10::irange(1, num_backends)) { + auto tid = static_cast(i); + DispatchKeySet sing(tid); + ASSERT_EQ(sing, sing); + ASSERT_EQ(sing, DispatchKeySet().add(tid)); + ASSERT_EQ(sing, sing.add(tid)); + ASSERT_EQ(sing, sing | sing); + ASSERT_FALSE(sing.empty()); + ASSERT_TRUE(sing.has(tid)); + } +} + +// This covers all keys that correspond to a single functionality bit: +// - runtime, not-per-backend functionality keys, e.g. +// DispatchKey::FuncTorchBatched +// - runtime, "fake backend" keys, e.g. DispatchKey::FPGA +// - NOT-runtime, per-backend functionality keys, e.g. DispatchKey::Dense +// Even though it's not a runtime key, we still allow adding it directly to a +// keyset. +// DispatchKey:: +TEST(DispatchKeySet, SingletonFunctionalityKeys) { + for (const auto i : c10::irange(1, num_functionality_keys)) { auto tid = static_cast(i); DispatchKeySet sing(tid); ASSERT_EQ(sing, sing); @@ -30,47 +168,145 @@ TEST(DispatchKeySet, Singleton) { ASSERT_EQ(sing, sing | sing); ASSERT_FALSE(sing.empty()); ASSERT_TRUE(sing.has(tid)); - ASSERT_EQ(sing.highestPriorityTypeId(), tid); ASSERT_EQ(sing.remove(tid), DispatchKeySet()); } } -TEST(DispatchKeySet, Doubleton) { - for (uint8_t i = 1; i < static_cast(DispatchKey::NumDispatchKeys); +// This covers runtime keys that are per-backend, +// and take up more than one bit in a DispatchKeySet. They take up one +// functionality bit + one backend bit. e.g. CPU, CUDA, SparseCPU, SparseCUDA, +// AutogradCPU, AutogradCUDA +TEST(DispatchKeySet, SingletonPerBackendFunctionalityKeys) { + for (uint8_t i = static_cast(DispatchKey::StartOfDenseBackends); + i <= static_cast(DispatchKey::EndOfRuntimeBackendKeys); + i++) { + auto tid = static_cast(i); + // Skip these because they aren't real keys. + if (tid == DispatchKey::StartOfDenseBackends || + tid == DispatchKey::StartOfSparseBackends || + tid == DispatchKey::StartOfQuantizedBackends || + tid == DispatchKey::StartOfAutogradBackends) { + continue; + } + DispatchKeySet sing(tid); + ASSERT_EQ(sing, sing); + ASSERT_EQ(sing, DispatchKeySet().add(tid)); + ASSERT_EQ(sing, sing.add(tid)); + ASSERT_EQ(sing, sing | sing); + ASSERT_FALSE(sing.empty()); + ASSERT_TRUE(sing.has(tid)); + + auto functionality_key = toFunctionalityKey(tid); + auto backend_key = toBackendComponent(tid); + // These two sets should be equivalent: + // DispatchKeySet(DispatchKey::CPU) + // DispatchKeySet({DispatchKey::Dense, BackendComponent::CPUBit}) + auto expected_ks = + DispatchKeySet(functionality_key) | DispatchKeySet(backend_key); + ASSERT_EQ(sing, expected_ks); + // These two sets should be equivalent: + // DispatchKeySet(DispatchKey::CPU).remove(DispatchKey::Dense) + // DispatchKeySet(BackendComponent::CPUBit) + expected_ks = DispatchKeySet(toBackendComponent(tid)); + ASSERT_EQ(sing.remove(tid), expected_ks); + } +} + +TEST(DispatchKeySet, DoubletonPerBackend) { + for (uint8_t i = static_cast(DispatchKey::StartOfDenseBackends); + i <= static_cast(DispatchKey::EndOfRuntimeBackendKeys); i++) { for (uint8_t j = i + 1; - j < static_cast(DispatchKey::NumDispatchKeys); + j <= static_cast(DispatchKey::EndOfRuntimeBackendKeys); j++) { ASSERT_LT(i, j); auto tid1 = static_cast(i); auto tid2 = static_cast(j); - auto doub = DispatchKeySet(tid1).add(tid2); - ASSERT_EQ(doub, DispatchKeySet(tid1) | DispatchKeySet(tid2)); - ASSERT_TRUE(doub.has(tid1)); - ASSERT_TRUE(doub.has(tid2)); - ASSERT_EQ(doub.highestPriorityTypeId(), tid2); // relies on i < j + + // Skip these because they aren't real keys. + if (tid1 == DispatchKey::StartOfDenseBackends || + tid1 == DispatchKey::StartOfSparseBackends || + tid1 == DispatchKey::StartOfQuantizedBackends || + tid1 == DispatchKey::StartOfAutogradBackends) + continue; + if (tid2 == DispatchKey::StartOfDenseBackends || + tid2 == DispatchKey::StartOfSparseBackends || + tid2 == DispatchKey::StartOfQuantizedBackends || + tid2 == DispatchKey::StartOfAutogradBackends) + continue; + + auto backend1 = toBackendComponent(tid1); + auto backend2 = toBackendComponent(tid2); + auto functionality1 = toFunctionalityKey(tid1); + auto functionality2 = toFunctionalityKey(tid2); + + auto combined = DispatchKeySet({tid1, tid2}); + // The combined set has the backend bits + ASSERT_TRUE(combined.has_backend(backend1)); + ASSERT_TRUE(combined.has_backend(backend2)); + // and it has the backend bits + ASSERT_TRUE(combined.has(functionality1)); + ASSERT_TRUE(combined.has(functionality2)); + // and it has the original two runtime keys + ASSERT_TRUE(combined.has(tid1)); + ASSERT_TRUE(combined.has(tid2)); + + // Add all of the keys in the keyset to a real set + std::unordered_set visited_keys; + auto iter = combined.begin(); + while (*iter != *combined.end()) { + visited_keys.insert(*iter); + ++iter; + } + std::unordered_set expected_keys; + expected_keys.insert( + toRuntimePerBackendFunctionalityKey(functionality1, backend1)); + expected_keys.insert( + toRuntimePerBackendFunctionalityKey(functionality1, backend2)); + expected_keys.insert( + toRuntimePerBackendFunctionalityKey(functionality2, backend1)); + expected_keys.insert( + toRuntimePerBackendFunctionalityKey(functionality2, backend2)); + ASSERT_EQ(expected_keys, visited_keys); + + if (backend1 == backend2 || functionality1 == functionality2) { + // We have two runtime keys, with either the same backend or the same + // per-backend functionalities. E.g. {AutogradCUDA, CUDA} or + // {AutogradCPU, AutogradCUDA} There should be 2 total runtime keys in + // this set. + ASSERT_EQ(2, visited_keys.size()); + } else { + // since i and j are different keys, they should not have the same + // functionality and backend + ASSERT_TRUE(backend1 != backend2 && functionality1 != functionality2); + // We have two runtime keys, that have different backends + per-backend + // functionalities. So we should expect the full cross product of + // runtime keys to be in the set. e.g. if i = AutogradCUDA, and j = CPU, + // then combined = {AutogradCUDA, AutogradCPU, CUDA, CPU} + ASSERT_EQ(4, visited_keys.size()); + } } } } TEST(DispatchKeySet, Full) { DispatchKeySet full(DispatchKeySet::FULL); - for (uint8_t i = 1; i < static_cast(DispatchKey::NumDispatchKeys); - i++) { + for (const auto i : c10::irange(1, num_functionality_keys)) { auto tid = static_cast(i); ASSERT_TRUE(full.has(tid)); } + ASSERT_FALSE(full.has(DispatchKey::EndOfFunctionalityKeys)); } TEST(DispatchKeySet, IteratorBasicOps) { DispatchKeySet empty_set; DispatchKeySet full_set(DispatchKeySet::FULL); - DispatchKeySet mutated_set = empty_set.add(static_cast(1)); + DispatchKeySet mutated_set = empty_set.add(DispatchKey::CPU); // Constructor + Comparison - ASSERT_EQ(*empty_set.begin(), DispatchKey::NumDispatchKeys); - ASSERT_EQ(*empty_set.end(), DispatchKey::NumDispatchKeys); - ASSERT_EQ(*mutated_set.begin(), static_cast(1)); + ASSERT_EQ(*empty_set.begin(), DispatchKey::EndOfFunctionalityKeys); + ASSERT_EQ(*empty_set.end(), DispatchKey::EndOfFunctionalityKeys); + ASSERT_EQ(*mutated_set.begin(), DispatchKey::CPU); ASSERT_TRUE(empty_set.begin() == empty_set.end()); ASSERT_TRUE(full_set.begin() != full_set.end()); @@ -90,16 +326,37 @@ TEST(DispatchKeySet, IteratorEmpty) { ASSERT_EQ(i, 0); } +TEST(DispatchKeySet, IteratorCrossProduct) { + // The iterator should return all runtime keys in the set, + // including the cross product of {backends} x {functionalities} + auto ks = + DispatchKeySet({BackendComponent::CPUBit, BackendComponent::CUDABit}) | + DispatchKeySet( + {DispatchKey::Dense, + DispatchKey::FPGA, + DispatchKey::AutogradFunctionality}); + + auto iter = ks.begin(); + // iterate through dense backends first. + ASSERT_EQ(DispatchKey::CPU, *(iter++)); + ASSERT_EQ(DispatchKey::CUDA, *(iter++)); + // FPGA doesn't have a backend bit, so it isn't included in the cross product. + ASSERT_EQ(DispatchKey::FPGA, *(iter++)); + // iterate through the autograd keys laster. + ASSERT_EQ(DispatchKey::AutogradCPU, *(iter++)); + ASSERT_EQ(DispatchKey::AutogradCUDA, *(iter++)); +} + TEST(DispatchKeySet, IteratorFull) { DispatchKeySet full_set(DispatchKeySet::FULL); uint8_t i = 0; for (const auto& it : full_set) { i++; - ASSERT_TRUE(it == static_cast(i)); - ASSERT_TRUE(it != DispatchKey::NumDispatchKeys); } - ASSERT_EQ(i, static_cast(DispatchKey::NumDispatchKeys) - 1); + // Total # of runtime entries includes an entry for DispatchKey::Undefined, + // which is not included when iterating through the DispatchKeySet. + ASSERT_EQ(i, num_runtime_entries - 1); } TEST(DispatchKeySet, IteratorRangeFull) { @@ -108,41 +365,61 @@ TEST(DispatchKeySet, IteratorRangeFull) { for (DispatchKey dispatch_key : full_set) { i++; - ASSERT_TRUE(dispatch_key == static_cast(i)); - } - - ASSERT_EQ(i, static_cast(DispatchKey::NumDispatchKeys) - 1); -} - -TEST(DispatchKeySet, SpecificKeys) { - DispatchKeySet keyset({ - static_cast(0), // Undefined should be ignored - static_cast(4), - static_cast(10), - static_cast(15), - }); - std::unordered_set visited_keys; - - for (DispatchKey key : keyset) { - visited_keys.insert(key); } - ASSERT_EQ(visited_keys.size(), 3); - ASSERT_TRUE( - visited_keys.find(static_cast(4)) != visited_keys.end()); - ASSERT_TRUE( - visited_keys.find(static_cast(10)) != visited_keys.end()); - ASSERT_TRUE( - visited_keys.find(static_cast(15)) != visited_keys.end()); + // Total # of runtime entries includes an entry for DispatchKey::Undefined, + // which is not included when iterating through the DispatchKeySet. + ASSERT_EQ(i, num_runtime_entries - 1); } TEST(DispatchKeySet, FailAtEndIterator) { DispatchKeySet full_set(DispatchKeySet::FULL); uint64_t raw_repr = full_set.raw_repr(); + // doesn't throw + DispatchKeySet::iterator(&raw_repr, num_backends + num_functionality_keys); // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) EXPECT_THROW( DispatchKeySet::iterator( - &raw_repr, static_cast(DispatchKey::NumDispatchKeys) + 1), + &raw_repr, num_backends + num_functionality_keys + 1), c10::Error); } + +TEST(DispatchKeySet, TestKeyOrderingInvariants) { + for (uint8_t i = static_cast(DispatchKey::StartOfDenseBackends); + i <= static_cast(DispatchKey::EndOfRuntimeBackendKeys); + i++) { + auto k = static_cast(i); + // Note [The Ordering of Per-Backend Dispatch Keys Matters!] + // The DispatchKey enum includes all of the runtime keys for + // Dense/Sparse/Quantized/Autograd, (e.g. CPU, CUDA, SparseCPU, SparseCUDA, + // AutogradCPU, AutogradCUDA, etc). And we expect the ordering of those keys + // to be the same as the ordering of the backends in the `BackendComponent` + // enum. This makes several utilities in `DispatchKey.h` and + // `DispatchKeySet.h` significantly easier to implement. The purpose of the + // test is to assert (through CI) that this invariant is maintained. + // + // The only way that we can really check this invariant is by + // comparing the string names of each enum. + // We only really care about the ordering for "real" keys that are actually + // used, which we expect to be able to print properly. This saves us from + // having to enumerate the full set of possible runtime keys in + // DispatchKey::toString(). It also relies on toString() being implemented + // correctly. + auto functionality_str = std::string(toString(k)); + if (functionality_str == "UNKNOWN_TENSOR_TYPE_ID") + continue; + + auto computed_backend_k = toBackendComponent(k); + auto computed_backend_str = std::string(toString(computed_backend_k)); + // Skip, e.g., the "Bit" from "CPUBit" + computed_backend_str = + computed_backend_str.substr(0, computed_backend_str.size() - 3); + + ASSERT_TRUE( + functionality_str.find(computed_backend_str) != std::string::npos) + << "DispatchKey invariant broken! Found a key that is not ordered correctly" + << " with its backend bit. key = " << toString(k) << ", " << k + << ", computed backend = " << toString(computed_backend_k); + } +} diff --git a/test/test_dispatch.py b/test/test_dispatch.py index 37a6054f9151e6..c97e9e382fc766 100644 --- a/test/test_dispatch.py +++ b/test/test_dispatch.py @@ -532,8 +532,8 @@ def test_computed_table_with_ambiguous_autogradother(self): lambda m: m.def_("foo(Tensor x) -> Tensor"), # m.impl("foo", torch::kCompositeImplicitAutograd, [](const Tensor & x) { return x }) lambda m: m.impl_t_t("foo", "CompositeImplicitAutograd", debug="fn_math"), - # m.impl("foo", torch::kQuantizedCPU, [](const Tensor & x) { return x }) - lambda m: m.impl_t_t("foo", "QuantizedCPU", debug="fn_quantizedcpu"), + # m.impl("foo", torch::kFPGA, [](const Tensor & x) { return x }) + lambda m: m.impl_t_t("foo", "FPGA", debug="fn_fpga"), ]) state, table = result.state, result.table self.assertExpectedInline(state, '''\ @@ -541,12 +541,12 @@ def test_computed_table_with_ambiguous_autogradother(self): schema: test::foo(Tensor x) -> (Tensor) debug: registered at /dev/null:0 alias analysis kind: FROM_SCHEMA -QuantizedCPU: fn_quantizedcpu :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] +FPGA: fn_fpga :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] CompositeImplicitAutograd[alias]: fn_math :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] ''') # computed dispatch table is too big, so we only check on a few entries we're interested in. - extracted_table = extract_dispatch_table_with_keys(table, dispatch_keys_to_check + ('QuantizedCPU',)) + extracted_table = extract_dispatch_table_with_keys(table, dispatch_keys_to_check + ('FPGA',)) self.assertExpectedInline(extracted_table, '''\ Undefined: fn_math [math kernel] @@ -557,7 +557,7 @@ def test_computed_table_with_ambiguous_autogradother(self): AutogradCPU: fn_math [math kernel] AutogradCUDA: fn_math [math kernel] AutogradXLA: fn_math [math kernel] -QuantizedCPU: fn_quantizedcpu [kernel] +FPGA: fn_fpga [kernel] ''') def test_computed_table_with_cpu_defaultbackend(self): @@ -616,7 +616,7 @@ def test_computed_table_with_cpu_autograd_defaultbackend(self): ''') # computed dispatch table is too big, so we only check on a few entries we're interested in. - extracted_table = extract_dispatch_table_with_keys(table, dispatch_keys_to_check + ('QuantizedCPU',)) + extracted_table = extract_dispatch_table_with_keys(table, dispatch_keys_to_check + ('FPGA',)) self.assertExpectedInline(extracted_table, '''\ Undefined: fn_defaultbackend [default backend kernel] @@ -627,7 +627,7 @@ def test_computed_table_with_cpu_autograd_defaultbackend(self): AutogradCPU: fn_autograd [autograd kernel] AutogradCUDA: fn_autograd [autograd kernel] AutogradXLA: fn_autograd [autograd kernel] -QuantizedCPU: fn_defaultbackend [default backend kernel] +FPGA: fn_defaultbackend [default backend kernel] ''') def test_computed_table_with_cpu_autograd_math_defaultbackend(self): @@ -808,7 +808,7 @@ def test_basic(self): CPU fn_CPU [kernel] XLA fn_XLA [kernel] Lazy fn_Lazy [kernel] -QuantizedCPU fn_CompositeImplicitAutograd [math kernel] +FPGA fn_CompositeImplicitAutograd [math kernel] AutogradOther fn_CompositeImplicitAutograd [math kernel] AutogradCPU fallthrough [backend fallback] AutogradXLA fallthrough [backend fallback] @@ -829,7 +829,7 @@ def test_math_autogradcpu(self): CPU fn_CPU [kernel] XLA fn_XLA [kernel] Lazy fn_Lazy [kernel] -QuantizedCPU fn_CompositeImplicitAutograd [math kernel] +FPGA fn_CompositeImplicitAutograd [math kernel] AutogradOther fn_CompositeImplicitAutograd [math kernel] AutogradCPU fn_AutogradCPU [kernel] AutogradXLA fallthrough [backend fallback] @@ -864,7 +864,7 @@ def test_defaultbackend_autogradcpu(self): CPU fn_CPU [kernel] XLA fn_XLA [kernel] Lazy fn_Lazy [kernel] -QuantizedCPU fn_CompositeExplicitAutograd [default backend kernel] +FPGA fn_CompositeExplicitAutograd [default backend kernel] AutogradOther fallthrough [backend fallback] AutogradCPU fn_AutogradCPU [kernel] AutogradXLA fallthrough [backend fallback] @@ -889,7 +889,7 @@ def test_defaultbackend_autogradcpu(self): def test_autogradother(self): dispatcher = PythonDispatcher() - dispatcher.register(["CPU", "QuantizedCPU", "CompositeImplicitAutograd"]) + dispatcher.register(["CPU", "FPGA", "CompositeImplicitAutograd"]) self.assertExpectedInline( dispatcher.dispatchTable(), '''\ @@ -900,7 +900,7 @@ def test_autogradother(self): CPU fn_CPU [kernel] XLA fn_CompositeImplicitAutograd [math kernel] Lazy fn_CompositeImplicitAutograd [math kernel] -QuantizedCPU fn_QuantizedCPU [kernel] +FPGA fn_FPGA [kernel] AutogradOther ambiguous_autogradother [ambiguous autogradother] AutogradCPU fallthrough [backend fallback] AutogradXLA fn_CompositeImplicitAutograd [math kernel] @@ -915,8 +915,8 @@ def test_autogradother(self): Registered Kernels key kernel --------------------------- +FPGA fn_FPGA CPU fn_CPU -QuantizedCPU fn_QuantizedCPU CompositeImplicitAutograd[alias] fn_CompositeImplicitAutograd ''' ) diff --git a/test/test_sparse.py b/test/test_sparse.py index a50d493cdac635..99c90f48633c3d 100644 --- a/test/test_sparse.py +++ b/test/test_sparse.py @@ -3402,21 +3402,21 @@ class TestSparseOneOff(TestCase): def test_cuda_from_cpu(self): with self.assertRaisesRegex( RuntimeError, - "backend of indices \\(CUDA\\) must match backend of values \\(CPU\\)"): + "Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!"): torch.sparse.FloatTensor(torch.zeros(1, 4).long().cuda(), torch.randn(4, 4, 4), [3, 4, 4]) with self.assertRaisesRegex( RuntimeError, - "backend of indices \\(CUDA\\) must match backend of values \\(CPU\\)"): + "Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!"): torch.sparse.FloatTensor(torch.zeros(1, 4).long().cuda(), torch.randn(4, 4, 4, 0), [3, 4, 4, 0]) with self.assertRaisesRegex( RuntimeError, - "backend of indices \\(CUDA\\) must match backend of values \\(CPU\\)"): + "Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!"): torch.sparse.FloatTensor(torch.LongTensor(1, 0).cuda(), torch.randn(0, 4, 4, 0), [0, 4, 4, 0]) diff --git a/tools/codegen/model.py b/tools/codegen/model.py index 1c61517a3e52b2..22dc925f301131 100644 --- a/tools/codegen/model.py +++ b/tools/codegen/model.py @@ -48,58 +48,66 @@ class DispatchKey(Enum): Undefined = 0 CatchAll = Undefined - CPU = auto() - CUDA = auto() - HIP = auto() + Dense = auto() FPGA = auto() ORT = auto() - XLA = auto() - Lazy = auto() Vulkan = auto() Metal = auto() - XPU = auto() MKLDNN = auto() OpenGL = auto() OpenCL = auto() IDEEP = auto() - QuantizedCPU = auto() - QuantizedCUDA = auto() - QuantizedXPU = auto() + Quantized = auto() CustomRNGKeyId = auto() MkldnnCPU = auto() - SparseCPU = auto() - SparseCUDA = auto() + Sparse = auto() SparseCsrCPU = auto() SparseCsrCUDA = auto() - SparseHIP = auto() - SparseXPU = auto() - NestedTensor = auto() - PrivateUse1 = auto() - PrivateUse2 = auto() - PrivateUse3 = auto() - EndOfBackendKeys = PrivateUse3 ZeroTensor = auto() Meta = auto() BackendSelect = auto() Named = auto() AutogradOther = auto() + AutogradFunctionality = auto() + AutogradNestedTensor = auto() + Tracer = auto() + Autocast = auto() + Batched = auto() + VmapMode = auto() + TESTING_ONLY_GenericWrapper = auto() + TESTING_ONLY_GenericMode = auto() + EndOfFunctionalityKeys = TESTING_ONLY_GenericMode + + CPU = auto() + CUDA = auto() + HIP = auto() + XLA = auto() + Lazy = auto() + XPU = auto() + NestedTensor = auto() + PrivateUse1 = auto() + PrivateUse2 = auto() + PrivateUse3 = auto() + + QuantizedCPU = auto() + QuantizedCUDA = auto() + QuantizedXPU = auto() + + SparseCPU = auto() + SparseCUDA = auto() + SparseHIP = auto() + SparseXPU = auto() + AutogradCPU = auto() AutogradCUDA = auto() AutogradXLA = auto() AutogradLazy = auto() - AutogradNestedTensor = auto() AutogradXPU = auto() AutogradPrivateUse1 = auto() AutogradPrivateUse2 = auto() AutogradPrivateUse3 = auto() - Tracer = auto() - Autocast = auto() - Batched = auto() - VmapMode = auto() - TESTING_ONLY_GenericWrapper = auto() - TESTING_ONLY_GenericMode = auto() - NumDispatchKeys = auto() + Autograd = auto() CompositeImplicitAutograd = auto() CompositeExplicitAutograd = auto() diff --git a/torch/_python_dispatcher.py b/torch/_python_dispatcher.py index aa19a18efb3b56..fe0c6253fdd34a 100644 --- a/torch/_python_dispatcher.py +++ b/torch/_python_dispatcher.py @@ -15,9 +15,9 @@ - CPU/AutogradCPU: represents in-tree backends which we usually have dedicated inference & autograd kernel in pytorch core library. E.g. CPU, CUDA -- QuantizedCPU/AutogradOther: represents in-tree backends which we usually have backend specific +- FPGA/AutogradOther: represents in-tree backends which we usually have backend specific inference kernels, but they share the same autograd kernel specified in AutogradOther. - E.g. QuantizedCPU, QuantizedCUDA + E.g. FPGA, SparseCsrCPU - XLA/AutogradXLA: represents out-of-tree backends which we don't have either inference or autograd kernel defined in pytorch core library. Backend owner is responsible for registering both inference & autograd kernels in their extensions(e.g. torch-xla) for the operators they support. @@ -53,7 +53,7 @@ class PythonDispatcher: name = "foo" runtime_keys = [ "CPU", "AutogradCPU", - "QuantizedCPU", "AutogradOther", + "FPGA", "AutogradOther", "XLA", "AutogradXLA", "Lazy", "AutogradLazy", ]