From 597523928a820f77f93681c779e2db37a6226771 Mon Sep 17 00:00:00 2001 From: Lunderberg Date: Thu, 24 Jun 2021 12:19:55 -0700 Subject: [PATCH] [Vulkan] Implement sync for SyncThread("warp") (#8320) - Add sync if a SyncThread("warp") node is present. The sync is done at spv::ScopeSubgroup if supported (Vulkan 1.1+), and at spv::ScopeWorkgroup otherwise. Co-authored-by: Eric Lunderberg --- src/target/spirv/build_vulkan.cc | 21 ++++++++++++-- src/target/spirv/codegen_spirv.cc | 29 ++++++++++++-------- src/target/spirv/spirv_support.cc | 4 +++ src/target/spirv/spirv_support.h | 14 ++++++++++ src/tir/transforms/lower_thread_allreduce.cc | 2 +- 5 files changed, 56 insertions(+), 14 deletions(-) diff --git a/src/target/spirv/build_vulkan.cc b/src/target/spirv/build_vulkan.cc index d24bf3c02186..c19b71d1540b 100644 --- a/src/target/spirv/build_vulkan.cc +++ b/src/target/spirv/build_vulkan.cc @@ -36,7 +36,24 @@ namespace codegen { class SPIRVTools { public: - SPIRVTools() { ctx_ = spvContextCreate(SPV_ENV_VULKAN_1_0); } + explicit SPIRVTools(Target target) { + uint32_t vulkan_version = + target->GetAttr("vulkan_api_version").value_or(VK_API_VERSION_1_0); + uint32_t spirv_version = target->GetAttr("max_spirv_version").value_or(0x10000); + + spv_target_env validation_version; + if (vulkan_version >= VK_API_VERSION_1_2) { + validation_version = SPV_ENV_VULKAN_1_2; + } else if (vulkan_version >= VK_API_VERSION_1_1 && spirv_version >= 0x10400) { + validation_version = SPV_ENV_VULKAN_1_1_SPIRV_1_4; + } else if (vulkan_version >= VK_API_VERSION_1_1) { + validation_version = SPV_ENV_VULKAN_1_1; + } else { + validation_version = SPV_ENV_VULKAN_1_0; + } + + ctx_ = spvContextCreate(validation_version); + } ~SPIRVTools() { spvContextDestroy(ctx_); } std::string BinaryToText(const std::vector& bin) { spv_text text = nullptr; @@ -80,7 +97,7 @@ runtime::Module BuildSPIRV(IRModule mod, Target target, bool webgpu_restriction) using tvm::runtime::VulkanShader; std::ostringstream code_data; - static SPIRVTools spirv_tools; + SPIRVTools spirv_tools(target); std::unordered_map smap; const auto* postproc = Registry::Get("tvm_callback_vulkan_postproc"); diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index f8412b51edcf..8245597fef8b 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -140,20 +140,27 @@ spirv::Value CodeGenSPIRV::GetThreadIndex(const IterVar& iv, const PrimExpr& ext spirv::Value CodeGenSPIRV::CreateStorageSync(const CallNode* op) { const std::string& sync = op->args[0].as()->value; spirv::Value value; - if (sync == "warp") { - return value; - } else if (sync == "shared") { - auto type_int = builder_->GetSType(DataType::Int(32)); - builder_->MakeInst( - spv::OpControlBarrier, - builder_->IntImm(type_int, static_cast(spv::ScopeWorkgroup)), - builder_->IntImm(type_int, static_cast(spv::ScopeWorkgroup)), - builder_->IntImm(type_int, - static_cast(spv::MemorySemanticsSequentiallyConsistentMask | - spv::MemorySemanticsWorkgroupMemoryMask))); + + uint32_t vulkan_api_version = spirv_support_.vulkan_api_version; + + int64_t sync_scope; + int64_t memory_semantics; + if ((sync == "warp") && (vulkan_api_version >= VK_API_VERSION_1_1)) { + sync_scope = spv::ScopeSubgroup; + memory_semantics = + spv::MemorySemanticsSequentiallyConsistentMask | spv::MemorySemanticsSubgroupMemoryMask; + } else if ((sync == "shared") || (sync == "warp")) { + sync_scope = spv::ScopeWorkgroup; + memory_semantics = + spv::MemorySemanticsSequentiallyConsistentMask | spv::MemorySemanticsWorkgroupMemoryMask; } else { LOG(FATAL) << "Do not support sync " << sync; } + + auto type_int = builder_->GetSType(DataType::Int(32)); + builder_->MakeInst(spv::OpControlBarrier, builder_->IntImm(type_int, sync_scope), + builder_->IntImm(type_int, sync_scope), + builder_->IntImm(type_int, memory_semantics)); return value; } diff --git a/src/target/spirv/spirv_support.cc b/src/target/spirv/spirv_support.cc index e06bde08895d..4a294d56bd9c 100644 --- a/src/target/spirv/spirv_support.cc +++ b/src/target/spirv/spirv_support.cc @@ -35,6 +35,10 @@ SPIRVSupport::SPIRVSupport(tvm::Target target) { ICHECK_EQ(target->kind->device_type, kDLVulkan) << "SPIRVSupport can only be checked for vulkan device type"; + if (target->GetAttr("vulkan_api_version")) { + vulkan_api_version = target->GetAttr("vulkan_api_version").value(); + } + if (target->GetAttr("supported_subgroup_operations")) { supported_subgroup_operations = target->GetAttr("supported_subgroup_operations").value(); diff --git a/src/target/spirv/spirv_support.h b/src/target/spirv/spirv_support.h index db15f593dd5a..1497c7c6333a 100644 --- a/src/target/spirv/spirv_support.h +++ b/src/target/spirv/spirv_support.h @@ -27,6 +27,7 @@ #define TVM_TARGET_SPIRV_SPIRV_SUPPORT_H_ #include +#include namespace tvm { namespace codegen { @@ -37,6 +38,19 @@ struct SPIRVSupport { */ explicit SPIRVSupport(Target target); + /*! \brief The Vulkan API version supported by the device. + * + * Vulkan struct: VkPhysicalDeviceProperties + * Device property: apiVersion + * + * If VK_KHR_driver_properties is present, will also check the + * driver conformance version. If the version advertised does not + * pass the Vulkan conformance test, vulkan_api_version will be the + * latest Vulkan version that does pass the conformance test + * instead. + */ + uint32_t vulkan_api_version{VK_MAKE_VERSION(1, 0, 0)}; + /*! * \brief The supported subgroup operations * diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc index f6cb096720da..9598f07e365e 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -425,7 +425,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { while (reduce_align > 1) { reduce_align = reduce_align >> 1; in_warp_seq.emplace_back(freduce(reduce_align)); - seq.emplace_back(SyncThread("warp")); + in_warp_seq.emplace_back(SyncThread("warp")); } if (in_warp_seq.size() != 0) { Stmt warp_body = SeqStmt::Flatten(in_warp_seq);