Skip to content

Commit

Permalink
free up dispatch key space (in C++) (pytorch#72827)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#72827

Reland of D34034848 (pytorch@6690256)
ghstack-source-id: 152161452

Test Plan: Confirm that Milan tests are passing

Reviewed By: ezyang

Differential Revision: D34227616

fbshipit-source-id: 6d1dd0fd8144dfbd9e194cd7564cce017e7db968
(cherry picked from commit e5c1b29)
  • Loading branch information
bdhirsh authored and pytorchmergebot committed Mar 25, 2022
1 parent 7c747c7 commit 2cbddc0
Show file tree
Hide file tree
Showing 19 changed files with 1,759 additions and 506 deletions.
3 changes: 1 addition & 2 deletions aten/src/ATen/TensorSubclassLikeUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@ constexpr auto kFunctorchWrappedTensors = DispatchKeySet({

constexpr auto kTensorSubclassLike = kFunctorchWrappedTensors | DispatchKeySet({
DispatchKey::Batched,
DispatchKey::SparseCPU,
DispatchKey::SparseCUDA,
DispatchKey::Sparse,
DispatchKey::SparseCsrCPU,
DispatchKey::SparseCsrCUDA,
DispatchKey::Meta,
Expand Down
41 changes: 41 additions & 0 deletions aten/src/ATen/core/dispatch/DispatchKeyExtractor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,52 @@
namespace c10 {

void DispatchKeyExtractor::setOperatorHasFallthroughForKey(DispatchKey k, bool has_fallthrough) {
// (1) update nonFallthroughKeys_
if (has_fallthrough) {
nonFallthroughKeys_ = nonFallthroughKeys_.remove(k);
} else {
nonFallthroughKeys_ = nonFallthroughKeys_.add(k);
}
// (2) update nonFallthroughKeysPerBackend_
if (isPerBackendFunctionalityKey(toFunctionalityKey(k))) {
// This is a per-backend functionality key.
// We need to figure out what the current backend is,
// and only update the bitset for that backend.
// subtracting 1 because the first backend should have index 0 (CPU),
// But the enum starts with BackendComponent::InvalidBit.
auto backend_idx = static_cast<uint8_t>(toBackendComponent(k)) - 1;
TORCH_INTERNAL_ASSERT(backend_idx >= 0 && static_cast<uint8_t>(backend_idx) < nonFallthroughKeysPerBackend_.size());
if (has_fallthrough) {
nonFallthroughKeysPerBackend_[backend_idx] = nonFallthroughKeysPerBackend_[backend_idx].remove(k);
} else {
nonFallthroughKeysPerBackend_[backend_idx] = nonFallthroughKeysPerBackend_[backend_idx].add(k);
}

// Set requiresBitsetPerBackend_ accordingly
for (const auto i : c10::irange(nonFallthroughKeysPerBackend_.size() - 1)) {
if (nonFallthroughKeysPerBackend_[i] != nonFallthroughKeysPerBackend_[i+1]) {
requiresBitsetPerBackend_ = true;
return;
}
}
requiresBitsetPerBackend_ = false;
return;
} else {
// Otherwise, if a fallthrough is set for a functionality that isn't per backend,
// Then we update the fallthrough bitset for EVERY backend.
// TODO: we could probably optimize this by only lazily updating these values
// the first time that we see requiresBitsetPerBackend_ = true
// (which should almost never happen)
if (has_fallthrough) {
for (const auto i : c10::irange(nonFallthroughKeysPerBackend_.size())) {
nonFallthroughKeysPerBackend_[i] = nonFallthroughKeysPerBackend_[i].remove(k);
}
} else {
for (const auto i : c10::irange(nonFallthroughKeysPerBackend_.size())) {
nonFallthroughKeysPerBackend_[i] = nonFallthroughKeysPerBackend_[i].add(k);
}
}
}
}

std::string DispatchKeyExtractor::dumpState() const {
Expand Down
29 changes: 25 additions & 4 deletions aten/src/ATen/core/dispatch/DispatchKeyExtractor.h
Original file line number Diff line number Diff line change
Expand Up @@ -156,14 +156,24 @@ struct TORCH_API DispatchKeyExtractor final {
}
});
// Keys that are fallthrough should be skipped
return impl::computeDispatchKeySet(ks, nonFallthroughKeys_);
if (requiresBitsetPerBackend_) {
auto backend_idx = ks.getBackendIndex();
return impl::computeDispatchKeySet(ks, nonFallthroughKeysPerBackend_[backend_idx]);
} else {
return impl::computeDispatchKeySet(ks, nonFallthroughKeys_);
}
}

template<class... Args>
DispatchKeySet getDispatchKeySetUnboxed(const Args&... args) const {
auto ks = detail::multi_dispatch_key_set(args...);
// Keys that are fallthrough should be skipped
return impl::computeDispatchKeySet(ks, nonFallthroughKeys_);
if (requiresBitsetPerBackend_) {
auto backend_idx = ks.getBackendIndex();
return impl::computeDispatchKeySet(ks, nonFallthroughKeysPerBackend_[backend_idx]);
} else {
return impl::computeDispatchKeySet(ks, nonFallthroughKeys_);
}
}

void setOperatorHasFallthroughForKey(DispatchKey k, bool has_fallthrough);
Expand Down Expand Up @@ -193,7 +203,12 @@ struct TORCH_API DispatchKeyExtractor final {

explicit DispatchKeyExtractor(c10::utils::bitset dispatch_arg_indices_reverse)
: dispatch_arg_indices_reverse_(dispatch_arg_indices_reverse)
, nonFallthroughKeys_(DispatchKeySet::FULL) {}
, nonFallthroughKeys_(DispatchKeySet::FULL)
, requiresBitsetPerBackend_(false) {
for (const auto i : c10::irange(nonFallthroughKeysPerBackend_.size())) {
nonFallthroughKeysPerBackend_[i] = DispatchKeySet::FULL;
}
}

// this is a bitset that has ones for each argument index which has to be
// considered for dispatch. This avoids having to iterate over the stack
Expand All @@ -205,8 +220,14 @@ struct TORCH_API DispatchKeyExtractor final {
// fallthrough
c10::utils::bitset dispatch_arg_indices_reverse_;

// Set of keys for which the operator does NOT have fallthrough kernel.
// Set of functionality keys for which the operator does NOT have fallthrough kernel.
DispatchKeySet nonFallthroughKeys_;
// Set of functionality keys for which the operator does NOT have fallthrough kernel, defined PER BACKEND.
// This is only needed if we know that the operator has a different set of fallthroughs defined for some backends.
std::array<DispatchKeySet, num_backends> nonFallthroughKeysPerBackend_;
// Flag to tell us if we can use the single set of nonFallthroughKeys_ (fast path),
// or if we need to fall back to the slower path and check nonFallthroughKeysPerBackend_
bool requiresBitsetPerBackend_;
};

}
10 changes: 6 additions & 4 deletions aten/src/ATen/core/dispatch/Dispatcher.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -267,14 +267,15 @@ void Dispatcher::cleanup(const OperatorHandle& op, const OperatorName& op_name)
RegistrationHandleRAII Dispatcher::registerFallback(DispatchKey dispatchKey, KernelFunction kernel, std::string debug) {
std::lock_guard<std::mutex> lock(mutex_);

auto idx = getDispatchTableIndexForDispatchKey(dispatchKey);
TORCH_CHECK(
!backendFallbackKernels_[static_cast<uint8_t>(dispatchKey)].kernel.isValid(),
!backendFallbackKernels_[idx].kernel.isValid(),
"Tried to register multiple backend fallbacks for the same dispatch key ", dispatchKey, "; previous registration ",
backendFallbackKernels_[static_cast<uint8_t>(dispatchKey)].debug, ", new registration ", debug
backendFallbackKernels_[idx].debug, ", new registration ", debug
);
// NB: inferred function schema is always nullptr for fallbacks, as fallbacks
// cannot be unobxed
backendFallbackKernels_[static_cast<uint8_t>(dispatchKey)] = impl::AnnotatedKernel(std::move(kernel), nullptr, std::move(debug));
backendFallbackKernels_[idx] = impl::AnnotatedKernel(std::move(kernel), nullptr, std::move(debug));

for (auto& op : operators_) {
op.op.updateFallback(*this, dispatchKey);
Expand All @@ -288,7 +289,8 @@ RegistrationHandleRAII Dispatcher::registerFallback(DispatchKey dispatchKey, Ker
void Dispatcher::deregisterFallback_(DispatchKey dispatchKey) {
std::lock_guard<std::mutex> lock(mutex_);

backendFallbackKernels_[static_cast<uint8_t>(dispatchKey)] = {};
auto idx = getDispatchTableIndexForDispatchKey(dispatchKey);
backendFallbackKernels_[idx] = {};

for (auto& op : operators_) {
op.op.updateFallback(*this, dispatchKey);
Expand Down
11 changes: 5 additions & 6 deletions aten/src/ATen/core/dispatch/Dispatcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ class TORCH_API Dispatcher final {
// Map from namespace to debug string (saying, e.g., where the library was defined)
ska::flat_hash_map<std::string, std::string> libraries_;

std::array<impl::AnnotatedKernel, static_cast<uint8_t>(DispatchKey::NumDispatchKeys)> backendFallbackKernels_;
std::array<impl::AnnotatedKernel, num_runtime_entries> backendFallbackKernels_;

std::unique_ptr<detail::RegistrationListenerList> listeners_;
std::mutex mutex_;
Expand Down Expand Up @@ -531,8 +531,7 @@ C10_DISPATCHER_INLINE_UNLESS_MOBILE Return Dispatcher::call(const TypedOperatorH
detail::unused_arg_(args...); // workaround for a false-positive warning about unused parameters in gcc 5
auto dispatchKeySet = op.operatorDef_->op.dispatchKeyExtractor()
.template getDispatchKeySetUnboxed<Args...>(args...);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!c10::isAliasDispatchKey(dispatchKeySet.highestPriorityTypeId()));
const KernelFunction& kernel = op.operatorDef_->op.lookup(dispatchKeySet.highestPriorityTypeId());
const KernelFunction& kernel = op.operatorDef_->op.lookup(dispatchKeySet);
#ifndef PYTORCH_DISABLE_PER_OP_PROFILING
// By default, when there're no high-frequency or non-sampled callbacks,
// RecordFunction is pre-sampled as a perf optimization;
Expand All @@ -553,15 +552,15 @@ template<class Return, class... Args>
inline Return Dispatcher::redispatch(const TypedOperatorHandle<Return (Args...)>& op, DispatchKeySet currentDispatchKeySet, Args... args) const {
detail::unused_arg_(args...); // workaround for a false-positive warning about unused parameters in gcc 5
// do not use RecordFunction on redispatch
const KernelFunction& kernel = op.operatorDef_->op.lookup(currentDispatchKeySet.highestPriorityTypeId());
const KernelFunction& kernel = op.operatorDef_->op.lookup(currentDispatchKeySet);
return kernel.template call<Return, Args...>(op, currentDispatchKeySet, std::forward<Args>(args)...);
}

inline void Dispatcher::callBoxed(const OperatorHandle& op, Stack* stack) const {
// note: this doesn't need the mutex because write operations on the list keep iterators intact.
const auto& entry = op.operatorDef_->op;
auto dispatchKeySet = entry.dispatchKeyExtractor().getDispatchKeySetBoxed(stack);
const auto& kernel = entry.lookup(dispatchKeySet.highestPriorityTypeId());
const auto& kernel = entry.lookup(dispatchKeySet);
#ifndef PYTORCH_DISABLE_PER_OP_PROFILING
bool pre_sampled = false;
if (C10_UNLIKELY(at::shouldRunRecordFunction(&pre_sampled))) {
Expand Down Expand Up @@ -593,7 +592,7 @@ inline void Dispatcher::callBoxed(const OperatorHandle& op, Stack* stack) const
inline void Dispatcher::redispatchBoxed(const OperatorHandle& op, DispatchKeySet dispatchKeySet, Stack* stack) const {
// note: this doesn't need the mutex because write operations on the list keep iterators intact.
const auto& entry = op.operatorDef_->op;
const auto& kernel = entry.lookup(dispatchKeySet.highestPriorityTypeId());
const auto& kernel = entry.lookup(dispatchKeySet);
return kernel.callBoxed(op, dispatchKeySet, stack);
}

Expand Down
39 changes: 27 additions & 12 deletions aten/src/ATen/core/dispatch/OperatorEntry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,8 @@ std::pair<const AnnotatedKernel&, const char*> OperatorEntry::computeDispatchTab
}

// 3. Backend fallback
auto dispatch_ix = static_cast<uint8_t>(dispatch_key);
auto dispatch_ix = getDispatchTableIndexForDispatchKey(dispatch_key);
TORCH_INTERNAL_ASSERT(dispatch_ix != -1);
if (dispatcher.backendFallbackKernels_[dispatch_ix].kernel.isValid()) {
return {dispatcher.backendFallbackKernels_[dispatch_ix], "backend fallback"};
}
Expand All @@ -299,7 +300,7 @@ std::pair<const AnnotatedKernel&, const char*> OperatorEntry::computeDispatchTab
// or alias keys and their associated keysets).
// This function should be considered a private helper for updateDispatchTable_()
void OperatorEntry::updateDispatchTableEntry_(const c10::Dispatcher& dispatcher, DispatchKey dispatch_key) {
const auto dispatch_ix = c10::getDispatchTableIndexForDispatchKey(dispatch_key);
const auto dispatch_ix = getDispatchTableIndexForDispatchKey(dispatch_key);
if (C10_UNLIKELY(dispatch_ix == -1)) {
return;
}
Expand Down Expand Up @@ -329,8 +330,12 @@ void OperatorEntry::updateDispatchTable_(const c10::Dispatcher& dispatcher, Disp
}
// Note [Refresh Runtime Autograd entries in dispatchTable_]
// Registering to backend key might affect computed entry at its Autograd backend key due to (2.1) & (2.3).
// In theory, we should only have to check if the given runtime key has "dense" functionality,
// e.g. DispatchKey::CPU (which is composed of DispatchKey::Dense and BackendComponent::CPUBit).
// However, there are some backends that should be included in this set that don't have the dense key set.
// E.g. DispatchKey::Meta, DispatchKey::ORT.
if (c10::isBackendDispatchKey(dispatch_key)) {
DispatchKey autograd_key = getAutogradKeyFromBackend(dispatch_key);
DispatchKey autograd_key = getAutogradKeyFromBackend(toBackendComponent(dispatch_key));
updateDispatchTableEntry_(dispatcher, autograd_key);
}
}
Expand All @@ -357,8 +362,9 @@ void OperatorEntry::updateDispatchTableFull_(const c10::Dispatcher& dispatcher)
// catchAll. After catchAllKernel_ is removed, Undefined now can get a kernel from either CompositeExplicitAutograd
// or CompositeImplicitAutograd alias key so that we don't break the support. Ideally isIncludedInAlias(Undefined, CompositeImplicitAutograd)
// should return true, it returns false because Undefined cannot be represented in a DispatchKeySet.
for (uint8_t iter = 0; iter != static_cast<uint8_t>(DispatchKey::NumDispatchKeys); ++iter) {
updateDispatchTable_(dispatcher, static_cast<DispatchKey>(iter));
updateDispatchTable_(dispatcher, DispatchKey::Undefined);
for (auto k : DispatchKeySet(DispatchKeySet::FULL)) {
updateDispatchTable_(dispatcher, k);
}
}

Expand All @@ -371,9 +377,13 @@ void OperatorEntry::checkInvariants() const {
for (const auto& kv : kernels_) {
TORCH_INTERNAL_ASSERT(kv.second.size() > 0, dumpState());
}
for (uint8_t iter = 0; iter != static_cast<uint8_t>(DispatchKey::NumDispatchKeys); ++iter) {
auto expected_k = computeDispatchTableEntry(c10::Dispatcher::singleton(), static_cast<DispatchKey>(iter));
TORCH_INTERNAL_ASSERT(expected_k._equalsBoxedAndUnboxed(dispatchTable_[iter]),
for (auto k : DispatchKeySet(DispatchKeySet::FULL)) {
auto expected_k = computeDispatchTableEntry(c10::Dispatcher::singleton(), k);
auto idx = getDispatchTableIndexForDispatchKey(k);
if (C10_UNLIKELY(idx == -1)) {
continue;
}
TORCH_INTERNAL_ASSERT(expected_k._equalsBoxedAndUnboxed(dispatchTable_[idx]),
"Canonical state\n~~~~~~~~~~~\n", dumpState(), "\n\n"
"Computed table:\n~~~~~~~~~~~\n", dumpComputedTable());
}
Expand All @@ -384,8 +394,9 @@ std::string OperatorEntry::listAllDispatchKeys() const {
str << "[";

bool has_kernels = false;
for (uint8_t iter = 0; iter != static_cast<uint8_t>(DispatchKey::NumDispatchKeys); ++iter) {
if (!dispatchTable_[iter].isValid()) {
for (auto k : DispatchKeySet(DispatchKeySet::FULL)) {
auto iter = getDispatchTableIndexForDispatchKey(k);
if (iter == -1 || !dispatchTable_[iter].isValid()) {
continue;
}
if (has_kernels) {
Expand Down Expand Up @@ -443,8 +454,12 @@ void OperatorEntry::reportError(DispatchKey dispatchKey) const {
// updateDispatchTableFull_ would update the dispatch table to be)
std::string OperatorEntry::dumpComputedTable() const {
std::ostringstream oss;
for (uint8_t i = 0; i < static_cast<uint8_t>(DispatchKey::NumDispatchKeys); i++) {
auto k = static_cast<DispatchKey>(i);
// Need to handle Undefined separately, because its a runtime key that can't be represented
// in a DispatchKeySet.
std::vector<DispatchKey> runtime_keys = {DispatchKey::Undefined};
for (auto k : DispatchKeySet(DispatchKeySet::FULL)) runtime_keys.push_back(k);

for (auto k : runtime_keys) {
auto kernel_prov = computeDispatchTableEntryWithDebug(c10::Dispatcher::singleton(), k);
if (kernel_prov.first.kernel.isValid()) {
oss << toString(k) << ": "
Expand Down
10 changes: 5 additions & 5 deletions aten/src/ATen/core/dispatch/OperatorEntry.h
Original file line number Diff line number Diff line change
Expand Up @@ -173,10 +173,10 @@ class TORCH_API OperatorEntry final {

[[noreturn]] void reportError(DispatchKey dispatchKey) const;

const KernelFunction& lookup(DispatchKey k) const {
const auto idx = getDispatchTableIndexForDispatchKey(k);
const KernelFunction& lookup(DispatchKeySet ks) const {
const auto idx = ks.getDispatchTableIndexForDispatchKeySet();
if (C10_UNLIKELY(idx == -1)) {
reportError(k);
reportError(ks.highestPriorityTypeId());
}
const auto& kernel = dispatchTable_[idx];
// A valid kernel *always* has a boxed kernel and *may* have an
Expand All @@ -187,7 +187,7 @@ class TORCH_API OperatorEntry final {
// in the common case.
if (C10_UNLIKELY(!kernel.isValidUnboxed())) {
if (!kernel.isValid()) {
reportError(k);
reportError(ks.highestPriorityTypeId());
}
}
return kernel;
Expand All @@ -211,7 +211,7 @@ class TORCH_API OperatorEntry final {
OperatorName name_;
c10::optional<AnnotatedSchema> schema_;

std::array<KernelFunction, c10::getDispatchTableIndexForDispatchKey(DispatchKey::NumDispatchKeys)> dispatchTable_;
std::array<KernelFunction, c10::num_runtime_entries> dispatchTable_;
DispatchKeyExtractor dispatchKeyExtractor_;

// kernels_ stores all registered kernels for the corresponding dispatch key
Expand Down
Loading

0 comments on commit 2cbddc0

Please sign in to comment.