diff --git a/aten/src/ATen/native/quantized/cpu/fbgemm_utils.cpp b/aten/src/ATen/native/quantized/cpu/fbgemm_utils.cpp index 369ee6744624cc..0a8334b96f7071 100644 --- a/aten/src/ATen/native/quantized/cpu/fbgemm_utils.cpp +++ b/aten/src/ATen/native/quantized/cpu/fbgemm_utils.cpp @@ -160,9 +160,10 @@ Tensor MakeStridedQTensorCPU( allocator->allocate(size_bytes), allocator, /* resizable = */ true); + constexpr auto quantized_cpu_ks = at::DispatchKeySet(at::DispatchKey::QuantizedCPU); auto tensor = detail::make_tensor( storage, - at::DispatchKeySet(at::DispatchKey::QuantizedCPU), + quantized_cpu_ks, dtype, quantizer); get_qtensorimpl(tensor)->set_sizes_and_strides(sizes, strides); diff --git a/c10/core/DispatchKeySet.h b/c10/core/DispatchKeySet.h index 821b701022728e..059c0e450ff82c 100644 --- a/c10/core/DispatchKeySet.h +++ b/c10/core/DispatchKeySet.h @@ -646,6 +646,18 @@ constexpr DispatchKeySet default_excluded_set = DispatchKeySet({ constexpr DispatchKeySet autograd_dispatch_keyset_with_ADInplaceOrView = autograd_dispatch_keyset | DispatchKeySet(DispatchKey::ADInplaceOrView); +constexpr DispatchKeySet python_ks = DispatchKeySet({ + DispatchKey::Python, + DispatchKey::PythonTLSSnapshot, +}); + +constexpr DispatchKeySet sparse_ks = DispatchKeySet(DispatchKey::Sparse); + +constexpr DispatchKeySet sparse_csr_ks = + DispatchKeySet({DispatchKey::SparseCsrCPU, DispatchKey::SparseCsrCUDA}); + +constexpr DispatchKeySet mkldnn_ks = DispatchKeySet(DispatchKey::MkldnnCPU); + // backend dispatch keys that map to DispatchKey::AutogradOther // NB: keys in this set also get associated with CompositeImplicitAutograd constexpr DispatchKeySet autogradother_backends = diff --git a/c10/core/TensorImpl.cpp b/c10/core/TensorImpl.cpp index 158dfc590e496f..ebe96ee9412fb4 100644 --- a/c10/core/TensorImpl.cpp +++ b/c10/core/TensorImpl.cpp @@ -148,10 +148,7 @@ TensorImpl::TensorImpl( numel_(0), data_type_(data_type), device_opt_(storage_.device()), - key_set_( - key_set.remove(DispatchKey::Python) - .remove(DispatchKey::PythonTLSSnapshot)) { // See [Note: Python - // key removal] + key_set_(key_set - c10::python_ks) { // See [Note: Python key removal] init_bitfields(); // Inference tensor doesn't have version counter. if (!is_inference()) { @@ -196,10 +193,8 @@ TensorImpl::TensorImpl( key_set = key_set | getAutocastRelatedKeySetFromBackend(k); - key_set = - key_set.remove(DispatchKey::Python) - .remove( - DispatchKey::PythonTLSSnapshot); // See [Note: Python key removal] + // See [Note: Python key removal] + key_set = key_set - c10::python_ks; // Inference tensor doesn't have autograd related keys. if (inference_mode) { diff --git a/c10/core/TensorImpl.h b/c10/core/TensorImpl.h index bec47f3acba714..a6fc1eade770de 100644 --- a/c10/core/TensorImpl.h +++ b/c10/core/TensorImpl.h @@ -854,91 +854,103 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { bool is_sparse() const { // NB: This method is not virtual and avoid dispatches for performance // reasons. - return key_set_.has(DispatchKey::Sparse); + return key_set_.has_all(c10::sparse_ks); } // Whether a tensor is sparse COO or not. Use is_sparse_csr for checking CSR // format. bool is_sparse_csr() const { - return key_set_.has(DispatchKey::SparseCsrCPU) || - key_set_.has(DispatchKey::SparseCsrCUDA); + return key_set_.has_any(c10::sparse_csr_ks); } bool is_quantized() const { // NB: This method is not virtual and avoid dispatches for performance // reasons. - return key_set_.has(DispatchKey::Quantized); + constexpr auto quantized_ks = DispatchKeySet(DispatchKey::Quantized); + return key_set_.has_all(quantized_ks); } bool is_meta() const { // NB: This method is not virtual and avoid dispatches for performance // reasons. - return key_set_.has(DispatchKey::Meta); + constexpr auto meta_ks = DispatchKeySet(DispatchKey::Meta); + return key_set_.has_all(meta_ks); } bool is_cpu() const { // NB: This method is not virtual and avoid dispatches for performance // reasons. - return key_set_.has_backend(BackendComponent::CPUBit) || - key_set_.has(DispatchKey::SparseCsrCPU) || - key_set_.has(DispatchKey::MkldnnCPU); + constexpr auto cpu_bits_ks = DispatchKeySet(BackendComponent::CPUBit) | + DispatchKeySet({DispatchKey::SparseCsrCPU, DispatchKey::MkldnnCPU}); + return key_set_.has_any(cpu_bits_ks); } bool is_cuda() const { // NB: This method is not virtual and avoid dispatches for performance // reasons. - return key_set_.has_backend(BackendComponent::CUDABit) || - key_set_.has(DispatchKey::SparseCsrCUDA); + constexpr auto cuda_bits_ks = DispatchKeySet(BackendComponent::CUDABit) | + DispatchKeySet(DispatchKey::SparseCsrCUDA); + return key_set_.has_any(cuda_bits_ks); } bool is_xpu() const { // NB: This method is not virtual and avoid dispatches for performance // reasons. - return key_set_.has_backend(BackendComponent::XPUBit); + constexpr auto xpu_ks = DispatchKeySet(BackendComponent::XPUBit); + return key_set_.has_all(xpu_ks); } bool is_xla() const { - return key_set_.has_backend(BackendComponent::XLABit); + constexpr auto xla_ks = DispatchKeySet(BackendComponent::XLABit); + return key_set_.has_all(xla_ks); } bool is_hpu() const { - return key_set_.has_backend(BackendComponent::HPUBit); + constexpr auto hpu_ks = DispatchKeySet(BackendComponent::HPUBit); + return key_set_.has_all(hpu_ks); } bool is_lazy() const { - return key_set_.has_backend(BackendComponent::LazyBit); + constexpr auto lazy_ks = DispatchKeySet(BackendComponent::LazyBit); + return key_set_.has_all(lazy_ks); } bool is_hip() const { // NB: This method is not virtual and avoid dispatches for performance // reasons. - return key_set_.has_backend(BackendComponent::HIPBit); + constexpr auto hip_ks = DispatchKeySet(BackendComponent::HIPBit); + return key_set_.has_all(hip_ks); } bool is_ve() const { // NB: This method is not virtual and avoid dispatches for performance // reasons. - return key_set_.has_backend(BackendComponent::VEBit); + constexpr auto ve_ks = DispatchKeySet(BackendComponent::VEBit); + return key_set_.has_all(ve_ks); } bool is_mkldnn() const { - return key_set_.has(DispatchKey::MkldnnCPU); + return key_set_.has_all(c10::mkldnn_ks); } bool is_vulkan() const { - return key_set_.has(DispatchKey::Vulkan); + constexpr auto vulkan_ks = DispatchKeySet(DispatchKey::Vulkan); + return key_set_.has_all(vulkan_ks); } bool is_metal() const { - return key_set_.has(DispatchKey::Metal); + constexpr auto metal_ks = DispatchKeySet(DispatchKey::Metal); + return key_set_.has_all(metal_ks); } bool is_mlc() const { - return key_set_.has(DispatchKey::MLC); + constexpr auto mls_ks = DispatchKeySet(DispatchKey::MLC); + return key_set_.has_all(mls_ks); } bool is_ort() const { - return key_set_.has(DispatchKey::ORT); + constexpr auto ort_ks = DispatchKeySet(DispatchKey::ORT); + return key_set_.has_all(ort_ks); } bool is_nested() const { @@ -958,8 +970,8 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { // Invariant: // Inference tensor has version_counter_.enabled() == false bool is_inference() { - bool no_ADInplaceOrView = !key_set_.has(c10::DispatchKey::ADInplaceOrView); - bool no_Autograd = (key_set_ & c10::autograd_dispatch_keyset).empty(); + bool no_ADInplaceOrView = !key_set_.has_any(c10::inplace_or_view_ks); + bool no_Autograd = !key_set_.has_any(c10::autograd_dispatch_keyset); TORCH_INTERNAL_ASSERT_DEBUG_ONLY( no_ADInplaceOrView == no_Autograd, "ADInplaceOrView and Autograd keys must be on/off at the same time."); @@ -980,14 +992,22 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { Layout layout() const { // NB: This method is not virtual and avoid dispatches for perf. - if (is_sparse()) { + // strided is also the most common layout type, so we check for + // strided case first. + // This keyset must also be kept in sync with the logic in + // is_sparse() / is_sparse_csr() / is_mkldnn() + constexpr auto sparse_and_sparsecsr_and_mkldnn_ks = + c10::sparse_ks | c10::sparse_csr_ks | c10::mkldnn_ks; + if (!key_set_.has_any(sparse_and_sparsecsr_and_mkldnn_ks)) { + return kStrided; + } else if (is_sparse()) { return kSparse; } else if (is_sparse_csr()) { return kSparseCsr; - } else if (is_mkldnn()) { - return kMkldnn; } else { - return kStrided; + TORCH_INTERNAL_ASSERT( + is_mkldnn(), "There is an error in the layout calculation logic."); + return kMkldnn; } } @@ -1073,7 +1093,8 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { * Whether or not the imaginary part of the tensor should be negated */ inline bool is_conj() const { - return key_set_.has(DispatchKey::Conjugate); + constexpr auto conjugate_ks = DispatchKeySet(DispatchKey::Conjugate); + return key_set_.has_all(conjugate_ks); } /** @@ -1093,7 +1114,8 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { * Whether or not the tensor is a zerotensor */ inline bool _is_zerotensor() const { - return key_set_.has(DispatchKey::ZeroTensor); + constexpr auto zerotensor_ks = DispatchKeySet(DispatchKey::ZeroTensor); + return key_set_.has_all(zerotensor_ks); } /** @@ -1113,7 +1135,8 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { * Whether or not the tensor should be negated */ inline bool is_neg() const { - return key_set_.has(DispatchKey::Negative); + constexpr auto negative_ks = DispatchKeySet(DispatchKey::Negative); + return key_set_.has_all(negative_ks); } /** @@ -1484,16 +1507,14 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { void set_python_dispatch(bool k) { if (k) { - key_set_ = - key_set_.add(DispatchKey::Python).add(DispatchKey::PythonTLSSnapshot); + key_set_ = key_set_.add(c10::python_ks); } else { - key_set_ = key_set_.remove(DispatchKey::Python) - .remove(DispatchKey::PythonTLSSnapshot); + key_set_ = key_set_ - c10::python_ks; } } bool is_python_dispatch() const { - return key_set_.has(DispatchKey::Python); + return key_set_.has_all(c10::python_ks); } /**