From bb7ff2694cf663b0f2e69715d92098d47b98ae57 Mon Sep 17 00:00:00 2001 From: Baibaifan Date: Tue, 1 Mar 2022 11:57:57 +0000 Subject: [PATCH] add_new_comm_primitive --- .../distributed/collective/ProcessGroup.h | 20 ++- .../collective/ProcessGroupNCCL.cc | 156 ++++++++++++++++++ .../distributed/collective/ProcessGroupNCCL.h | 17 ++ paddle/fluid/distributed/collective/Types.h | 4 + paddle/fluid/pybind/distributed_py.cc | 33 ++++ .../tests/unittests/process_group_nccl.py | 30 ++++ 6 files changed, 259 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/distributed/collective/ProcessGroup.h b/paddle/fluid/distributed/collective/ProcessGroup.h index dde8622d9007e1..e4f27205202424 100644 --- a/paddle/fluid/distributed/collective/ProcessGroup.h +++ b/paddle/fluid/distributed/collective/ProcessGroup.h @@ -96,7 +96,25 @@ class ProcessGroup { std::vector& /* tensors */, const BroadcastOptions& = BroadcastOptions()) { PADDLE_THROW(platform::errors::InvalidArgument( - "ProcessGroup%s does not support allreduce", GetBackendName())); + "ProcessGroup%s does not support broadcast", GetBackendName())); + } + + virtual std::shared_ptr Barrier( + const BarrierOptions& = BarrierOptions()) { + PADDLE_THROW(platform::errors::InvalidArgument( + "ProcessGroup%s does not support barrier", GetBackendName())); + } + + virtual std::shared_ptr Send( + std::vector& tensors /* tensors */, int dst_rank) { // NOLINT + PADDLE_THROW(platform::errors::InvalidArgument( + "ProcessGroup%s does not support send", GetBackendName())); + } + + virtual std::shared_ptr Recv( + std::vector& tensors /* tensors */, int src_rank) { // NOLINT + PADDLE_THROW(platform::errors::InvalidArgument( + "ProcessGroup%s does not support receive", GetBackendName())); } protected: diff --git a/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc b/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc index fe2325423b460d..5d96e730aa4b1a 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc +++ b/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc @@ -14,6 +14,9 @@ #include "paddle/fluid/distributed/collective/ProcessGroupNCCL.h" #include "paddle/fluid/platform/device/gpu/nccl_helper.h" +#include "paddle/fluid/platform/place.h" +#include "paddle/phi/api/include/api.h" +#include "paddle/phi/common/place.h" DECLARE_bool(nccl_blocking_wait); DECLARE_bool(use_stream_safe_cuda_allocator); @@ -139,6 +142,14 @@ bool ProcessGroupNCCL::NCCLTask::Wait(std::chrono::milliseconds timeout) { std::this_thread::sleep_for(std::chrono::milliseconds(kWaitBlockTImeout)); } } + + if (!barrierTensors_.empty()) { + // If we use the work to do barrier, we should block cpu + for (auto& place : places_) { + platform::CUDADeviceGuard gpuGuard(place); + PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize()); + } + } return true; } @@ -193,6 +204,10 @@ void ProcessGroupNCCL::CreateNCCLManagerCache( nccl_ids.resize(1); auto& nccl_id = nccl_ids.front(); + for (auto& place : places) { + used_place_ids_.insert(place.GetDeviceId()); + } + if (rank_ == 0) { PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGetUniqueId(&nccl_id)); } @@ -274,6 +289,54 @@ std::shared_ptr ProcessGroupNCCL::Collective( return task; } +template +std::shared_ptr ProcessGroupNCCL::PointToPoint( + std::vector& tensors, Fn fn, int dst_rank, CommType op_type) { + const auto places = GetPlaceList(tensors); + const auto key = GetKeyFromPlaces(places); + + { + std::lock_guard lock(mutex_); + if (places_to_ncclcomm_.find(key) == places_to_ncclcomm_.end()) { + CreateNCCLManagerCache(key, places); + } + } + + auto& nccl_comms = places_to_ncclcomm_[key]; + + SyncDefaultStream(places, places_to_events_[key], places_to_ctx_[key]); + + auto task = CreateTask(places, rank_, op_type, tensors); + + // construct uninitialize guard for device + platform::CUDADeviceGuard cuda_guard; + + if (FLAGS_use_stream_safe_cuda_allocator) { + for (size_t i = 0; i < tensors.size(); ++i) { + cuda_guard.SetDevice(places[i]); + auto dense_tensor = + std::dynamic_pointer_cast(tensors[i].impl()); + memory::RecordStream(dense_tensor->Holder(), + places_to_ctx_[key][i]->stream()); + } + } + + { + platform::NCCLGroupGuard nccl_guard; + for (size_t i = 0; i < tensors.size(); ++i) { + cuda_guard.SetDevice(places[i]); + const auto& nccl_stream = places_to_ctx_[key][i]->stream(); + fn(tensors[i], nccl_comms[i]->GetNcclComm(), nccl_stream, dst_rank); + } + } + + for (size_t i = 0; i < tensors.size(); ++i) { + cuda_guard.SetDevice(places[i]); + task->control_events_[i].Record(*places_to_ctx_[key][i]); + } + return task; +} + std::shared_ptr ProcessGroupNCCL::AllReduce( std::vector& tensors, const AllreduceOptions& opts) { PADDLE_ENFORCE_EQ( @@ -317,5 +380,98 @@ std::shared_ptr ProcessGroupNCCL::Broadcast( CommType::BROADCAST); } +std::shared_ptr ProcessGroupNCCL::Barrier( + const BarrierOptions& opts) { + std::vector places; + + if (!opts.place_ids.empty()) { + for (auto place_id : opts.place_ids) { + places.emplace_back(place_id); + } + } else if (!used_place_ids_.empty()) { + for (auto place_id : used_place_ids_) { + places.emplace_back(place_id); + } + } else { + auto numGPUs = GetSize(); + int place_id = static_cast(rank_ % numGPUs); + places.emplace_back(place_id); + } + + std::vector barrierTensors; + barrierTensors.reserve(places.size()); + + platform::CUDADeviceGuard gpuGuard; + for (auto& place : places) { + gpuGuard.SetDeviceIndex(place.GetDeviceId()); + auto dt = full({1}, 0, phi::DataType::FLOAT32, phi::Backend::GPU); + barrierTensors.push_back(dt); + } + auto task = ProcessGroupNCCL::AllReduce(barrierTensors); + auto nccl_task = dynamic_cast(task.get()); + nccl_task->barrierTensors_ = std::move(barrierTensors); + return task; +} + +void CheckTensorsInDifferentDevices(const std::vector& tensors, + const size_t num_devices) { + PADDLE_ENFORCE_EQ( + tensors.size() == 0, false, + platform::errors::InvalidArgument("Tensor list must be nonempty.")); + PADDLE_ENFORCE_LE( + tensors.size(), num_devices, + platform::errors::InvalidArgument( + "Tensor list mustn't be larger than the number of available GPUs.")); + + std::set used_devices; + + for (const auto& t : tensors) { + PADDLE_ENFORCE_EQ(t.is_cuda() && t.is_dense_tensor(), true, + platform::errors::InvalidArgument( + "Tensors must be CUDA and dense tensor.")); + + const auto inserted = used_devices.insert(t.inner_place()).second; + PADDLE_ENFORCE_EQ(inserted, true, + platform::errors::InvalidArgument( + "Tensors must be on distinct GPU devices.")); + } +} + +std::shared_ptr ProcessGroupNCCL::Send( + std::vector& tensors, int dst_rank) { + CheckTensorsInDifferentDevices(tensors, static_cast(GetSize())); + + auto task = PointToPoint( + tensors, + [&](Tensor& input, ncclComm_t comm, const gpuStream_t& stream, + int dst_rank) { + auto input_tensor = + std::dynamic_pointer_cast(input.impl()); + return platform::dynload::ncclSend( + input_tensor->data(), input_tensor->numel(), + platform::ToNCCLDataType(input.type()), dst_rank, comm, stream); + }, + dst_rank, CommType::SEND); + return task; +} + +std::shared_ptr ProcessGroupNCCL::Recv( + std::vector& tensors, int src_rank) { + CheckTensorsInDifferentDevices(tensors, static_cast(GetSize())); + + auto task = PointToPoint( + tensors, + [&](Tensor& output, ncclComm_t comm, const gpuStream_t& stream, + int src_rank) { + auto output_tensor = + std::dynamic_pointer_cast(output.impl()); + return platform::dynload::ncclRecv( + output_tensor->data(), output_tensor->numel(), + platform::ToNCCLDataType(output.type()), src_rank, comm, stream); + }, + src_rank, CommType::RECV); + return task; +} + } // namespace distributed } // namespace paddle diff --git a/paddle/fluid/distributed/collective/ProcessGroupNCCL.h b/paddle/fluid/distributed/collective/ProcessGroupNCCL.h index 9f06566d1c8638..cfeb6467f0dbf2 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupNCCL.h +++ b/paddle/fluid/distributed/collective/ProcessGroupNCCL.h @@ -65,6 +65,7 @@ class ProcessGroupNCCL : public ProcessGroup { virtual ~NCCLTask(); std::vector control_events_; + std::vector barrierTensors_; protected: std::vector places_; @@ -88,6 +89,15 @@ class ProcessGroupNCCL : public ProcessGroup { std::vector& tensors, const BroadcastOptions& = BroadcastOptions()) override; + std::shared_ptr Barrier( + const BarrierOptions& = BarrierOptions()) override; + + std::shared_ptr Send(std::vector& tensors, + int dst_rank) override; + + std::shared_ptr Recv(std::vector& tensors, + int src_rank) override; + protected: virtual std::shared_ptr CreateTask( std::vector places, int rank, CommType opType, @@ -106,6 +116,8 @@ class ProcessGroupNCCL : public ProcessGroup { std::vector>> places_to_ctx_; + std::set used_place_ids_; + private: void BcastNCCLId(std::vector& nccl_ids, int root, // NOLINT int server_fd); @@ -118,6 +130,11 @@ class ProcessGroupNCCL : public ProcessGroup { std::vector& outputs, // NOLINT Fn fn, CommType op_type); + template + std::shared_ptr PointToPoint( + std::vector& tensors, // NOLINT + Fn fn, int dst_rank, CommType op_type); + void CreateNCCLManagerCache(const std::string& places_key, const std::vector& places); }; diff --git a/paddle/fluid/distributed/collective/Types.h b/paddle/fluid/distributed/collective/Types.h index 654d06686957bd..699222ac452dbc 100644 --- a/paddle/fluid/distributed/collective/Types.h +++ b/paddle/fluid/distributed/collective/Types.h @@ -32,5 +32,9 @@ struct BroadcastOptions { int source_root = 0; }; +struct BarrierOptions { + std::vector place_ids; +}; + } // namespace distributed } // namespace paddle diff --git a/paddle/fluid/pybind/distributed_py.cc b/paddle/fluid/pybind/distributed_py.cc index e057fb53ccecc7..da045660e55f10 100644 --- a/paddle/fluid/pybind/distributed_py.cc +++ b/paddle/fluid/pybind/distributed_py.cc @@ -59,6 +59,10 @@ void BindDistributed(py::module *m) { .def_readwrite("source_root", &distributed::BroadcastOptions::source_root); + py::class_(*m, "BarrierOptions") + .def(py::init<>()) + .def_readwrite("place_ids", &distributed::BarrierOptions::place_ids); + auto ProcessGroup = py::class_>(*m, "ProcessGroup") @@ -87,6 +91,35 @@ void BindDistributed(py::module *m) { return self.Broadcast(tensors, opts); }, py::arg("tensor"), py::arg("source_rank"), + py::call_guard()) + + .def("barrier", + [](distributed::ProcessGroup &self, std::vector place_ids) { + distributed::BarrierOptions opts; + opts.place_ids = place_ids; + return self.Barrier(opts); + }, + py::arg("place_ids") = std::vector{}, + py::call_guard()) + + .def("send", + [](distributed::ProcessGroup &self, py::handle py_tensor, + int dst) { + auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0); + std::vector tensors = {tensor}; + return self.Send(tensors, dst); + }, + py::arg("tensor"), py::arg("dst"), + py::call_guard()) + + .def("recv", + [](distributed::ProcessGroup &self, py::handle py_tensor, + int src) { + auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0); + std::vector tensors = {tensor}; + return self.Recv(tensors, src); + }, + py::arg("tensor"), py::arg("src"), py::call_guard()); #if defined(PADDLE_WITH_NCCL) diff --git a/python/paddle/fluid/tests/unittests/process_group_nccl.py b/python/paddle/fluid/tests/unittests/process_group_nccl.py index d999aad63ecf41..8ec5d13c569fe8 100644 --- a/python/paddle/fluid/tests/unittests/process_group_nccl.py +++ b/python/paddle/fluid/tests/unittests/process_group_nccl.py @@ -132,6 +132,36 @@ def test_create_process_group_nccl(self): print("test broadcast api ok") + # test barrier + # rank 0 + if pg.rank() == 0: + task = pg.barrier() + task.wait() + # rank 1 + else: + task = pg.barrier() + task.wait() + + print("test barrier api ok\n") + + # test send/recv + # rank 0 + x = np.random.random(self.shape).astype(self.dtype) + tensor_x = paddle.to_tensor(x) + if pg.rank() == 0: + task = pg.send(tensor_x, dst=1) + task.wait() + paddle.device.cuda.synchronize() + # rank 1 + else: + y = np.random.random(self.shape).astype(self.dtype) + tensor_y = paddle.to_tensor(y) + task = pg.recv(tensor_y, src=0) + task.wait() + paddle.device.cuda.synchronize() + assert np.array_equal(tensor_x, tensor_y) + print("test send/recv api ok\n") + class TestProcessGroupFp16(TestProcessGroupFp32): def setUp(self):