diff --git a/src/runtime/pipeline/pipeline_struct.h b/src/runtime/pipeline/pipeline_struct.h index 33bdfeee3c31..834a84933e44 100644 --- a/src/runtime/pipeline/pipeline_struct.h +++ b/src/runtime/pipeline/pipeline_struct.h @@ -34,6 +34,8 @@ #include #include #include + +#include "spsc_queue.h" namespace tvm { namespace runtime { #define GLOBAL_MODULE_INDEX -1 @@ -63,12 +65,27 @@ enum InterfaceType { INPUT = 0, OUTPUT, }; +/*!\The state of the pipeline.*/ +enum PipelineState { + STOPPED = 0, + RUNNING, + STOPPING, +}; /*! *\brief The structure includes the module index and the module output index. */ struct ModuleInterfaceID { - ModuleInterfaceID() : runtime_idx(0), runtime_interface_idx(0), interface_type(OUTPUT) { ; } - ModuleInterfaceID(int runtime_index, int runtime_interface_index, InterfaceType type = OUTPUT) { + ModuleInterfaceID() { SetID(0, 0, INPUT); } + ModuleInterfaceID(int runtime_index, int runtime_interface_index, InterfaceType type = INPUT) { + SetID(runtime_index, runtime_interface_index, type); + } + /*! + * \brief Set the value of ID. + * \param runtime_index The index of runtime. + * \param runtime_interface_index The index of interface. + * \param type The type of the interface. + */ + void SetID(int runtime_index, int runtime_interface_index, InterfaceType type) { runtime_idx = runtime_index; runtime_interface_idx = runtime_interface_index; interface_type = type; @@ -84,6 +101,21 @@ struct ModuleInterfaceID { }; /*!\brief The interface type*/ InterfaceType interface_type; + ModuleInterfaceID& operator=(const struct ModuleInterfaceID& id) { + SetID(id.runtime_idx, id.runtime_interface_idx, id.interface_type); + return *this; + } + bool operator==(const struct ModuleInterfaceID& id) const { + return id.interface_type == interface_type && + id.runtime_interface_idx == runtime_interface_idx && id.runtime_idx == runtime_idx; + } +}; +/*!brief The hash function used to generate the hash value for the "ModuleInterfaceID" variable.*/ +struct ModuleIDHash { + bool operator()(const ModuleInterfaceID& id) const { + int offset = sizeof(std::size_t) / 3; + return id.interface_type | id.runtime_interface_idx << offset | id.runtime_idx << offset * 2; + } }; /*!\brief The data notification structure.*/ class DataNotify { @@ -96,24 +128,21 @@ class DataNotify { bool data_ready_ = false; /*!\brief Whether the thread should exit or not.*/ std::atomic exit_state_{false}; - /*! - * \brief The 'ModuleInterfaceID' in which the data was ready and triggered this - * notification. - */ + /*!\brief The 'ModuleInterfaceID' of an interface which sent this notification.*/ ModuleInterfaceID notification_source_; public: /*! * \brief Constructing the DataNotify class. - * \param parent_output_id The id of a runtime interface which is sending out the data + * \param source_interface_id The id of a runtime interface which is sending out the data * notification. */ - explicit DataNotify(ModuleInterfaceID parent_output_id) { - notification_source_ = parent_output_id; + explicit DataNotify(ModuleInterfaceID source_interface_id) { + notification_source_ = source_interface_id; } /*! - * \brief Getting the notification source. - * \return The first 'int' is the runtime index, and the second 'int' is the output index. + * \brief Getting the notification target. + * \return The ID of the interface which is sending out the notification. */ ModuleInterfaceID GetNotifySource(void) { return notification_source_; } /*! @@ -146,8 +175,65 @@ class DataNotify { */ bool GetExitState(void) { return exit_state_.load(std::memory_order_acquire); } }; +/*!\brief The container used to store the forwarding data of the pipeline.*/ +class QueueData { + public: + /*!\brief Doing a deep copy for the 'QueueData' structure.*/ + QueueData& operator=(const QueueData& data) { + CreateCopyFrom(data.GetDLData()); + return *this; + } + QueueData& operator=(const NDArray& from) { + CreateCopyFrom(const_cast(from.operator->())); + return *this; + } + QueueData& operator=(const DLTensor* from) { + CreateCopyFrom(from); + return *this; + } + /*!\brief Create a deep copy of the 'DLTensor' data.*/ + DLTensor* CreateCopyFrom(const DLTensor* from) { + if (!from) { + LOG(FATAL) << "the 'from' pointer is a null pointer!"; + return nullptr; + } + size_t fromLen = tvm::runtime::GetDataSize(*from); + size_t toLen = data_ ? tvm::runtime::GetDataSize(*data_) : 0; + if (!(device_type_ == from->device.device_type && device_id_ == from->device.device_id) || + fromLen != toLen) { + if (data_) { + TVMArrayFree(data_); + data_ = nullptr; + } + TVMArrayAlloc(from->shape, from->ndim, from->dtype.code, from->dtype.bits, from->dtype.lanes, + from->device.device_type, from->device.device_id, &data_); + } + TVMArrayCopyFromTo(const_cast(from), data_, nullptr); + device_type_ = from->device.device_type; + device_id_ = from->device.device_id; + return data_; + } + /*!\brief Return a pointer to the 'DLTensor' data.*/ + DLTensor* GetDLData() const { return data_; } + const int DeviceType() { return device_type_; } + const int DeviceID() { return device_id_; } + ~QueueData() { + if (data_) { + TVMArrayFree(data_); + data_ = nullptr; + } + } + + private: + /*!\brief Pointer to the forwarding data.*/ + DLTensor* data_ = nullptr; + /*!\brief The type of device which generated the QueueData container.*/ + int device_type_; + /*!\brief The id of device which generated the data in this container.*/ + int device_id_; +}; /*! - * \brief All binding information of a output interface. + * \brief All binding information of an output interface. */ class ConfigBindings { public: @@ -274,7 +360,7 @@ class ConfigOutputBindings { return ret; } /*! - * \brief Create a output binding map from JSONReader. + * \brief Create an output binding map from JSONReader. * \param reader Json reader. */ void Load(dmlc::JSONReader* reader) { @@ -427,7 +513,7 @@ struct InputConnectionConfig { return input_connection[key]; } /*! - * \brief Create a input connection config from JSONReader. + * \brief Create an input connection config from JSONReader. * \param reader Json reader. */ void Load(dmlc::JSONReader* reader) { @@ -498,25 +584,44 @@ struct ParamConnectionConfig { } } }; +/*! + * \brief The single consumer single producer queue which is used to forward data between two + * interfaces of backend cores. + */ +using ForwardQueue = SPSCLockFreeQueue; /* - *\brief Backend Runtime. + *!\brief Backend Runtime. */ class BackendRuntime { using ModuleInputPairList = std::vector, int>>; + using ForwardQueueMap = + std::unordered_map, ModuleIDHash>; private: - /*\brief The index of runtime indicates the runtime position in the pipeline.*/ + /*!\brief The index of runtime indicates the runtime position in the pipeline.*/ int runtime_idx_; - /*\brief The Runtime module of a backend graph executor.*/ + /*!\brief The Runtime module of a backend graph executor.*/ Module module_; /*\brief The thread is associated with the current runtime*/ std::thread thread_; - /*\brief A list of runtime which depends on the current runtime.*/ + /*!\brief The state of the pipeline.*/ + std::atomic pipeline_state_{STOPPED}; + /*!\brief A list of runtime which depends on the current runtime.*/ std::unordered_map children_; - /*\brief A map including the runtime input index and the notification data structure.*/ + /*!\brief A map including the runtime input index and the notification data structure.*/ std::unordered_map> parents_notify_; - /*\brief The execution count of the 'RunPipeline' function. */ + /*!\brief The execution count of the 'RunPipeline' function. */ uint32_t pipeline_execution_count_ = 0; + /*! + * \brief A list of SPSC input queues in which the input interface will poll the data sent from + * other backend cores. + */ + std::unordered_map> input_queue_; + /*! + * \brief A list of SPSC output queues in which the output interface will push the data to + * other backend cores. + */ + std::unordered_map output_queue_; /*! *\brief In order to transfer data from one backend runtime to another, we need a local * tensor variable as a medium. "input_tensor_local_copy_" is a map including @@ -533,27 +638,41 @@ class BackendRuntime { tvm::runtime::PackedFunc run_; /*!\brief The worker thread is used to execute the runtimes in pipeline.*/ void StartWorkThread() { + SetPipelineState(RUNNING); if (runtime_idx_ == 0) { this->CreateParentsNotify(0, GLOBAL_MODULE_INDEX, 0); } else { // Only launching the worker thread for the runtimes after the first runtime. thread_ = std::thread([&]() { while (!this->WaitAndLoadPipelineData()) { - this->RunPipeline(); + if (!this->RunPipeline()) { + break; + } } VLOG(1) << "Runtime " << this->runtime_idx_ << " exit."; }); } return; } + /*!\brief Checking if the pipeline is stopped or stopping.*/ + const bool PipelineIsStop() const { + auto state = pipeline_state_.load(std::memory_order_acquire); + return state == STOPPING || state == STOPPED; + } + /*!\brief Setting the state of the pipeline.*/ + void SetPipelineState(PipelineState state) { + pipeline_state_.store(state, std::memory_order_release); + } /*!\brief Stopping the threads in pipeline.*/ void StopPipeline() { + SetPipelineState(STOPPING); for (auto notify : parents_notify_) { notify.second->ExitNotify(); } if (thread_.joinable()) { thread_.join(); } + SetPipelineState(STOPPED); } /*! * \brief Waiting for the internal forwarding data. @@ -567,64 +686,98 @@ class BackendRuntime { // Breaking the loop when the notification is in the exit state. if ((exit_notify = notify->second->GetExitState())) break; // Getting the source which sends this notification. - auto notify_source = notify->second->GetNotifySource(); + auto target_input_interface_index = notify->first; + auto source_interface_id = notify->second->GetNotifySource(); // Loading the binding data. - while (!this->LoadBindingData(notify->first, notify_source.runtime_idx, - notify_source.runtime_output_idx)) { + while (!this->LoadBindingData(target_input_interface_index)) { // Waiting for the notification. if (!notify->second->Wait()) { VLOG(1) << "runtime index:" << runtime_idx_ << " receive exit notify."; exit_notify = true; break; } - // TODO(huajsj): removing this 'break' after finishing the 'LoadBindingData'. - break; } - VLOG(1) << "runtime_index.input_index:" << runtime_idx_ << "." << notify->first - << "from runtime_index.output_index:" << notify_source.runtime_idx << "." - << notify_source.runtime_output_idx; + VLOG(1) << "Data forwarding from runtime(" << source_interface_id.runtime_idx << ").output(" + << source_interface_id.runtime_interface_idx << ") to runtime(" << runtime_idx_ + << ").input(" << target_input_interface_index << ")"; notifys.erase(notify); } return exit_notify; } /*! * \brief Loading the binding data. - * \param parent_idx The index of runtime which forwards data to current runtime. - * \param parent_output_idx The index of output where the forwarding data is coming from. - * \param input_idx The index of input where the data will be forwarding to. + * \param input_index The index of the interface which will receive the forwarding data. * \return Returning 'true' when data is loaded successfully, otherwise returning 'false'. */ - bool LoadBindingData(int parent_idx, int parent_output_idx, int input_idx) { - // TODO(huajsj): Loading data. - return false; + bool LoadBindingData(int input_index) { + if (input_queue_.find(input_index) == input_queue_.end()) { + LOG(FATAL) << "Not finding the associated input queue of the input " << input_index << " !"; + return false; + } + auto queue = input_queue_[input_index]; + QueueData data; + // TODO(huajsj): Doing the 'SetInput' inside the poll function to avoid one time data copy. + if (!queue->Poll(&data)) { + return false; + } + SetInput(input_index, data.GetDLData()); + return true; } /*! * \brief Forwarding the output data into the child runtimes. + * \return bool Return false when the "PipelineIsStop" function returns true or this function + * reaches some errors. Otherwise, return true. */ - void ForwardingOutputDataToChildren(void) { + bool ForwardingOutputDataToChildren(void) { for (auto child : children_) { - // TODO(huajsj): Getting the output data from the current runtime in order to forward - // data to the child. - + auto output_idx = child.first; + if (output_queue_.find(output_idx) == output_queue_.end()) { + LOG(FATAL) << "Not find the forwarding queue map for output(" << output_idx << ")!"; + return false; + } + NDArray output = GetOutput(output_idx); + auto forward_queue_map = output_queue_[output_idx]; // Notifying the 'children runtime' that the forwarding data are ready. for (auto module_pair : child.second) { - module_pair.first->ParentNotify(module_pair.second); + auto child_runtime = module_pair.first; + auto child_runtime_index = child_runtime->GetModuleIndex(); + auto child_input_index = module_pair.second; + auto queue_id = GenerateQueueID(child_runtime_index, child_input_index, INPUT); + if (forward_queue_map.find(queue_id) == forward_queue_map.end()) { + LOG(FATAL) << "Not find the associated queue of the runtime(" << child_runtime_index + << ").input(" << child_input_index << ") which is connected with runtime(" + << runtime_idx_ << ").output(" << output_idx << ")"; + } + auto forward_queue = forward_queue_map[queue_id]; + // If the queue is full, keep try until the push get success or the pipeline run into + // a STOP state. + while (!forward_queue->Push(output)) { + if (PipelineIsStop()) { + LOG(INFO) << "The forwarding process is stopped after the pipeline status is changed" + << " into stop."; + return false; + } + } + child_runtime->ParentNotify(child_input_index); } } + return true; } /*! *\brief Creating a parent notification. *\param input_index The input index of the 'current runtime'. *\param parent_idx The index of 'parent runtime' which will send the notification. *\param parent_output_idx The output index of the 'parent runtime' which will send - * the nofication. + * the notification. */ void CreateParentsNotify(int input_index, int parent_idx, int parent_output_idx) { if (parents_notify_.find(input_index) != parents_notify_.end()) { - LOG(FATAL) << "Not finding the input index " << input_index << " in runtime " << runtime_idx_; + LOG(FATAL) << "The notification associated with the input interface " << input_index + << " in runtime " << runtime_idx_ << " already been created!"; + return; } parents_notify_[input_index] = - std::make_shared(ModuleInterfaceID(parent_idx, parent_output_idx)); + std::make_shared(ModuleInterfaceID(parent_idx, parent_output_idx, OUTPUT)); } /*! * \brief Copying from a given tensor and using 'CPU' as the device. @@ -707,21 +860,24 @@ class BackendRuntime { LOG(FATAL) << "The runtime index " << child_idx << " is out of the range."; } auto child_runtime = runtimes->at(child_idx); + ICHECK(child_runtime->GetModuleIndex() == child_idx); int input_index = child_runtime->GetInputIndex(child_input_name); if (input_index < 0) { LOG(FATAL) << "Can not find the input " << input_index << "in runtime " << child_idx; } children_[output_idx].push_back(std::make_pair(child_runtime, input_index)); child_runtime->CreateParentsNotify(input_index, runtime_idx_, output_idx); - VLOG(1) << " parent_idx.output:" << runtime_idx_ << "." << output_idx << " child.input" - << child_idx << "." << input_index; + VLOG(1) << " parent_idx.output:" << runtime_idx_ << "." << output_idx + << " child.input:" << child_idx << "." << input_index; + // Creating the pipeline forwarding queue. + this->CreateForwardingQueue(output_idx, child_runtime, input_index); }, runtime_idx_); StartWorkThread(); } /*! - * \brief Notifying a input is ready. + * \brief Notifying an input is ready. * \param input_index The index of 'input interface' which is ready for data. */ void ParentNotify(int input_index) { @@ -739,6 +895,45 @@ class BackendRuntime { NDArray data = get_output_(idx); return CreateNDArrayFromDLTensor(const_cast(data.operator->())); } + /*! + * \brief Generate the ID of an input queue. + * \param runtime_index The index of backend runtime. + * \param interface_index The index of the interface. + * \param type The type of the interface. + */ + ModuleInterfaceID GenerateQueueID(int runtime_index, int interface_index, InterfaceType type) { + return ModuleInterfaceID(runtime_index, interface_index, type); + } + /*! + * \brief Creating a forwarding queue for the pair of an output interface and an input interface. + * \param output_idx The index of an output interface which will send the forwarding data. + * \param child_runtime The backend runtime which owns the input interface. + * \param input_index The index of an input interface which will receive the forwarding data. + */ + void CreateForwardingQueue(int output_idx, std::shared_ptr child_runtime, + int input_index) { + auto queue_id = GenerateQueueID(child_runtime->GetModuleIndex(), input_index, INPUT); + // The forwarding queue map of a specified output interface. + auto& queue_map = output_queue_[output_idx]; + if (queue_map.find(queue_id) != queue_map.end()) { + LOG(FATAL) << "The queue " << queue_id.runtime_idx << "." << queue_id.runtime_interface_idx + << " is already created!"; + return; + } + auto queue = std::make_shared(queue_id); + queue_map[queue_id] = queue; + // Use the created queue as the consumer queue for the input interface of this forwarding + // pair. + child_runtime->AppendInputQueue(input_index, queue); + } + /*! + * \brief Setting the consumer queue for the input interface. + * \param input_index The index of the input interface. + * \param queue The consumer queue. + */ + void AppendInputQueue(int input_index, std::shared_ptr queue) { + input_queue_[input_index] = queue; + } /*!\brief Return the index of the current module.*/ int GetModuleIndex() { return runtime_idx_; } /*!\brief Return the number of output*/ @@ -764,11 +959,15 @@ class BackendRuntime { NDArray GetOutput(int index) { return get_output_(index); } /*!\brief Running the runtime.*/ void Run() { run_(); } - /*!\brief Running the runtime in the pipeline mode.*/ - void RunPipeline() { + /*! + * \brief Running the runtime in the pipeline mode. + * \return Returning false if the forwarding function failed. Otherwise, returning true.; + */ + bool RunPipeline() { Run(); - ForwardingOutputDataToChildren(); + bool ret = ForwardingOutputDataToChildren(); pipeline_execution_count_++; + return ret; } }; /*! diff --git a/src/runtime/pipeline/spsc_queue.h b/src/runtime/pipeline/spsc_queue.h new file mode 100644 index 000000000000..17313909f204 --- /dev/null +++ b/src/runtime/pipeline/spsc_queue.h @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_RUNTIME_PIPELINE_SPSC_QUEUE_H_ +#define TVM_RUNTIME_PIPELINE_SPSC_QUEUE_H_ +#include +#include +/*!\brief A single producer and single consumer lock free queue. + */ +template +class SPSCLockFreeQueue { + public: + explicit SPSCLockFreeQueue(IDType id) : id_(id) {} + /*A read barrier enforcing the CPU to performe the reads before this barrier.*/ + inline void read_barrier() { std::atomic_thread_fence(std::memory_order_acquire); } + /*A write barrier enforcing the CPU to performe the writes before this barrier.*/ + inline void write_barrier() { std::atomic_thread_fence(std::memory_order_release); } + /*!\brief Checking whether the queue is full.*/ + bool Full() { + read_barrier(); + return ((tail_ + 1) % len_) == head_; + } + /*!brief Checking whether the queue is empty.*/ + bool Empty() { + read_barrier(); + return head_ == tail_; + } + /*! + * \brief Pushing the data into the queue. Only a single producer will call this function. + * \param data The data which is pushed into the queue. + * \return Return false when the queue is full. Otherwise, return true. + */ + template + bool Push(const data_type& data) { + if (Full()) return false; + queue_[tail_] = data; + write_barrier(); + tail_ = (tail_ + 1) % len_; + return true; + } + /*! + * \brief Poll the data from the front of the queue. Only the single consumer will call this + * function. + * \param data A pointer to the structure which stores the polled data.. + * \return Returning false when the queue is empty. Otherwise, return true. + */ + template + bool Poll(data_type* data) { + if (Empty()) return false; + *data = queue_[head_]; + write_barrier(); + head_ = (head_ + 1) % len_; + return true; + } + + private: + /*!\brief The pointer points to the first slot with valid data in the queue.*/ + size_t head_ = 0; + /*!\brief The end of the queue at which elements are added.*/ + size_t tail_ = 0; + /*!\brief The length of the queue.*/ + size_t len_ = QueueLength; + /*!\brief The queue used to store the data.*/ + SlotType queue_[QueueLength]; + /*!\brief The ID of the queue.*/ + IDType id_; +}; +#endif // TVM_RUNTIME_PIPELINE_SPSC_QUEUE_H_ diff --git a/tests/python/relay/test_pipeline_executor.py b/tests/python/relay/test_pipeline_executor.py index 8ab2265db3d6..ff30c2affe47 100644 --- a/tests/python/relay/test_pipeline_executor.py +++ b/tests/python/relay/test_pipeline_executor.py @@ -17,6 +17,7 @@ import pytest import os +import time import numpy as np import tvm import tvm.testing