From 91b3ce3bf9d94bd8cc2e73cdbb9ff3224c222888 Mon Sep 17 00:00:00 2001 From: WangXi Date: Thu, 23 Dec 2021 20:17:23 +0800 Subject: [PATCH 1/7] add task loop thread pool --- .../distributed/fleet_executor/CMakeLists.txt | 6 +- .../distributed/fleet_executor/task_loop.cc | 82 +++++++++++++++++++ .../distributed/fleet_executor/task_loop.h | 81 ++++++++++++++++++ .../fleet_executor/task_loop_thread.cc | 58 +++++++++++++ .../fleet_executor/task_loop_thread.h | 44 ++++++++++ .../fleet_executor/task_loop_thread_pool.cc | 66 +++++++++++++++ .../fleet_executor/task_loop_thread_pool.h | 47 +++++++++++ paddle/fluid/framework/blocking_queue.h | 10 +++ 8 files changed, 392 insertions(+), 2 deletions(-) create mode 100644 paddle/fluid/distributed/fleet_executor/task_loop.cc create mode 100644 paddle/fluid/distributed/fleet_executor/task_loop.h create mode 100644 paddle/fluid/distributed/fleet_executor/task_loop_thread.cc create mode 100644 paddle/fluid/distributed/fleet_executor/task_loop_thread.h create mode 100644 paddle/fluid/distributed/fleet_executor/task_loop_thread_pool.cc create mode 100644 paddle/fluid/distributed/fleet_executor/task_loop_thread_pool.h diff --git a/paddle/fluid/distributed/fleet_executor/CMakeLists.txt b/paddle/fluid/distributed/fleet_executor/CMakeLists.txt index 82444ae77dc9d..bb6440ed82534 100644 --- a/paddle/fluid/distributed/fleet_executor/CMakeLists.txt +++ b/paddle/fluid/distributed/fleet_executor/CMakeLists.txt @@ -10,10 +10,12 @@ else() set(BRPC_DEPS "") endif() +cc_library(task_loop_thread_pool SRCS task_loop_thread_pool.cc task_loop_thread.cc task_loop.cc) + cc_library(fleet_executor SRCS fleet_executor.cc carrier.cc task_node.cc runtime_graph.cc interceptor.cc compute_interceptor.cc amplifier_interceptor.cc interceptor_message_service.cc message_bus.cc - DEPS proto_desc fleet_executor_desc_proto interceptor_message_proto collective_helper op_registry - executor_gc_helper gflags glog ${BRPC_DEPS}) + DEPS proto_desc fleet_executor_desc_proto interceptor_message_proto task_loop_thread_pool collective_helper + op_registry executor_gc_helper gflags glog ${BRPC_DEPS}) if(WITH_DISTRIBUTE) set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") diff --git a/paddle/fluid/distributed/fleet_executor/task_loop.cc b/paddle/fluid/distributed/fleet_executor/task_loop.cc new file mode 100644 index 0000000000000..bfe9a939b966c --- /dev/null +++ b/paddle/fluid/distributed/fleet_executor/task_loop.cc @@ -0,0 +1,82 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/distributed/fleet_executor/task_loop.h" + +#include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/errors.h" + +namespace paddle { +namespace distributed { + +thread_local TaskLoop* TaskLoop::thread_local_loop_ = nullptr; + +TaskLoop* TaskLoop::GetTaskLoopOfCurrentThread() { return thread_local_loop_; } + +TaskLoop::TaskLoop() + : looping_(false), quit_(false), thread_id_(std::this_thread::get_id()) { + PADDLE_ENFORCE_EQ( + thread_local_loop_, nullptr, + platform::errors::AlreadyExists("Another TaskLoop is already init.")); + thread_local_loop_ = this; +} + +TaskLoop::~TaskLoop() { thread_local_loop_ = nullptr; } + +void TaskLoop::Loop() { + PADDLE_ENFORCE_EQ(looping_, false, + platform::errors::PreconditionNotMet( + "Loop can only execute in one loop thread")); + AssertInLoopThread(); + + looping_ = true; + quit_ = false; + + while (!quit_) { + auto tasks = tasks_.PopAll(); + for (auto& task : tasks) { + task(); + } + } + looping_ = false; +} + +void TaskLoop::Quit() { + quit_ = true; + if (!IsInLoopThread()) WakeUp(); +} + +void TaskLoop::RunInLoop(Functor cb) { + if (IsInLoopThread()) { + cb(); + } else { + QueueInLoop(cb); + } +} + +void TaskLoop::QueueInLoop(Functor cb) { tasks_.Push(cb); } + +void TaskLoop::WakeUp() { + Functor task([] {}); + QueueInLoop(task); +} + +void TaskLoop::AbortNotInLoopThread() { + PADDLE_THROW(platform::errors::PreconditionNotMet( + "This TaskLoop was created in thread %d, but current thread is %d", + thread_id_, std::this_thread::get_id())); +} + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/fleet_executor/task_loop.h b/paddle/fluid/distributed/fleet_executor/task_loop.h new file mode 100644 index 0000000000000..91425304e57d1 --- /dev/null +++ b/paddle/fluid/distributed/fleet_executor/task_loop.h @@ -0,0 +1,81 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include + +#include "paddle/fluid/framework/blocking_queue.h" + +namespace paddle { +namespace distributed { + +class TaskLoop { + public: + static TaskLoop* GetTaskLoopOfCurrentThread(); + + using Functor = std::function; + + TaskLoop(); + ~TaskLoop(); + + void Loop(); + void Quit(); + + void RunInLoop(Functor cb); + void QueueInLoop(Functor cb); + + template + auto Enqueue(F&& f, Args&&... args) + -> std::future::type> { + using return_type = typename std::result_of::type; + + auto task = std::make_shared>( + std::bind(std::forward(f), std::forward(args)...)); + std::future task_future = task->get_future(); + + tasks_.Push([task]() { (*task)(); }); + return task_future; + } + + void WakeUp(); + + bool IsInLoopThread() const { + return thread_id_ == std::this_thread::get_id(); + } + + void AssertInLoopThread() { + if (!IsInLoopThread()) { + AbortNotInLoopThread(); + } + } + + private: + void AbortNotInLoopThread(); + + static thread_local TaskLoop* thread_local_loop_; + + bool looping_; + std::atomic quit_; + std::thread::id thread_id_; + + framework::BlockingQueue tasks_; +}; + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/fleet_executor/task_loop_thread.cc b/paddle/fluid/distributed/fleet_executor/task_loop_thread.cc new file mode 100644 index 0000000000000..bb313ad37890d --- /dev/null +++ b/paddle/fluid/distributed/fleet_executor/task_loop_thread.cc @@ -0,0 +1,58 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/distributed/fleet_executor/task_loop_thread.h" + +#include "paddle/fluid/distributed/fleet_executor/task_loop.h" +#include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/errors.h" + +namespace paddle { +namespace distributed { + +TaskLoopThread::TaskLoopThread() : start_(false), loop_(nullptr) {} + +TaskLoopThread::~TaskLoopThread() { + if (loop_ != nullptr) { + loop_->Quit(); + thread_.join(); + } +} + +TaskLoop* TaskLoopThread::StartLoop() { + PADDLE_ENFORCE_EQ(start_, false, platform::errors::PreconditionNotMet( + "thread is already running.")); + start_ = true; + thread_ = std::thread([this]() { Loop(); }); + + std::unique_lock lock(mutex_); + cv_.wait(lock, [=] { return loop_ != nullptr; }); + return loop_; +} + +void TaskLoopThread::Loop() { + TaskLoop loop; + { + std::unique_lock lock(mutex_); + loop_ = &loop; + cv_.notify_one(); + } + loop.Loop(); + + std::unique_lock lock(mutex_); + loop_ = nullptr; +} + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/fleet_executor/task_loop_thread.h b/paddle/fluid/distributed/fleet_executor/task_loop_thread.h new file mode 100644 index 0000000000000..07952abdc247c --- /dev/null +++ b/paddle/fluid/distributed/fleet_executor/task_loop_thread.h @@ -0,0 +1,44 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +namespace paddle { +namespace distributed { + +class TaskLoop; + +class TaskLoopThread { + public: + TaskLoopThread(); + ~TaskLoopThread(); + + TaskLoop* StartLoop(); + + private: + void Loop(); + + bool start_; + TaskLoop* loop_; + std::thread thread_; + std::mutex mutex_; + std::condition_variable cv_; +}; + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/fleet_executor/task_loop_thread_pool.cc b/paddle/fluid/distributed/fleet_executor/task_loop_thread_pool.cc new file mode 100644 index 0000000000000..ed34bbb87fc6b --- /dev/null +++ b/paddle/fluid/distributed/fleet_executor/task_loop_thread_pool.cc @@ -0,0 +1,66 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/distributed/fleet_executor/task_loop_thread_pool.h" + +#include "paddle/fluid/distributed/fleet_executor/task_loop.h" +#include "paddle/fluid/distributed/fleet_executor/task_loop_thread.h" +#include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/errors.h" + +namespace paddle { +namespace distributed { + +TaskLoopThreadPool::TaskLoopThreadPool() : TaskLoopThreadPool(1) {} + +TaskLoopThreadPool::TaskLoopThreadPool(int thread_num) + : start_(false), thread_num_(thread_num) {} + +TaskLoopThreadPool::~TaskLoopThreadPool() = default; + +void TaskLoopThreadPool::Start() { + PADDLE_ENFORCE_EQ(start_, false, platform::errors::PreconditionNotMet( + "thread pool is already start.")); + PADDLE_ENFORCE_GT( + thread_num_, 0, + platform::errors::InvalidArgument( + "thread num must greater than 0, but now is %d", thread_num_)); + + start_ = true; + for (int i = 0; i < thread_num_; ++i) { + threads_.emplace_back(new TaskLoopThread()); + loops_.push_back(threads_[i]->StartLoop()); + } +} + +TaskLoop* TaskLoopThreadPool::GetLoop(int tid) { + PADDLE_ENFORCE_EQ(start_, true, platform::errors::PreconditionNotMet( + "thread pool must start first.")); + PADDLE_ENFORCE_GE(tid, 0, platform::errors::OutOfRange( + "tid must >= 0, but now is %d", tid)); + PADDLE_ENFORCE_LT(tid, thread_num_, + platform::errors::OutOfRange( + "tid must < thread_num, but now tid=%d thread_num=%d", + tid, thread_num_)); + return loops_[tid]; +} + +std::vector TaskLoopThreadPool::GetAllLoops() { + PADDLE_ENFORCE_EQ(start_, true, platform::errors::PreconditionNotMet( + "thread pool must start first.")); + return loops_; +} + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/fleet_executor/task_loop_thread_pool.h b/paddle/fluid/distributed/fleet_executor/task_loop_thread_pool.h new file mode 100644 index 0000000000000..ffc9588f4e77e --- /dev/null +++ b/paddle/fluid/distributed/fleet_executor/task_loop_thread_pool.h @@ -0,0 +1,47 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +namespace paddle { +namespace distributed { + +class TaskLoop; +class TaskLoopThread; + +class TaskLoopThreadPool { + public: + TaskLoopThreadPool(); + explicit TaskLoopThreadPool(int thread_num); + ~TaskLoopThreadPool(); + + void SetThreadNum(int thread_num) { thread_num_ = thread_num; } + + void Start(); + + TaskLoop* GetLoop(int tid); + std::vector GetAllLoops(); + + private: + bool start_; + int thread_num_; + std::vector> threads_; + std::vector loops_; +}; + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/framework/blocking_queue.h b/paddle/fluid/framework/blocking_queue.h index 5bc38c1398aa5..04937fa6b97b3 100644 --- a/paddle/fluid/framework/blocking_queue.h +++ b/paddle/fluid/framework/blocking_queue.h @@ -81,6 +81,16 @@ class BlockingQueue { std::swap(*empty_queue, q_); } + std::deque PopAll() { + std::deque ret; + { + std::unique_lock lock(mutex_); + cv_.wait(lock, [this] { return !q_.empty(); }); + std::swap(ret, q_); + } + return ret; + } + T Pop() { std::unique_lock lock(mutex_); cv_.wait(lock, [=] { return !q_.empty(); }); From acb3c8a57965195cc39ce41c19f9bc26684bacbf Mon Sep 17 00:00:00 2001 From: WangXi Date: Fri, 24 Dec 2021 11:41:14 +0800 Subject: [PATCH 2/7] use thread_pool in carrier --- .../distributed/fleet_executor/carrier.cc | 13 +++ .../distributed/fleet_executor/carrier.h | 4 + .../distributed/fleet_executor/interceptor.cc | 97 +++++++++---------- .../distributed/fleet_executor/interceptor.h | 25 ++--- 4 files changed, 69 insertions(+), 70 deletions(-) diff --git a/paddle/fluid/distributed/fleet_executor/carrier.cc b/paddle/fluid/distributed/fleet_executor/carrier.cc index 3279f954fa5f8..5526355764947 100644 --- a/paddle/fluid/distributed/fleet_executor/carrier.cc +++ b/paddle/fluid/distributed/fleet_executor/carrier.cc @@ -42,6 +42,12 @@ void Carrier::Init(int64_t rank, std::shared_ptr runtime_graph, place_ = place; root_scope_ = root_scope; dev_ctx_ = platform::DeviceContextPool::Instance().Get(place_); + + // TODO(fleet_exe dev): thread pool + thread_num_ = 1; + thread_pool_.SetThreadNum(thread_num_); + thread_pool_.Start(); + CreateInterceptors(); is_init_ = true; } @@ -183,6 +189,13 @@ Interceptor* Carrier::SetInterceptor(int64_t interceptor_id, "The interceptor id should be unique.", interceptor_id)); interceptor->RegisterCarrier(this); + + // TODO(fleet_exe dev): get loop + auto* loop = thread_pool_.GetLoop(interceptor_id % thread_num_); + PADDLE_ENFORCE_NOT_NULL( + loop, platform::errors::Fatal("thread task loop must not null")); + interceptor->RegisterTaskLoop(loop); + auto* ptr = interceptor.get(); interceptor_idx_to_interceptor_.insert( std::make_pair(interceptor_id, std::move(interceptor))); diff --git a/paddle/fluid/distributed/fleet_executor/carrier.h b/paddle/fluid/distributed/fleet_executor/carrier.h index 54cf2150030fc..a298b1979a2f3 100644 --- a/paddle/fluid/distributed/fleet_executor/carrier.h +++ b/paddle/fluid/distributed/fleet_executor/carrier.h @@ -24,6 +24,7 @@ #include "paddle/fluid/distributed/fleet_executor/interceptor.h" #include "paddle/fluid/distributed/fleet_executor/interceptor_message.pb.h" +#include "paddle/fluid/distributed/fleet_executor/task_loop_thread_pool.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/errors.h" @@ -118,6 +119,9 @@ class Carrier final { std::shared_ptr msg_bus_; int64_t rank_; std::unordered_map interceptor_id_to_rank_; + + int thread_num_; + TaskLoopThreadPool thread_pool_; }; } // namespace distributed diff --git a/paddle/fluid/distributed/fleet_executor/interceptor.cc b/paddle/fluid/distributed/fleet_executor/interceptor.cc index f5501754cd729..282123a33123f 100644 --- a/paddle/fluid/distributed/fleet_executor/interceptor.cc +++ b/paddle/fluid/distributed/fleet_executor/interceptor.cc @@ -14,26 +14,20 @@ #include "paddle/fluid/distributed/fleet_executor/interceptor.h" #include "paddle/fluid/distributed/fleet_executor/carrier.h" +#include "paddle/fluid/distributed/fleet_executor/task_loop.h" #include "paddle/fluid/distributed/fleet_executor/task_node.h" namespace paddle { namespace distributed { Interceptor::Interceptor(int64_t interceptor_id, TaskNode* node) - : interceptor_id_(interceptor_id), node_(node) { - interceptor_thread_ = std::thread([this]() { - VLOG(3) << "Interceptor " << interceptor_id_ - << " starts the thread pooling it's local mailbox."; - PoolTheMailbox(); - }); -} - -Interceptor::~Interceptor() { Join(); } + : interceptor_id_(interceptor_id), node_(node) {} -void Interceptor::Join() { - if (interceptor_thread_.joinable()) { - interceptor_thread_.join(); - } +Interceptor::~Interceptor() { + std::lock_guard lock(mutex_); + PADDLE_ENFORCE_EQ(messages_.empty(), true, + platform::errors::PreconditionNotMet( + "Interceptor must destruct with messages empty")); } void Interceptor::RegisterMsgHandle(MsgHandle handle) { handle_ = handle; } @@ -44,6 +38,32 @@ void Interceptor::Handle(const InterceptorMessage& msg) { handle_(msg); } +void Interceptor::LoopOnce() { + std::deque tmp_messages; + { + std::lock_guard lock(mutex_); + messages_.swap(tmp_messages); + } + PADDLE_ENFORCE_EQ(tmp_messages.empty(), false, + platform::errors::PreconditionNotMet( + "tmp_messages must not empty in task loop")); + + for (auto& msg : tmp_messages) { + const MessageType message_type = msg.message_type(); + VLOG(3) << "Interceptor " << interceptor_id_ << " has received a message" + << " from interceptor " << interceptor_message.src_id() + << " with message: " << message_type << "."; + + Handle(msg); + // TODO(wangxi): unregister + if (stop_) { + // break the pooling thread + VLOG(3) << "Interceptor " << interceptor_id_ << " is quiting."; + break; + } + } +} + void Interceptor::StopCarrier() { PADDLE_ENFORCE_NOT_NULL(carrier_, platform::errors::PreconditionNotMet( "Carrier is not registered.")); @@ -52,17 +72,21 @@ void Interceptor::StopCarrier() { cond_var.notify_all(); } -int64_t Interceptor::GetInterceptorId() const { - // return the interceptor id - return interceptor_id_; -} - void Interceptor::EnqueueRemoteInterceptorMessage( - const InterceptorMessage& interceptor_message) { + const InterceptorMessage& message) { // Called by Carrier, enqueue an InterceptorMessage to remote mailbox VLOG(3) << "Enqueue message: " << interceptor_message.message_type() << " into " << interceptor_id_ << "'s remote mailbox."; - remote_mailbox_.Push(interceptor_message); + + bool empty = false; + { + std::lock_guard lock(mutex_); + empty = messages_.empty(); + messages_.emplace_back(message); + } + if (empty) { + loop_->QueueInLoop([this]() { LoopOnce(); }); + } } bool Interceptor::Send(int64_t dst_id, InterceptorMessage& msg) { @@ -73,39 +97,6 @@ bool Interceptor::Send(int64_t dst_id, InterceptorMessage& msg) { return carrier_->Send(msg); } -void Interceptor::PoolTheMailbox() { - // pool the local mailbox, parse the Message - for (;;) { - if (local_mailbox_.empty()) { - // local mailbox is empty, fetch the remote mailbox - VLOG(3) << interceptor_id_ << "'s local mailbox is empty. " - << "Fetch the remote mailbox."; - PADDLE_ENFORCE_EQ(FetchRemoteMailbox(), true, - platform::errors::InvalidArgument( - "Error encountered when fetch remote mailbox.")); - } - const InterceptorMessage interceptor_message = local_mailbox_.front(); - local_mailbox_.pop_front(); - const MessageType message_type = interceptor_message.message_type(); - VLOG(3) << "Interceptor " << interceptor_id_ << " has received a message" - << " from interceptor " << interceptor_message.src_id() - << " with message: " << message_type << "."; - - Handle(interceptor_message); - - if (stop_) { - // break the pooling thread - VLOG(3) << "Interceptor " << interceptor_id_ << " is quiting."; - break; - } - } -} - -bool Interceptor::FetchRemoteMailbox() { - remote_mailbox_.PopAll(&local_mailbox_); - return !local_mailbox_.empty(); -} - static InterceptorFactory::CreateInterceptorMap& GetInterceptorMap() { static InterceptorFactory::CreateInterceptorMap interceptorMap; return interceptorMap; diff --git a/paddle/fluid/distributed/fleet_executor/interceptor.h b/paddle/fluid/distributed/fleet_executor/interceptor.h index d9e8d050dd1fc..6309b96f304ac 100644 --- a/paddle/fluid/distributed/fleet_executor/interceptor.h +++ b/paddle/fluid/distributed/fleet_executor/interceptor.h @@ -38,6 +38,7 @@ namespace distributed { class TaskNode; class Carrier; +class TaskLoop; class Interceptor { public: @@ -58,7 +59,7 @@ class Interceptor { void Handle(const InterceptorMessage& msg); // return the interceptor id - int64_t GetInterceptorId() const; + int64_t GetInterceptorId() const { return interceptor_id_; } // Called by Carrier, enqueue an InterceptorMessage to remote mailbox void EnqueueRemoteInterceptorMessage( @@ -77,6 +78,7 @@ class Interceptor { gc_ = gc; } void RegisterCarrier(Carrier* carrier) { carrier_ = carrier; } + void RegisterTaskLoop(TaskLoop* loop) { loop_ = loop; } TaskNode* GetTaskNode() const { return node_; } @@ -101,28 +103,17 @@ class Interceptor { std::shared_ptr gc_{nullptr}; Carrier* carrier_; + TaskLoop* loop_; private: - // pool the local mailbox, parse the Message - void PoolTheMailbox(); - - // fetch all Message from remote mailbox to local mailbox - // return true if remote mailbox not empty, otherwise return false - bool FetchRemoteMailbox(); + void LoopOnce(); // interceptor handle which process message MsgHandle handle_{nullptr}; - // interceptor runs PoolTheMailbox() function to poll local mailbox - std::thread interceptor_thread_; - - // remote mailbox, written by EnqueueRemoteMessage() - // read by FetchRemoteMailbox() - framework::BlockingQueue remote_mailbox_; - - // local mailbox, written by FetchRemoteMailbox() - // read by PoolTheMailbox() - std::deque local_mailbox_; + std::mutex mutex_; + std::deque messages_; + // std::deque local_messages_; int64_t already_run_times_{0}; int64_t used_slot_nums_{0}; From f62c676aea297697e80047c7b2a326a674518f7a Mon Sep 17 00:00:00 2001 From: WangXi Date: Fri, 24 Dec 2021 04:54:03 +0000 Subject: [PATCH 3/7] fix --- .../distributed/fleet_executor/CMakeLists.txt | 2 +- .../fluid/distributed/fleet_executor/carrier.cc | 5 ----- paddle/fluid/distributed/fleet_executor/carrier.h | 6 +++++- .../distributed/fleet_executor/interceptor.cc | 15 ++++++++------- 4 files changed, 14 insertions(+), 14 deletions(-) diff --git a/paddle/fluid/distributed/fleet_executor/CMakeLists.txt b/paddle/fluid/distributed/fleet_executor/CMakeLists.txt index bb6440ed82534..95ec6b329964e 100644 --- a/paddle/fluid/distributed/fleet_executor/CMakeLists.txt +++ b/paddle/fluid/distributed/fleet_executor/CMakeLists.txt @@ -10,7 +10,7 @@ else() set(BRPC_DEPS "") endif() -cc_library(task_loop_thread_pool SRCS task_loop_thread_pool.cc task_loop_thread.cc task_loop.cc) +cc_library(task_loop_thread_pool SRCS task_loop_thread_pool.cc task_loop_thread.cc task_loop.cc DEPS enforce glog) cc_library(fleet_executor SRCS fleet_executor.cc carrier.cc task_node.cc runtime_graph.cc interceptor.cc compute_interceptor.cc amplifier_interceptor.cc interceptor_message_service.cc message_bus.cc diff --git a/paddle/fluid/distributed/fleet_executor/carrier.cc b/paddle/fluid/distributed/fleet_executor/carrier.cc index 5526355764947..0faad3e6a026f 100644 --- a/paddle/fluid/distributed/fleet_executor/carrier.cc +++ b/paddle/fluid/distributed/fleet_executor/carrier.cc @@ -66,11 +66,6 @@ void Carrier::Release() { stop_msg.set_message_type(STOP); Send(stop_msg); } - - // TODO(wangxi): Maybe need a better to use thread. - for (auto& interceptor : interceptor_idx_to_interceptor_) { - interceptor.second->Join(); - } } Carrier::~Carrier() { VLOG(3) << "Carrier's destructor."; } diff --git a/paddle/fluid/distributed/fleet_executor/carrier.h b/paddle/fluid/distributed/fleet_executor/carrier.h index a298b1979a2f3..48a9ead1d180c 100644 --- a/paddle/fluid/distributed/fleet_executor/carrier.h +++ b/paddle/fluid/distributed/fleet_executor/carrier.h @@ -48,7 +48,11 @@ class Carrier final { Carrier() = default; Carrier(int64_t rank, const std::unordered_map& interceptor_id_to_rank) - : rank_(rank), interceptor_id_to_rank_(interceptor_id_to_rank) {} + : rank_(rank), interceptor_id_to_rank_(interceptor_id_to_rank) { + thread_num_ = 1; + thread_pool_.SetThreadNum(thread_num_); + thread_pool_.Start(); + } ~Carrier(); void Init(int64_t rank, std::shared_ptr runtime_graph, framework::Scope* root_scope, framework::Scope* minibatch_scope, diff --git a/paddle/fluid/distributed/fleet_executor/interceptor.cc b/paddle/fluid/distributed/fleet_executor/interceptor.cc index 282123a33123f..a178d9544ac82 100644 --- a/paddle/fluid/distributed/fleet_executor/interceptor.cc +++ b/paddle/fluid/distributed/fleet_executor/interceptor.cc @@ -24,10 +24,11 @@ Interceptor::Interceptor(int64_t interceptor_id, TaskNode* node) : interceptor_id_(interceptor_id), node_(node) {} Interceptor::~Interceptor() { - std::lock_guard lock(mutex_); - PADDLE_ENFORCE_EQ(messages_.empty(), true, - platform::errors::PreconditionNotMet( - "Interceptor must destruct with messages empty")); + // FIXME(wangxi): throw in stop function + // std::lock_guard lock(mutex_); + // PADDLE_ENFORCE_EQ(messages_.empty(), true, + // platform::errors::PreconditionNotMet( + // "Interceptor must destruct with messages empty")); } void Interceptor::RegisterMsgHandle(MsgHandle handle) { handle_ = handle; } @@ -51,7 +52,7 @@ void Interceptor::LoopOnce() { for (auto& msg : tmp_messages) { const MessageType message_type = msg.message_type(); VLOG(3) << "Interceptor " << interceptor_id_ << " has received a message" - << " from interceptor " << interceptor_message.src_id() + << " from interceptor " << msg.src_id() << " with message: " << message_type << "."; Handle(msg); @@ -75,8 +76,8 @@ void Interceptor::StopCarrier() { void Interceptor::EnqueueRemoteInterceptorMessage( const InterceptorMessage& message) { // Called by Carrier, enqueue an InterceptorMessage to remote mailbox - VLOG(3) << "Enqueue message: " << interceptor_message.message_type() - << " into " << interceptor_id_ << "'s remote mailbox."; + VLOG(3) << "Enqueue message: " << message.message_type() << " into " + << interceptor_id_ << "'s remote mailbox."; bool empty = false; { From 8c9cc59b59e437d4652c7b159afb64bb1f53bd9c Mon Sep 17 00:00:00 2001 From: WangXi Date: Fri, 24 Dec 2021 16:30:40 +0800 Subject: [PATCH 4/7] refine interceptor --- .../distributed/fleet_executor/carrier.cc | 86 ++++--------------- .../distributed/fleet_executor/carrier.h | 16 +--- .../fleet_executor/compute_interceptor.cc | 2 +- .../distributed/fleet_executor/interceptor.cc | 15 ++-- .../distributed/fleet_executor/interceptor.h | 2 - 5 files changed, 27 insertions(+), 94 deletions(-) diff --git a/paddle/fluid/distributed/fleet_executor/carrier.cc b/paddle/fluid/distributed/fleet_executor/carrier.cc index 0faad3e6a026f..170e47fec926f 100644 --- a/paddle/fluid/distributed/fleet_executor/carrier.cc +++ b/paddle/fluid/distributed/fleet_executor/carrier.cc @@ -56,16 +56,18 @@ void Carrier::Release() { // NOTE(wangxi): must join before `Derived Interceptor` destruct, // otherwise Derived object will be destructed before thread complete. - for (int64_t id : source_interceptor_ids_) { - VLOG(3) << "Carrier Release is sending stop to source interceptor " << id - << "."; - InterceptorMessage stop_msg; - // source node STOP is send by carrier, so set src_id=-1 - stop_msg.set_src_id(-1); - stop_msg.set_dst_id(id); - stop_msg.set_message_type(STOP); - Send(stop_msg); - } + // FIXME(wangxi): should we need stop? + // for (int64_t id : source_interceptor_ids_) { + // VLOG(3) << "Carrier Release is sending stop to source interceptor " << + // id + // << "."; + // InterceptorMessage stop_msg; + // // source node STOP is send by carrier, so set src_id=-1 + // stop_msg.set_src_id(-1); + // stop_msg.set_dst_id(id); + // stop_msg.set_message_type(STOP); + // Send(stop_msg); + // } } Carrier::~Carrier() { VLOG(3) << "Carrier's destructor."; } @@ -77,17 +79,6 @@ bool Carrier::EnqueueInterceptorMessage( << interceptor_message.src_id() << " to rank " << interceptor_message.dst_id(); } else { - { - std::unique_lock lock_creating(creating_flag_mutex_); - if (creating_interceptors_) { - std::unique_lock lock_message(tmp_message_mutex_); - // Cannot handle the message to interceptor since interceptors - // are still under creating. Will enqueue into a tmp stack. - VLOG(3) << "Receiving message while creating interceptors."; - message_tmp_.emplace_back(interceptor_message); - return true; - } - } int64_t dst_id = interceptor_message.dst_id(); Interceptor* dst_interceptor = GetInterceptor(dst_id); dst_interceptor->EnqueueRemoteInterceptorMessage(interceptor_message); @@ -110,6 +101,11 @@ void Carrier::Wait() { cond_var_.wait(lock); } +void Carrier::WakeUp() { + // probably double notify, but ok for ut + cond_var_.notify_all(); +} + void Carrier::Start() { PADDLE_ENFORCE_EQ(msg_bus_->IsInit(), true, platform::errors::PreconditionNotMet( @@ -127,12 +123,11 @@ void Carrier::Start() { start_msg.set_message_type(DATA_IS_READY); Send(start_msg); } + // TODO(wangxi): async step Wait(); dev_ctx_->Wait(); } -std::condition_variable& Carrier::GetCondVar() { return cond_var_; } - bool Carrier::IsInit() const { return is_init_; } int64_t Carrier::GetRank(int64_t interceptor_id) const { @@ -197,45 +192,6 @@ Interceptor* Carrier::SetInterceptor(int64_t interceptor_id, return ptr; } -void Carrier::SetCreatingFlag(bool flag) { - // set the creating flag - creating_flag_mutex_.lock(); - VLOG(3) << "Carrier is set the creating flag from " << creating_interceptors_ - << " to " << flag << "."; - creating_interceptors_ = flag; - creating_flag_mutex_.unlock(); - if (!flag) { - for (auto& pair : interceptor_idx_to_interceptor_) { - // update the source interceptor id - if (std::find(source_interceptor_ids_.begin(), - source_interceptor_ids_.end(), - pair.first) == source_interceptor_ids_.end()) { - auto task = pair.second->GetTaskNode(); - if (task != nullptr && task->upstream().empty()) { - source_interceptor_ids_.emplace_back(pair.first); - } - } - } - // finish create interceptors outside, handle tmp messsages - HandleTmpMessages(); - } -} - -void Carrier::HandleTmpMessages() { - // NOTE: It's ok lock on the tmp_message_mutex_ here, when enter this - // `HandleTmpMessages` method, the creating_interceptors_ flag - // must be false, therefore, there won't have conflict with the - // lock on the tmp_message_mutex_ inside `EnqueueInterceptorMessage` - // on the same thread. - std::unique_lock lock(tmp_message_mutex_); - VLOG(3) << "Carrier has received " << message_tmp_.size() - << " messages during creating interceptors."; - for (const auto& msg : message_tmp_) { - EnqueueInterceptorMessage(msg); - } - message_tmp_.clear(); -} - static std::shared_ptr GetGC( const platform::Place& place) { int64_t max_memory_size = framework::GetEagerDeletionThreshold(); @@ -293,12 +249,6 @@ void Carrier::CreateInterceptors() { source_interceptor_ids_.emplace_back(interceptor_id); } } - // The carrier will be always waiting for outside initializer - // since there is no interceptor has been created during auto init - creating_flag_mutex_.lock(); - creating_interceptors_ = false; - creating_flag_mutex_.unlock(); - HandleTmpMessages(); } } // namespace distributed diff --git a/paddle/fluid/distributed/fleet_executor/carrier.h b/paddle/fluid/distributed/fleet_executor/carrier.h index 48a9ead1d180c..c9c8aa618d061 100644 --- a/paddle/fluid/distributed/fleet_executor/carrier.h +++ b/paddle/fluid/distributed/fleet_executor/carrier.h @@ -61,6 +61,7 @@ class Carrier final { void Release(); void Wait(); + void WakeUp(); // Enqueue a message to corresponding interceptor id bool EnqueueInterceptorMessage(const InterceptorMessage& interceptor_message); @@ -72,32 +73,23 @@ class Carrier final { Interceptor* SetInterceptor(int64_t interceptor_id, std::unique_ptr); - void SetCreatingFlag(bool flag); + void SetCreatingFlag(bool flag) {} void SetMsgBus(const std::shared_ptr& msg_bus) { msg_bus_ = msg_bus; } - std::condition_variable& GetCondVar(); - void Start(); bool IsInit() const; bool Send(const InterceptorMessage& msg); - // NOTE: This mutex will be used in interceptor's RunOps function. - // This mutex is used for avoiding forward ops and backward ops run - // simultaneously, which will lead to a random hang for some sync ops. - std::mutex run; - private: DISABLE_COPY_AND_ASSIGN(Carrier); // create each Interceptor void CreateInterceptors(); - void HandleTmpMessages(); - int64_t GetRank(int64_t interceptor_id) const; // interceptor logic id to actually interceptor @@ -106,10 +98,6 @@ class Carrier final { std::vector source_interceptor_ids_; - std::vector message_tmp_{}; - std::mutex tmp_message_mutex_; - bool creating_interceptors_{true}; - std::mutex creating_flag_mutex_; bool is_init_{false}; std::mutex running_mutex_; diff --git a/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc b/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc index 1f0d3408a3da8..60f39b8a36087 100644 --- a/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc +++ b/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc @@ -170,7 +170,6 @@ void ComputeInterceptor::ReplyCompletedToUpStream() { } void ComputeInterceptor::RunOps() { - std::unique_lock lock(carrier_->run); VLOG(3) << "ComputeInterceptor " << interceptor_id_ << " running ops for the " << step_ + 1 << " time."; for (auto op : node_->ops()) { @@ -198,6 +197,7 @@ void ComputeInterceptor::Run() { if (is_last_ && (step_ % node_->max_run_times() == 0)) { VLOG(3) << "Interceptor " << GetInterceptorId() << " is stopping carrier."; + // FIXME(wangxi): with multi sink StopCarrier(); } } diff --git a/paddle/fluid/distributed/fleet_executor/interceptor.cc b/paddle/fluid/distributed/fleet_executor/interceptor.cc index a178d9544ac82..e526a739be4a3 100644 --- a/paddle/fluid/distributed/fleet_executor/interceptor.cc +++ b/paddle/fluid/distributed/fleet_executor/interceptor.cc @@ -56,21 +56,18 @@ void Interceptor::LoopOnce() { << " with message: " << message_type << "."; Handle(msg); - // TODO(wangxi): unregister - if (stop_) { - // break the pooling thread - VLOG(3) << "Interceptor " << interceptor_id_ << " is quiting."; - break; - } + // if (stop_) { + // // break the pooling thread + // VLOG(3) << "Interceptor " << interceptor_id_ << " is quiting."; + // break; + // } } } void Interceptor::StopCarrier() { PADDLE_ENFORCE_NOT_NULL(carrier_, platform::errors::PreconditionNotMet( "Carrier is not registered.")); - std::condition_variable& cond_var = carrier_->GetCondVar(); - // probably double notify, but ok for ut - cond_var.notify_all(); + carrier_->WakeUp(); } void Interceptor::EnqueueRemoteInterceptorMessage( diff --git a/paddle/fluid/distributed/fleet_executor/interceptor.h b/paddle/fluid/distributed/fleet_executor/interceptor.h index 6309b96f304ac..34c78e8d266a2 100644 --- a/paddle/fluid/distributed/fleet_executor/interceptor.h +++ b/paddle/fluid/distributed/fleet_executor/interceptor.h @@ -51,8 +51,6 @@ class Interceptor { virtual ~Interceptor(); - void Join(); - // register interceptor handle void RegisterMsgHandle(MsgHandle handle); From b55b99a6bcac900c28da9ac403d0d3a0e08a6550 Mon Sep 17 00:00:00 2001 From: WangXi Date: Fri, 24 Dec 2021 18:56:04 +0800 Subject: [PATCH 5/7] add barrier --- .../distributed/fleet_executor/carrier.cc | 3 + .../distributed/fleet_executor/carrier.h | 2 + .../fleet_executor/fleet_executor.cc | 3 + .../distributed/fleet_executor/message_bus.cc | 57 ++++++++++++++----- .../distributed/fleet_executor/message_bus.h | 11 +++- .../interceptor_ping_pong_with_brpc_test.cc | 4 ++ 6 files changed, 65 insertions(+), 15 deletions(-) diff --git a/paddle/fluid/distributed/fleet_executor/carrier.cc b/paddle/fluid/distributed/fleet_executor/carrier.cc index 170e47fec926f..ca886523dad37 100644 --- a/paddle/fluid/distributed/fleet_executor/carrier.cc +++ b/paddle/fluid/distributed/fleet_executor/carrier.cc @@ -78,6 +78,7 @@ bool Carrier::EnqueueInterceptorMessage( VLOG(3) << "Receiving control message from rank " << interceptor_message.src_id() << " to rank " << interceptor_message.dst_id(); + msg_bus_->IncreaseBarrierCount(); } else { int64_t dst_id = interceptor_message.dst_id(); Interceptor* dst_interceptor = GetInterceptor(dst_id); @@ -86,6 +87,8 @@ bool Carrier::EnqueueInterceptorMessage( return true; } +void Carrier::Barrier() { msg_bus_->Barrier(); } + Interceptor* Carrier::GetInterceptor(int64_t interceptor_id) { auto iter = interceptor_idx_to_interceptor_.find(interceptor_id); PADDLE_ENFORCE_NE(iter, interceptor_idx_to_interceptor_.end(), diff --git a/paddle/fluid/distributed/fleet_executor/carrier.h b/paddle/fluid/distributed/fleet_executor/carrier.h index c9c8aa618d061..81643a74550b0 100644 --- a/paddle/fluid/distributed/fleet_executor/carrier.h +++ b/paddle/fluid/distributed/fleet_executor/carrier.h @@ -84,6 +84,8 @@ class Carrier final { bool Send(const InterceptorMessage& msg); + void Barrier(); + private: DISABLE_COPY_AND_ASSIGN(Carrier); diff --git a/paddle/fluid/distributed/fleet_executor/fleet_executor.cc b/paddle/fluid/distributed/fleet_executor/fleet_executor.cc index 697c4aaaf3aaa..84add613f2c58 100644 --- a/paddle/fluid/distributed/fleet_executor/fleet_executor.cc +++ b/paddle/fluid/distributed/fleet_executor/fleet_executor.cc @@ -89,6 +89,9 @@ void FleetExecutor::Init( CreateCarrier(); InitCarrier(); InitMessageBus(); + + // refine this? wait all carrier ready + GetCarrier()->Barrier(); } void FleetExecutor::InitCarrier() { diff --git a/paddle/fluid/distributed/fleet_executor/message_bus.cc b/paddle/fluid/distributed/fleet_executor/message_bus.cc index ac7b08c4b2868..dd95a90ad1ba4 100644 --- a/paddle/fluid/distributed/fleet_executor/message_bus.cc +++ b/paddle/fluid/distributed/fleet_executor/message_bus.cc @@ -105,21 +105,53 @@ bool MessageBus::Send(int64_t dst_rank, return true; } -void MessageBus::TestConnection() { - InterceptorMessage ctrl_msg; - ctrl_msg.set_ctrl_message(true); - ctrl_msg.set_src_id(rank_); - for (const auto& dst_rank_pair : rank_to_addr_) { - int64_t dst_rank = dst_rank_pair.first; - if (dst_rank != rank_) { - ctrl_msg.set_dst_id(dst_rank); - VLOG(3) << "Send control message bus from rank " << rank_ << " to rank " - << dst_rank; - while (!Send(dst_rank, ctrl_msg)) { +void MessageBus::IncreaseBarrierCount() { + VLOG(3) << "IncreaseBarrierCount"; + { + std::unique_lock lock(mutex_); + ++count_; + cv_.notify_one(); + } + VLOG(3) << "End IncreaseBarrierCount"; +} + +void MessageBus::Barrier() { + // gather to root + if (rank_ != 0) { + InterceptorMessage ctrl_msg; + ctrl_msg.set_ctrl_message(true); + ctrl_msg.set_src_id(rank_); + ctrl_msg.set_dst_id(0); + VLOG(3) << "Barrier Gather ctrl message from " << rank_ << " to 0"; + while (!Send(0, ctrl_msg)) { + std::this_thread::sleep_for(std::chrono::milliseconds(1000)); + } + } else { + VLOG(3) << "Barrier 0 wait others rank ready"; + std::unique_lock lock(mutex_); + cv_.wait(lock, [this] { + return count_ == static_cast(rank_to_addr_.size() - 1); + }); + count_ = 0; + } + + // scatter from root + if (rank_ == 0) { + for (int i = 1; i < static_cast(rank_to_addr_.size()); ++i) { + InterceptorMessage ctrl_msg; + ctrl_msg.set_ctrl_message(true); + ctrl_msg.set_src_id(0); + ctrl_msg.set_dst_id(i); + VLOG(3) << "Barrier Scatter ctrl message from 0 to " << i; + while (!Send(i, ctrl_msg)) { std::this_thread::sleep_for(std::chrono::milliseconds(1000)); } - VLOG(3) << "Message bus has connected to rank: " << dst_rank << "."; } + } else { + VLOG(3) << "Barrier " << rank_ << " wait others rank ready"; + std::unique_lock lock(mutex_); + cv_.wait(lock, [this] { return count_ == 1; }); + count_ = 0; } } @@ -151,7 +183,6 @@ void MessageBus::ListenPort() { interval += 500; } LOG(INFO) << "Message bus's listen port thread starts successful."; - TestConnection(); #else LOG(WARNING) << "Fleet executor's ListenPort() is a fake function when Paddle is " diff --git a/paddle/fluid/distributed/fleet_executor/message_bus.h b/paddle/fluid/distributed/fleet_executor/message_bus.h index d4a2af54e6cd4..c8685a73900d5 100644 --- a/paddle/fluid/distributed/fleet_executor/message_bus.h +++ b/paddle/fluid/distributed/fleet_executor/message_bus.h @@ -14,6 +14,7 @@ #pragma once +#include #include #include #include @@ -51,14 +52,15 @@ class MessageBus final { // called by Interceptor, send InterceptorMessage to dst bool Send(int64_t dst_rank, const InterceptorMessage& interceptor_message); + void IncreaseBarrierCount(); + void Barrier(); + private: DISABLE_COPY_AND_ASSIGN(MessageBus); // function keep listen the port and handle the message void ListenPort(); - void TestConnection(); - const std::string& GetAddr(int64_t rank) const; #if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \ @@ -84,6 +86,11 @@ class MessageBus final { // brpc server brpc::Server server_; #endif + + // for barrier + std::mutex mutex_; + std::condition_variable cv_; + int count_{0}; }; } // namespace distributed diff --git a/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_with_brpc_test.cc b/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_with_brpc_test.cc index a577b30fa8c0b..6515a53ca470e 100644 --- a/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_with_brpc_test.cc +++ b/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_with_brpc_test.cc @@ -119,6 +119,8 @@ TEST(InterceptorTest, PingPong) { carrier->SetMsgBus(msg_bus); Interceptor* a = carrier->SetInterceptor( 0, InterceptorFactory::Create("PingPong", 0, nullptr)); + carrier->Barrier(); + InterceptorMessage msg; a->Send(1, msg); carrier->Wait(); @@ -131,6 +133,8 @@ TEST(InterceptorTest, PingPong) { carrier->SetMsgBus(msg_bus); carrier->SetInterceptor( 1, InterceptorFactory::Create("PingPong", 1, nullptr)); + carrier->Barrier(); + carrier->Wait(); int status; int ret = waitpid(pid, &status, 0); From 98dffe3e5905b5d46bb5a103c1beeb8377f573ae Mon Sep 17 00:00:00 2001 From: WangXi Date: Mon, 27 Dec 2021 03:37:44 +0000 Subject: [PATCH 6/7] fix abort --- paddle/fluid/distributed/fleet_executor/fleet_executor.cc | 2 ++ .../test/interceptor_ping_pong_with_brpc_test.cc | 5 +++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/distributed/fleet_executor/fleet_executor.cc b/paddle/fluid/distributed/fleet_executor/fleet_executor.cc index 84add613f2c58..a5badcb36eb3e 100644 --- a/paddle/fluid/distributed/fleet_executor/fleet_executor.cc +++ b/paddle/fluid/distributed/fleet_executor/fleet_executor.cc @@ -91,6 +91,8 @@ void FleetExecutor::Init( InitMessageBus(); // refine this? wait all carrier ready + // NOTE(wangxi): must add after Carrier::SetMsgBus, for we use + // MessageBus::IncreaseBarrierCount when receive barrier msg. GetCarrier()->Barrier(); } diff --git a/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_with_brpc_test.cc b/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_with_brpc_test.cc index 6515a53ca470e..262e5caa8c82e 100644 --- a/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_with_brpc_test.cc +++ b/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_with_brpc_test.cc @@ -115,8 +115,9 @@ TEST(InterceptorTest, PingPong) { FleetExecutor::CreateCarrier(0, interceptor_id_to_rank); carrier->SetCreatingFlag(false); auto msg_bus = std::make_shared(); - msg_bus->Init(0, {{0, ip0}, {1, ip1}}, ip0); carrier->SetMsgBus(msg_bus); + // NOTE: need Init msg_bus after carrier SetMsgBus + msg_bus->Init(0, {{0, ip0}, {1, ip1}}, ip0); Interceptor* a = carrier->SetInterceptor( 0, InterceptorFactory::Create("PingPong", 0, nullptr)); carrier->Barrier(); @@ -129,8 +130,8 @@ TEST(InterceptorTest, PingPong) { FleetExecutor::CreateCarrier(1, interceptor_id_to_rank); carrier->SetCreatingFlag(false); auto msg_bus = std::make_shared(); - msg_bus->Init(1, {{0, ip0}, {1, ip1}}, ip1); carrier->SetMsgBus(msg_bus); + msg_bus->Init(1, {{0, ip0}, {1, ip1}}, ip1); carrier->SetInterceptor( 1, InterceptorFactory::Create("PingPong", 1, nullptr)); carrier->Barrier(); From febebd891085cd14b027b0ccc641ef930200f0d9 Mon Sep 17 00:00:00 2001 From: WangXi Date: Mon, 27 Dec 2021 11:50:40 +0800 Subject: [PATCH 7/7] refine --- .../distributed/fleet_executor/carrier.cc | 19 ++----------------- .../fleet_executor/compute_interceptor.cc | 2 +- .../distributed/fleet_executor/interceptor.cc | 5 ----- .../distributed/fleet_executor/interceptor.h | 1 - 4 files changed, 3 insertions(+), 24 deletions(-) diff --git a/paddle/fluid/distributed/fleet_executor/carrier.cc b/paddle/fluid/distributed/fleet_executor/carrier.cc index ca886523dad37..ea35b36aa4a75 100644 --- a/paddle/fluid/distributed/fleet_executor/carrier.cc +++ b/paddle/fluid/distributed/fleet_executor/carrier.cc @@ -52,23 +52,7 @@ void Carrier::Init(int64_t rank, std::shared_ptr runtime_graph, is_init_ = true; } -void Carrier::Release() { - // NOTE(wangxi): must join before `Derived Interceptor` destruct, - // otherwise Derived object will be destructed before thread complete. - - // FIXME(wangxi): should we need stop? - // for (int64_t id : source_interceptor_ids_) { - // VLOG(3) << "Carrier Release is sending stop to source interceptor " << - // id - // << "."; - // InterceptorMessage stop_msg; - // // source node STOP is send by carrier, so set src_id=-1 - // stop_msg.set_src_id(-1); - // stop_msg.set_dst_id(id); - // stop_msg.set_message_type(STOP); - // Send(stop_msg); - // } -} +void Carrier::Release() {} Carrier::~Carrier() { VLOG(3) << "Carrier's destructor."; } @@ -78,6 +62,7 @@ bool Carrier::EnqueueInterceptorMessage( VLOG(3) << "Receiving control message from rank " << interceptor_message.src_id() << " to rank " << interceptor_message.dst_id(); + // for barrier msg_bus_->IncreaseBarrierCount(); } else { int64_t dst_id = interceptor_message.dst_id(); diff --git a/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc b/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc index 60f39b8a36087..d934ab1948e7e 100644 --- a/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc +++ b/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc @@ -197,7 +197,7 @@ void ComputeInterceptor::Run() { if (is_last_ && (step_ % node_->max_run_times() == 0)) { VLOG(3) << "Interceptor " << GetInterceptorId() << " is stopping carrier."; - // FIXME(wangxi): with multi sink + // FIXME(wangxi): with multi sink interceptor StopCarrier(); } } diff --git a/paddle/fluid/distributed/fleet_executor/interceptor.cc b/paddle/fluid/distributed/fleet_executor/interceptor.cc index e526a739be4a3..710ebda41244e 100644 --- a/paddle/fluid/distributed/fleet_executor/interceptor.cc +++ b/paddle/fluid/distributed/fleet_executor/interceptor.cc @@ -56,11 +56,6 @@ void Interceptor::LoopOnce() { << " with message: " << message_type << "."; Handle(msg); - // if (stop_) { - // // break the pooling thread - // VLOG(3) << "Interceptor " << interceptor_id_ << " is quiting."; - // break; - // } } } diff --git a/paddle/fluid/distributed/fleet_executor/interceptor.h b/paddle/fluid/distributed/fleet_executor/interceptor.h index 34c78e8d266a2..cb7ff2da89a9d 100644 --- a/paddle/fluid/distributed/fleet_executor/interceptor.h +++ b/paddle/fluid/distributed/fleet_executor/interceptor.h @@ -111,7 +111,6 @@ class Interceptor { std::mutex mutex_; std::deque messages_; - // std::deque local_messages_; int64_t already_run_times_{0}; int64_t used_slot_nums_{0};