diff --git a/src/runtime/vulkan/vulkan.cc b/src/runtime/vulkan/vulkan.cc index 5cd4812f41c4..c8a0858ec1bc 100644 --- a/src/runtime/vulkan/vulkan.cc +++ b/src/runtime/vulkan/vulkan.cc @@ -91,6 +91,11 @@ struct VulkanBuffer { VkDeviceMemory memory{VK_NULL_HANDLE}; }; +struct UniformBuffer { + VulkanBuffer* vk_buf; + void* host_buf; +}; + struct VulkanPipeline { VulkanContext* vctx_{nullptr}; VkShaderModule shader{VK_NULL_HANDLE}; @@ -100,10 +105,105 @@ struct VulkanPipeline { VkPipelineLayout pipeline_layout{VK_NULL_HANDLE}; VkPipeline pipeline{VK_NULL_HANDLE}; VkDescriptorUpdateTemplateKHR descriptor_update_template{VK_NULL_HANDLE}; + UniformBuffer ubo; }; typedef dmlc::ThreadLocalStore VulkanThreadStore; +uint32_t FindMemoryType(VkDevice logical_device, VkPhysicalDevice phy_device, VkBuffer buffer, + VkMemoryPropertyFlags req_prop) { + VkMemoryRequirements mem_reqs; + vkGetBufferMemoryRequirements(logical_device, buffer, &mem_reqs); + uint32_t type_bits = mem_reqs.memoryTypeBits; + VkPhysicalDeviceMemoryProperties phy_mem_prop; + vkGetPhysicalDeviceMemoryProperties(phy_device, &phy_mem_prop); + for (uint32_t i = 0; i < phy_mem_prop.memoryTypeCount; i++) { + if ((type_bits & 1) == 1 && + (phy_mem_prop.memoryTypes[i].propertyFlags & req_prop) == req_prop) { + return i; + } + type_bits >>= 1; + } + LOG(FATAL) << "Requested memory type not found"; + return 0; +} + +VulkanBuffer* CreateBuffer(const VulkanContext& vctx, size_t nbytes, VkBufferUsageFlags usage) { + VkBufferCreateInfo info; + info.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO; + info.pNext = nullptr; + info.flags = 0; + info.size = nbytes; + info.queueFamilyIndexCount = 1; + info.pQueueFamilyIndices = &(vctx.queue_family_index); + info.sharingMode = VK_SHARING_MODE_EXCLUSIVE; + info.usage = usage; + // create buffer + VkBuffer buffer; + VULKAN_CALL(vkCreateBuffer(vctx.device, &info, nullptr, &buffer)); + + uint32_t mem_type_index = vctx.compute_mtype_index; + + if (usage & VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT) { + // Find a memory type that supports UBO + auto prop = VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT; + mem_type_index = FindMemoryType(vctx.device, vctx.phy_device, buffer, prop); + } + + // bind to memory + bool dedicated_allocation = false; + VkMemoryRequirements2KHR req2; + + if (vctx.get_buffer_memory_requirements_2_functions) { + VkBufferMemoryRequirementsInfo2KHR req_info2; + req_info2.sType = VK_STRUCTURE_TYPE_BUFFER_MEMORY_REQUIREMENTS_INFO_2_KHR; + req_info2.pNext = 0; + req_info2.buffer = buffer; + + req2.sType = VK_STRUCTURE_TYPE_MEMORY_REQUIREMENTS_2_KHR; + req2.pNext = 0; + + VkMemoryDedicatedRequirementsKHR dedicated_req; + dedicated_req.sType = VK_STRUCTURE_TYPE_MEMORY_DEDICATED_REQUIREMENTS_KHR; + dedicated_req.pNext = 0; + req2.pNext = &dedicated_req; + + vctx.get_buffer_memory_requirements_2_functions->vkGetBufferMemoryRequirements2KHR( + vctx.device, &req_info2, &req2); + dedicated_allocation = + dedicated_req.requiresDedicatedAllocation || dedicated_req.prefersDedicatedAllocation; + } + + VkDeviceMemory memory; + if (!dedicated_allocation) { + VkMemoryAllocateInfo minfo; + minfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO; + minfo.pNext = nullptr; + minfo.allocationSize = nbytes; + minfo.memoryTypeIndex = mem_type_index; + VULKAN_CALL(vkAllocateMemory(vctx.device, &minfo, nullptr, &memory)); + } else { + VkMemoryAllocateInfo minfo; + minfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO; + minfo.pNext = nullptr; + minfo.allocationSize = req2.memoryRequirements.size; + minfo.memoryTypeIndex = mem_type_index; + + VkMemoryDedicatedAllocateInfoKHR mdinfo; + mdinfo.sType = VK_STRUCTURE_TYPE_MEMORY_DEDICATED_ALLOCATE_INFO_KHR; + mdinfo.pNext = 0; + mdinfo.image = 0; + mdinfo.buffer = buffer; + minfo.pNext = &mdinfo; + VULKAN_CALL(vkAllocateMemory(vctx.device, &minfo, nullptr, &memory)); + } + VULKAN_CALL(vkBindBufferMemory(vctx.device, buffer, memory, 0)); + VulkanBuffer* pbuf = new VulkanBuffer(); + pbuf->memory = memory; + pbuf->buffer = buffer; + return pbuf; +} + class VulkanDeviceAPI final : public DeviceAPI { public: VulkanDeviceAPI(); @@ -124,70 +224,9 @@ class VulkanDeviceAPI final : public DeviceAPI { nbytes = 1; } const auto& vctx = context(dev.device_id); - VkBufferCreateInfo info; - info.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO; - info.pNext = nullptr; - info.flags = 0; - info.size = nbytes; - info.queueFamilyIndexCount = 1; - info.pQueueFamilyIndices = &(vctx.queue_family_index); - info.sharingMode = VK_SHARING_MODE_EXCLUSIVE; - info.usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT | + auto usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT | VK_BUFFER_USAGE_STORAGE_BUFFER_BIT; - // create buffer - VkBuffer buffer; - VULKAN_CALL(vkCreateBuffer(vctx.device, &info, nullptr, &buffer)); - // bind to memory - VkBufferMemoryRequirementsInfo2KHR req_info2; - req_info2.sType = VK_STRUCTURE_TYPE_BUFFER_MEMORY_REQUIREMENTS_INFO_2_KHR; - req_info2.pNext = 0; - req_info2.buffer = buffer; - - VkMemoryRequirements2KHR req2; - req2.sType = VK_STRUCTURE_TYPE_MEMORY_REQUIREMENTS_2_KHR; - req2.pNext = 0; - - VkMemoryDedicatedRequirementsKHR dedicated_req; - dedicated_req.sType = VK_STRUCTURE_TYPE_MEMORY_DEDICATED_REQUIREMENTS_KHR; - dedicated_req.pNext = 0; - req2.pNext = &dedicated_req; - - bool dedicated_allocation = false; - if (vctx.get_buffer_memory_requirements_2_functions) { - vctx.get_buffer_memory_requirements_2_functions->vkGetBufferMemoryRequirements2KHR( - vctx.device, &req_info2, &req2); - dedicated_allocation = - dedicated_req.requiresDedicatedAllocation || dedicated_req.prefersDedicatedAllocation; - } - - VkDeviceMemory memory; - if (!dedicated_allocation) { - VkMemoryAllocateInfo minfo; - minfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO; - minfo.pNext = nullptr; - minfo.allocationSize = nbytes; - minfo.memoryTypeIndex = vctx.compute_mtype_index; - VULKAN_CALL(vkAllocateMemory(vctx.device, &minfo, nullptr, &memory)); - } else { - VkMemoryAllocateInfo minfo; - minfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO; - minfo.pNext = nullptr; - minfo.allocationSize = req2.memoryRequirements.size; - minfo.memoryTypeIndex = vctx.compute_mtype_index; - - VkMemoryDedicatedAllocateInfoKHR mdinfo; - mdinfo.sType = VK_STRUCTURE_TYPE_MEMORY_DEDICATED_ALLOCATE_INFO_KHR; - mdinfo.pNext = 0; - mdinfo.image = 0; - mdinfo.buffer = buffer; - minfo.pNext = &mdinfo; - VULKAN_CALL(vkAllocateMemory(vctx.device, &minfo, nullptr, &memory)); - } - VULKAN_CALL(vkBindBufferMemory(vctx.device, buffer, memory, 0)); - VulkanBuffer* pbuf = new VulkanBuffer(); - pbuf->memory = memory; - pbuf->buffer = buffer; - return pbuf; + return CreateBuffer(vctx, nbytes, usage); } void FreeDataSpace(Device dev, void* ptr) final { @@ -747,7 +786,7 @@ class VulkanModuleNode final : public runtime::ModuleNode { public: explicit VulkanModuleNode(std::unordered_map smap, std::unordered_map fmap, std::string source) - : smap_(smap), fmap_(fmap), source_(source) {} + : smap_(smap), fmap_(fmap), source_(source), max_push_constants_(GetMaxPushConstantsSize()) {} const char* type_key() const final { return "vulkan"; } @@ -781,6 +820,13 @@ class VulkanModuleNode final : public runtime::ModuleNode { vkDestroyDescriptorPool(vctx.device, pe->descriptor_pool, nullptr); vkDestroyDescriptorSetLayout(vctx.device, pe->descriptor_set_layout, nullptr); vkDestroyShaderModule(vctx.device, pe->shader, nullptr); + // UBO + if (pe->ubo.vk_buf) { + vkUnmapMemory(vctx.device, pe->ubo.vk_buf->memory); + vkDestroyBuffer(vctx.device, pe->ubo.vk_buf->buffer, nullptr); + vkFreeMemory(vctx.device, pe->ubo.vk_buf->memory, nullptr); + delete pe->ubo.vk_buf; + } } } } @@ -812,30 +858,35 @@ class VulkanModuleNode final : public runtime::ModuleNode { std::vector arg_template; uint32_t num_pod = 0, num_buffer = 0; + auto push_arg_info = [&arg_binding, &arg_template](uint32_t binding, + VkDescriptorType desc_type) { + { + VkDescriptorSetLayoutBinding bd; + bd.binding = binding; + bd.descriptorType = desc_type; + bd.descriptorCount = 1; + bd.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT; + bd.pImmutableSamplers = nullptr; + arg_binding.push_back(bd); + } + { + VkDescriptorUpdateTemplateEntryKHR tpl; + tpl.dstBinding = binding; + tpl.dstArrayElement = 0; + tpl.descriptorCount = 1; + tpl.descriptorType = desc_type; + tpl.offset = binding * sizeof(VkDescriptorBufferInfo); + tpl.stride = sizeof(VkDescriptorBufferInfo); + arg_template.push_back(tpl); + } + }; + { auto fit = fmap_.find(func_name); ICHECK(fit != fmap_.end()); for (DLDataType arg_type : fit->second.arg_types) { if (arg_type.code == kTVMOpaqueHandle) { - { - VkDescriptorSetLayoutBinding bd; - bd.binding = num_buffer; - bd.descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER; - bd.descriptorCount = 1; - bd.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT; - bd.pImmutableSamplers = nullptr; - arg_binding.push_back(bd); - } - { - VkDescriptorUpdateTemplateEntryKHR tpl; - tpl.dstBinding = num_buffer; - tpl.dstArrayElement = 0; - tpl.descriptorCount = 1; - tpl.descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER; - tpl.offset = num_buffer * sizeof(VkDescriptorBufferInfo); - tpl.stride = sizeof(VkDescriptorBufferInfo); - arg_template.push_back(tpl); - } + push_arg_info(num_buffer, VK_DESCRIPTOR_TYPE_STORAGE_BUFFER); ++num_buffer; } else { ++num_pod; @@ -843,6 +894,11 @@ class VulkanModuleNode final : public runtime::ModuleNode { } } + size_t nbytes_scalars = num_pod * sizeof(ArgUnion64); + if (nbytes_scalars > max_push_constants_) { + push_arg_info(num_buffer, VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER); + } + { VkDescriptorSetLayoutCreateInfo descrip_cinfo; descrip_cinfo.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO; @@ -894,7 +950,7 @@ class VulkanModuleNode final : public runtime::ModuleNode { playout_cinfo.setLayoutCount = 1; playout_cinfo.pSetLayouts = &(pe->descriptor_set_layout); - if (num_pack_args != 0) { + if (0 < nbytes_scalars && nbytes_scalars <= max_push_constants_) { playout_cinfo.pushConstantRangeCount = 1; playout_cinfo.pPushConstantRanges = &crange; ICHECK_LE(crange.size, vctx.phy_device_prop.limits.maxPushConstantsSize); @@ -923,6 +979,13 @@ class VulkanModuleNode final : public runtime::ModuleNode { VULKAN_CALL(vkCreateComputePipelines(vctx.device, VK_NULL_HANDLE, 1, &pipeline_cinfo, nullptr, &(pe->pipeline))); + if (nbytes_scalars > max_push_constants_) { + // Allocate, bind and map UBO + UniformBuffer& ubo = pe->ubo; + ubo.vk_buf = CreateBuffer(vctx, nbytes_scalars, VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT); + vkMapMemory(vctx.device, ubo.vk_buf->memory, 0, nbytes_scalars, 0, &(ubo.host_buf)); + } + if (vctx.UseImmediate()) { VkDescriptorUpdateTemplateCreateInfoKHR descrip_template_cinfo; descrip_template_cinfo.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_UPDATE_TEMPLATE_CREATE_INFO_KHR; @@ -966,6 +1029,8 @@ class VulkanModuleNode final : public runtime::ModuleNode { return source_; } + uint32_t MaxPushConstantsSize() const { return max_push_constants_; } + private: // function information table. std::unordered_map smap_; @@ -975,6 +1040,8 @@ class VulkanModuleNode final : public runtime::ModuleNode { std::string fmt_{"vulkan"}; // The source std::string source_; + // The maximum size of push constants in bytes + const uint32_t max_push_constants_; // Guards accesses to `ecache_` std::mutex mutex_; @@ -1076,6 +1143,17 @@ void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, binfo.range = VK_WHOLE_SIZE; descriptor_buffers[i] = binfo; } + const size_t nbytes_scalars = num_pack_args_ * sizeof(ArgUnion64); + bool use_ubo = num_pack_args_ != 0 && nbytes_scalars > m_->MaxPushConstantsSize(); + if (use_ubo) { + CHECK(pipeline->ubo.host_buf) << "The UBO host buffer is not allocated"; + memcpy(pipeline->ubo.host_buf, pack_args, nbytes_scalars); + VkDescriptorBufferInfo binfo; + binfo.buffer = pipeline->ubo.vk_buf->buffer; + binfo.offset = 0; + binfo.range = VK_WHOLE_SIZE; + descriptor_buffers.push_back(binfo); + } if (vctx.UseImmediate()) { // Can safely capture by reference as this lambda is immediately executed on the calling thread. VulkanThreadEntry::ThreadLocal()->Stream(device_id)->Launch([&](VulkanStreamState* state) { @@ -1084,7 +1162,7 @@ void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, vctx.descriptor_template_khr_functions->vkCmdPushDescriptorSetWithTemplateKHR( state->cmd_buffer_, pipeline->descriptor_update_template, pipeline->pipeline_layout, 0, descriptor_buffers.data()); - if (num_pack_args_ != 0) { + if (num_pack_args_ > 0 && num_pack_args_ <= m_->MaxPushConstantsSize()) { vkCmdPushConstants(state->cmd_buffer_, pipeline->pipeline_layout, VK_SHADER_STAGE_COMPUTE_BIT, 0, num_pack_args_ * sizeof(ArgUnion64), pack_args); @@ -1105,7 +1183,7 @@ void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, // Otherwise, the more expensive deferred path. std::vector pack_args_storage(pack_args, pack_args + num_pack_args_); - const auto& deferred_initializer = [&vctx, pipeline, descriptor_buffers]() { + const auto& deferred_initializer = [&vctx, pipeline, descriptor_buffers, use_ubo]() { std::vector write_descriptor_sets; write_descriptor_sets.resize(descriptor_buffers.size()); for (size_t i = 0; i < write_descriptor_sets.size(); i++) { @@ -1115,20 +1193,26 @@ void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, write_descriptor_sets[i].dstBinding = i; write_descriptor_sets[i].dstArrayElement = 0; write_descriptor_sets[i].descriptorCount = 1; - write_descriptor_sets[i].descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER; write_descriptor_sets[i].pImageInfo = 0; write_descriptor_sets[i].pBufferInfo = &(descriptor_buffers[i]); write_descriptor_sets[i].pTexelBufferView = 0; + + if (use_ubo && i == write_descriptor_sets.size() - 1) { + // The last binding is for UBO + write_descriptor_sets[i].descriptorType = VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER; + } else { + write_descriptor_sets[i].descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER; + } } vkUpdateDescriptorSets(vctx.device, write_descriptor_sets.size(), write_descriptor_sets.data(), 0, 0); }; - const auto& deferred_kernel = [pipeline, wl, pack_args_storage](VulkanStreamState* state) { + const auto& deferred_kernel = [this, pipeline, wl, pack_args_storage](VulkanStreamState* state) { vkCmdBindPipeline(state->cmd_buffer_, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline->pipeline); vkCmdBindDescriptorSets(state->cmd_buffer_, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline->pipeline_layout, 0, 1, &(pipeline->descriptor_set), 0, nullptr); - if (pack_args_storage.size() != 0) { + if (num_pack_args_ > 0 && num_pack_args_ <= m_->MaxPushConstantsSize()) { vkCmdPushConstants(state->cmd_buffer_, pipeline->pipeline_layout, VK_SHADER_STAGE_COMPUTE_BIT, 0, pack_args_storage.size() * sizeof(ArgUnion64), pack_args_storage.data()); @@ -1183,6 +1267,12 @@ Module VulkanModuleLoadBinary(void* strm) { return VulkanModuleCreate(smap, fmap, ""); } +uint32_t GetMaxPushConstantsSize() { + int device_id = VulkanThreadEntry::ThreadLocal()->device.device_id; + const auto& vctx = VulkanDeviceAPI::Global()->context(device_id); + return vctx.phy_device_prop.limits.maxPushConstantsSize; +} + TVM_REGISTER_GLOBAL("runtime.module.loadfile_vulkan").set_body_typed(VulkanModuleLoadFile); TVM_REGISTER_GLOBAL("runtime.module.loadbinary_vulkan").set_body_typed(VulkanModuleLoadBinary); diff --git a/src/runtime/vulkan/vulkan_common.h b/src/runtime/vulkan/vulkan_common.h index 3083ba6f9ce4..e94a9fe7fa90 100644 --- a/src/runtime/vulkan/vulkan_common.h +++ b/src/runtime/vulkan/vulkan_common.h @@ -142,6 +142,9 @@ struct VulkanContext { bool UseImmediate() const { return descriptor_template_khr_functions.get() != nullptr; } }; +/*! \brief returns maximum push constant sizes in bytes for the target platform */ +uint32_t GetMaxPushConstantsSize(); + } // namespace vulkan } // namespace runtime } // namespace tvm diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index 24608ebc93f4..4d55f4c49a5f 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -30,6 +30,9 @@ #include +#include "../../runtime/pack_args.h" +#include "../../runtime/vulkan/vulkan_common.h" + namespace tvm { namespace codegen { @@ -66,16 +69,26 @@ std::vector CodeGenSPIRV::BuildFunction(const PrimFunc& f, const std:: spirv::Value func_ptr = builder_->NewFunction(); builder_->StartFunction(func_ptr); - // All the POD arguments are passed in through PushConstant if (pod_args.size() != 0) { std::vector value_types; for (size_t i = 0; i < pod_args.size(); ++i) { value_types.push_back(builder_->GetSType(pod_args[i].dtype())); } - spirv::Value ptr = builder_->DeclarePushConstant(value_types); - for (size_t i = 0; i < pod_args.size(); ++i) { - spirv::Value value = builder_->GetPushConstant(ptr, value_types[i], static_cast(i)); - var_map_[pod_args[i].get()] = value; + const auto max_push_constants = runtime::vulkan::GetMaxPushConstantsSize(); + if (pod_args.size() * sizeof(runtime::ArgUnion64) <= max_push_constants) { + spirv::Value ptr = builder_->DeclarePushConstant(value_types); + for (size_t i = 0; i < pod_args.size(); ++i) { + spirv::Value value = + builder_->GetPushConstant(ptr, value_types[i], static_cast(i)); + var_map_[pod_args[i].get()] = value; + } + } else { + // If we need to pass more arguments than push constants could handle, we use UBO. + spirv::Value ptr = builder_->DeclareUniformBuffer(value_types, num_buffer); + for (size_t i = 0; i < pod_args.size(); ++i) { + spirv::Value value = builder_->GetUniform(ptr, value_types[i], static_cast(i)); + var_map_[pod_args[i].get()] = value; + } } } this->VisitStmt(f->body); diff --git a/src/target/spirv/ir_builder.cc b/src/target/spirv/ir_builder.cc index 5a1457387ae5..cd48c93530ec 100644 --- a/src/target/spirv/ir_builder.cc +++ b/src/target/spirv/ir_builder.cc @@ -205,8 +205,8 @@ Value IRBuilder::BufferArgument(const SType& value_type, uint32_t descriptor_set return val; } -Value IRBuilder::DeclarePushConstant(const std::vector& value_types) { - ICHECK_EQ(push_const_.id, 0); +Value IRBuilder::DeclareStorageVariable(const std::vector& value_types, + spv::StorageClass storage_class, ValueKind kind) { SType struct_type; struct_type.id = id_counter_++; struct_type.type = DataType::Handle(); @@ -226,22 +226,26 @@ Value IRBuilder::DeclarePushConstant(const std::vector& value_types) { ICHECK_EQ(nbits % 8, 0); uint32_t bytes = (nbits / 8); if (t.bits() == 32) { - // In our Vulkan runtime, each push constant always occupies 64 bit. + // In our Vulkan runtime, each scalar argument always occupies 64 bit. offset += bytes * 2; } else { ICHECK_EQ(t.bits(), 64); offset += bytes; } } - // Decorate push constants as UBO this->Decorate(spv::OpDecorate, struct_type, spv::DecorationBlock); - SType ptr_type = GetPointerType(struct_type, spv::StorageClassPushConstant); - Value val = NewValue(ptr_type, kPushConstantPtr); - ib_.Begin(spv::OpVariable).AddSeq(ptr_type, val, spv::StorageClassPushConstant).Commit(&global_); + SType ptr_type = GetPointerType(struct_type, storage_class); + Value val = NewValue(ptr_type, kind); + ib_.Begin(spv::OpVariable).AddSeq(ptr_type, val, storage_class).Commit(&global_); return val; } +Value IRBuilder::DeclarePushConstant(const std::vector& value_types) { + ICHECK_EQ(push_const_.id, 0); + return DeclareStorageVariable(value_types, spv::StorageClassPushConstant, kPushConstantPtr); +} + Value IRBuilder::GetPushConstant(Value ptr_push_const, const SType& v_type, uint32_t index) { SType ptr_vtype = this->GetPointerType(v_type, spv::StorageClassPushConstant); Value ptr = this->MakeValue(spv::OpAccessChain, ptr_vtype, ptr_push_const, @@ -249,6 +253,19 @@ Value IRBuilder::GetPushConstant(Value ptr_push_const, const SType& v_type, uint return this->MakeValue(spv::OpLoad, v_type, ptr); } +Value IRBuilder::DeclareUniformBuffer(const std::vector& value_types, uint32_t binding) { + Value val = DeclareStorageVariable(value_types, spv::StorageClassUniform, kUniformPtr); + this->Decorate(spv::OpDecorate, val, spv::DecorationBinding, binding); + return val; +} + +Value IRBuilder::GetUniform(Value ptr_push_const, const SType& v_type, uint32_t index) { + SType ptr_vtype = this->GetPointerType(v_type, spv::StorageClassUniform); + Value ptr = this->MakeValue(spv::OpAccessChain, ptr_vtype, ptr_push_const, + IntImm(t_int32_, static_cast(index))); + return this->MakeValue(spv::OpLoad, v_type, ptr); +} + Value IRBuilder::NewFunction() { return NewValue(t_void_func_, kFunction); } void IRBuilder::CommitKernelFunction(const Value& func, const std::string& name) { diff --git a/src/target/spirv/ir_builder.h b/src/target/spirv/ir_builder.h index 8a08048e1955..05a2bc631743 100644 --- a/src/target/spirv/ir_builder.h +++ b/src/target/spirv/ir_builder.h @@ -60,7 +60,8 @@ enum ValueKind { kStructArrayPtr, kPushConstantPtr, kFunction, - kExtInst + kExtInst, + kUniformPtr }; /*! \brief Represent the SPIRV Value */ @@ -473,6 +474,7 @@ class IRBuilder { * \param The argument type. */ Value BufferArgument(const SType& value_type, uint32_t descriptor_set, uint32_t binding); + /*! * \brief Declare POD arguments through push constants. * @@ -488,6 +490,23 @@ class IRBuilder { * \return the value of push constant */ Value GetPushConstant(Value ptr_push_const, const SType& v_type, uint32_t index); + + /*! + * \brief Declare POD arguments through uniform buffer. + * + * \note Only call this function once! + * \param value_types The values in the uniform buffer + * \param binding The binding locaiton in descriptor set + * \return reference to self. + */ + Value DeclareUniformBuffer(const std::vector& value_types, uint32_t binding); + /*! + * \brief Get i-th uniform constant + * \param v_type The value type + * \param index The uniform index + * \return the value of uniform constant + */ + Value GetUniform(Value ptr_ubo, const SType& v_type, uint32_t index); /*! * \brief Declare a new function * \return The created function ID. @@ -555,6 +574,17 @@ class IRBuilder { val.flag = flag; return val; } + + /*! + * \brief The common function to declare push constants or uniform buffer + * \param value_types The values in the push constants or uniform buffer + * \param storage_class An enum defined by SPIR-V indicating push constant or uniform + * \param kind An enum indicating push constant or uniform + * \return The created new label + */ + Value DeclareStorageVariable(const std::vector& value_types, + spv::StorageClass storage_class, ValueKind kind); + // get constant given value encoded in uint64_t Value GetConst_(const SType& dtype, const uint64_t* pvalue); // declare type