Skip to content

Commit

Permalink
DispatchKeySet perf improvements (pytorch#72828)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#72828

Reland of D34034847 (pytorch@8aa3620)
ghstack-source-id: 152161453

Test Plan: confirm that Milan tests are passing

Reviewed By: ezyang, albanD

Differential Revision: D34227615

fbshipit-source-id: c7695e16dba3076e8ab9df8654327c5d57e92c77
(cherry picked from commit 940717d)
  • Loading branch information
bdhirsh authored and pytorchmergebot committed Mar 25, 2022
1 parent 2cbddc0 commit c0491c9
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 44 deletions.
3 changes: 2 additions & 1 deletion aten/src/ATen/native/quantized/cpu/fbgemm_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,10 @@ Tensor MakeStridedQTensorCPU(
allocator->allocate(size_bytes),
allocator,
/* resizable = */ true);
constexpr auto quantized_cpu_ks = at::DispatchKeySet(at::DispatchKey::QuantizedCPU);
auto tensor = detail::make_tensor<QTensorImpl>(
storage,
at::DispatchKeySet(at::DispatchKey::QuantizedCPU),
quantized_cpu_ks,
dtype,
quantizer);
get_qtensorimpl(tensor)->set_sizes_and_strides(sizes, strides);
Expand Down
12 changes: 12 additions & 0 deletions c10/core/DispatchKeySet.h
Original file line number Diff line number Diff line change
Expand Up @@ -646,6 +646,18 @@ constexpr DispatchKeySet default_excluded_set = DispatchKeySet({
constexpr DispatchKeySet autograd_dispatch_keyset_with_ADInplaceOrView =
autograd_dispatch_keyset | DispatchKeySet(DispatchKey::ADInplaceOrView);

constexpr DispatchKeySet python_ks = DispatchKeySet({
DispatchKey::Python,
DispatchKey::PythonTLSSnapshot,
});

constexpr DispatchKeySet sparse_ks = DispatchKeySet(DispatchKey::Sparse);

constexpr DispatchKeySet sparse_csr_ks =
DispatchKeySet({DispatchKey::SparseCsrCPU, DispatchKey::SparseCsrCUDA});

constexpr DispatchKeySet mkldnn_ks = DispatchKeySet(DispatchKey::MkldnnCPU);

// backend dispatch keys that map to DispatchKey::AutogradOther
// NB: keys in this set also get associated with CompositeImplicitAutograd
constexpr DispatchKeySet autogradother_backends =
Expand Down
11 changes: 3 additions & 8 deletions c10/core/TensorImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,10 +148,7 @@ TensorImpl::TensorImpl(
numel_(0),
data_type_(data_type),
device_opt_(storage_.device()),
key_set_(
key_set.remove(DispatchKey::Python)
.remove(DispatchKey::PythonTLSSnapshot)) { // See [Note: Python
// key removal]
key_set_(key_set - c10::python_ks) { // See [Note: Python key removal]
init_bitfields();
// Inference tensor doesn't have version counter.
if (!is_inference()) {
Expand Down Expand Up @@ -196,10 +193,8 @@ TensorImpl::TensorImpl(

key_set = key_set | getAutocastRelatedKeySetFromBackend(k);

key_set =
key_set.remove(DispatchKey::Python)
.remove(
DispatchKey::PythonTLSSnapshot); // See [Note: Python key removal]
// See [Note: Python key removal]
key_set = key_set - c10::python_ks;

// Inference tensor doesn't have autograd related keys.
if (inference_mode) {
Expand Down
91 changes: 56 additions & 35 deletions c10/core/TensorImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -854,91 +854,103 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
bool is_sparse() const {
// NB: This method is not virtual and avoid dispatches for performance
// reasons.
return key_set_.has(DispatchKey::Sparse);
return key_set_.has_all(c10::sparse_ks);
}

// Whether a tensor is sparse COO or not. Use is_sparse_csr for checking CSR
// format.
bool is_sparse_csr() const {
return key_set_.has(DispatchKey::SparseCsrCPU) ||
key_set_.has(DispatchKey::SparseCsrCUDA);
return key_set_.has_any(c10::sparse_csr_ks);
}

bool is_quantized() const {
// NB: This method is not virtual and avoid dispatches for performance
// reasons.
return key_set_.has(DispatchKey::Quantized);
constexpr auto quantized_ks = DispatchKeySet(DispatchKey::Quantized);
return key_set_.has_all(quantized_ks);
}

bool is_meta() const {
// NB: This method is not virtual and avoid dispatches for performance
// reasons.
return key_set_.has(DispatchKey::Meta);
constexpr auto meta_ks = DispatchKeySet(DispatchKey::Meta);
return key_set_.has_all(meta_ks);
}

bool is_cpu() const {
// NB: This method is not virtual and avoid dispatches for performance
// reasons.
return key_set_.has_backend(BackendComponent::CPUBit) ||
key_set_.has(DispatchKey::SparseCsrCPU) ||
key_set_.has(DispatchKey::MkldnnCPU);
constexpr auto cpu_bits_ks = DispatchKeySet(BackendComponent::CPUBit) |
DispatchKeySet({DispatchKey::SparseCsrCPU, DispatchKey::MkldnnCPU});
return key_set_.has_any(cpu_bits_ks);
}

bool is_cuda() const {
// NB: This method is not virtual and avoid dispatches for performance
// reasons.
return key_set_.has_backend(BackendComponent::CUDABit) ||
key_set_.has(DispatchKey::SparseCsrCUDA);
constexpr auto cuda_bits_ks = DispatchKeySet(BackendComponent::CUDABit) |
DispatchKeySet(DispatchKey::SparseCsrCUDA);
return key_set_.has_any(cuda_bits_ks);
}

bool is_xpu() const {
// NB: This method is not virtual and avoid dispatches for performance
// reasons.
return key_set_.has_backend(BackendComponent::XPUBit);
constexpr auto xpu_ks = DispatchKeySet(BackendComponent::XPUBit);
return key_set_.has_all(xpu_ks);
}

bool is_xla() const {
return key_set_.has_backend(BackendComponent::XLABit);
constexpr auto xla_ks = DispatchKeySet(BackendComponent::XLABit);
return key_set_.has_all(xla_ks);
}

bool is_hpu() const {
return key_set_.has_backend(BackendComponent::HPUBit);
constexpr auto hpu_ks = DispatchKeySet(BackendComponent::HPUBit);
return key_set_.has_all(hpu_ks);
}

bool is_lazy() const {
return key_set_.has_backend(BackendComponent::LazyBit);
constexpr auto lazy_ks = DispatchKeySet(BackendComponent::LazyBit);
return key_set_.has_all(lazy_ks);
}

bool is_hip() const {
// NB: This method is not virtual and avoid dispatches for performance
// reasons.
return key_set_.has_backend(BackendComponent::HIPBit);
constexpr auto hip_ks = DispatchKeySet(BackendComponent::HIPBit);
return key_set_.has_all(hip_ks);
}

bool is_ve() const {
// NB: This method is not virtual and avoid dispatches for performance
// reasons.
return key_set_.has_backend(BackendComponent::VEBit);
constexpr auto ve_ks = DispatchKeySet(BackendComponent::VEBit);
return key_set_.has_all(ve_ks);
}

bool is_mkldnn() const {
return key_set_.has(DispatchKey::MkldnnCPU);
return key_set_.has_all(c10::mkldnn_ks);
}

bool is_vulkan() const {
return key_set_.has(DispatchKey::Vulkan);
constexpr auto vulkan_ks = DispatchKeySet(DispatchKey::Vulkan);
return key_set_.has_all(vulkan_ks);
}

bool is_metal() const {
return key_set_.has(DispatchKey::Metal);
constexpr auto metal_ks = DispatchKeySet(DispatchKey::Metal);
return key_set_.has_all(metal_ks);
}

bool is_mlc() const {
return key_set_.has(DispatchKey::MLC);
constexpr auto mls_ks = DispatchKeySet(DispatchKey::MLC);
return key_set_.has_all(mls_ks);
}

bool is_ort() const {
return key_set_.has(DispatchKey::ORT);
constexpr auto ort_ks = DispatchKeySet(DispatchKey::ORT);
return key_set_.has_all(ort_ks);
}

bool is_nested() const {
Expand All @@ -958,8 +970,8 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
// Invariant:
// Inference tensor has version_counter_.enabled() == false
bool is_inference() {
bool no_ADInplaceOrView = !key_set_.has(c10::DispatchKey::ADInplaceOrView);
bool no_Autograd = (key_set_ & c10::autograd_dispatch_keyset).empty();
bool no_ADInplaceOrView = !key_set_.has_any(c10::inplace_or_view_ks);
bool no_Autograd = !key_set_.has_any(c10::autograd_dispatch_keyset);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
no_ADInplaceOrView == no_Autograd,
"ADInplaceOrView and Autograd keys must be on/off at the same time.");
Expand All @@ -980,14 +992,22 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {

Layout layout() const {
// NB: This method is not virtual and avoid dispatches for perf.
if (is_sparse()) {
// strided is also the most common layout type, so we check for
// strided case first.
// This keyset must also be kept in sync with the logic in
// is_sparse() / is_sparse_csr() / is_mkldnn()
constexpr auto sparse_and_sparsecsr_and_mkldnn_ks =
c10::sparse_ks | c10::sparse_csr_ks | c10::mkldnn_ks;
if (!key_set_.has_any(sparse_and_sparsecsr_and_mkldnn_ks)) {
return kStrided;
} else if (is_sparse()) {
return kSparse;
} else if (is_sparse_csr()) {
return kSparseCsr;
} else if (is_mkldnn()) {
return kMkldnn;
} else {
return kStrided;
TORCH_INTERNAL_ASSERT(
is_mkldnn(), "There is an error in the layout calculation logic.");
return kMkldnn;
}
}

Expand Down Expand Up @@ -1073,7 +1093,8 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
* Whether or not the imaginary part of the tensor should be negated
*/
inline bool is_conj() const {
return key_set_.has(DispatchKey::Conjugate);
constexpr auto conjugate_ks = DispatchKeySet(DispatchKey::Conjugate);
return key_set_.has_all(conjugate_ks);
}

/**
Expand All @@ -1093,7 +1114,8 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
* Whether or not the tensor is a zerotensor
*/
inline bool _is_zerotensor() const {
return key_set_.has(DispatchKey::ZeroTensor);
constexpr auto zerotensor_ks = DispatchKeySet(DispatchKey::ZeroTensor);
return key_set_.has_all(zerotensor_ks);
}

/**
Expand All @@ -1113,7 +1135,8 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
* Whether or not the tensor should be negated
*/
inline bool is_neg() const {
return key_set_.has(DispatchKey::Negative);
constexpr auto negative_ks = DispatchKeySet(DispatchKey::Negative);
return key_set_.has_all(negative_ks);
}

/**
Expand Down Expand Up @@ -1484,16 +1507,14 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {

void set_python_dispatch(bool k) {
if (k) {
key_set_ =
key_set_.add(DispatchKey::Python).add(DispatchKey::PythonTLSSnapshot);
key_set_ = key_set_.add(c10::python_ks);
} else {
key_set_ = key_set_.remove(DispatchKey::Python)
.remove(DispatchKey::PythonTLSSnapshot);
key_set_ = key_set_ - c10::python_ks;
}
}

bool is_python_dispatch() const {
return key_set_.has(DispatchKey::Python);
return key_set_.has_all(c10::python_ks);
}

/**
Expand Down

0 comments on commit c0491c9

Please sign in to comment.