diff --git a/dbms/src/Common/GRPCQueue.h b/dbms/src/Common/GRPCQueue.h index 333c94a5381..497dab3a4ad 100644 --- a/dbms/src/Common/GRPCQueue.h +++ b/dbms/src/Common/GRPCQueue.h @@ -139,6 +139,9 @@ class GRPCSendQueue bool isWritable() const { return send_queue.isWritable(); } + void registerPipeReadTask(TaskPtr && task) { send_queue.registerPipeReadTask(std::move(task)); } + void registerPipeWriteTask(TaskPtr && task) { send_queue.registerPipeWriteTask(std::move(task)); } + private: friend class tests::TestGRPCSendQueue; @@ -297,6 +300,9 @@ class GRPCRecvQueue bool isWritable() const { return recv_queue.isWritable(); } + void registerPipeReadTask(TaskPtr && task) { recv_queue.registerPipeReadTask(std::move(task)); } + void registerPipeWriteTask(TaskPtr && task) { recv_queue.registerPipeWriteTask(std::move(task)); } + private: friend class tests::TestGRPCRecvQueue; diff --git a/dbms/src/Flash/Coprocessor/DAGResponseWriter.h b/dbms/src/Flash/Coprocessor/DAGResponseWriter.h index e4be0eb59b9..b9a756e365d 100644 --- a/dbms/src/Flash/Coprocessor/DAGResponseWriter.h +++ b/dbms/src/Flash/Coprocessor/DAGResponseWriter.h @@ -15,6 +15,7 @@ #pragma once #include +#include #include #include @@ -30,12 +31,13 @@ class DAGResponseWriter virtual void prepare(const Block &){}; virtual void write(const Block & block) = 0; - // For async writer, `isWritable` need to be called before calling `write`. + // For async writer, `waitForWritable` need to be called before calling `write`. // ``` - // while (!isWritable()) {} + // auto res = waitForWritable(); + // switch (res) case... // write(block); // ``` - virtual bool isWritable() const { throw Exception("Unsupport"); } + virtual WaitResult waitForWritable() const { throw Exception("Unsupport"); } /// flush cached blocks for batch writer virtual void flush() = 0; diff --git a/dbms/src/Flash/Coprocessor/StreamWriter.h b/dbms/src/Flash/Coprocessor/StreamWriter.h index 4eb62b27144..41383ee49c2 100644 --- a/dbms/src/Flash/Coprocessor/StreamWriter.h +++ b/dbms/src/Flash/Coprocessor/StreamWriter.h @@ -16,6 +16,7 @@ #include #include +#include #ifdef __clang__ #pragma clang diagnostic push #pragma clang diagnostic ignored "-Wdeprecated-declarations" @@ -57,7 +58,7 @@ struct CopStreamWriter if (!writer->Write(resp)) throw Exception("Failed to write resp"); } - bool isWritable() const { throw Exception("Unsupport async write"); } + static WaitResult waitForWritable() { throw Exception("Unsupport async write"); } }; struct BatchCopStreamWriter @@ -81,7 +82,7 @@ struct BatchCopStreamWriter if (!writer->Write(resp)) throw Exception("Failed to write resp"); } - bool isWritable() const { throw Exception("Unsupport async write"); } + static WaitResult waitForWritable() { throw Exception("Unsupport async write"); } }; using CopStreamWriterPtr = std::shared_ptr; diff --git a/dbms/src/Flash/Coprocessor/StreamingDAGResponseWriter.cpp b/dbms/src/Flash/Coprocessor/StreamingDAGResponseWriter.cpp index b1169ca46bc..a6f39cb25dc 100644 --- a/dbms/src/Flash/Coprocessor/StreamingDAGResponseWriter.cpp +++ b/dbms/src/Flash/Coprocessor/StreamingDAGResponseWriter.cpp @@ -68,9 +68,9 @@ void StreamingDAGResponseWriter::flush() } template -bool StreamingDAGResponseWriter::isWritable() const +WaitResult StreamingDAGResponseWriter::waitForWritable() const { - return writer->isWritable(); + return writer->waitForWritable(); } template diff --git a/dbms/src/Flash/Coprocessor/StreamingDAGResponseWriter.h b/dbms/src/Flash/Coprocessor/StreamingDAGResponseWriter.h index e6f1e61b59e..61ca9a71517 100644 --- a/dbms/src/Flash/Coprocessor/StreamingDAGResponseWriter.h +++ b/dbms/src/Flash/Coprocessor/StreamingDAGResponseWriter.h @@ -37,7 +37,7 @@ class StreamingDAGResponseWriter : public DAGResponseWriter Int64 batch_send_min_limit_, DAGContext & dag_context_); void write(const Block & block) override; - bool isWritable() const override; + WaitResult waitForWritable() const override; void flush() override; private: diff --git a/dbms/src/Flash/Coprocessor/WaitResult.h b/dbms/src/Flash/Coprocessor/WaitResult.h new file mode 100644 index 00000000000..ed45e89111e --- /dev/null +++ b/dbms/src/Flash/Coprocessor/WaitResult.h @@ -0,0 +1,25 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +namespace DB +{ +enum class WaitResult +{ + Ready, + WaitForPolling, + WaitForNotify +}; +} // namespace DB diff --git a/dbms/src/Flash/Coprocessor/tests/gtest_streaming_writer.cpp b/dbms/src/Flash/Coprocessor/tests/gtest_streaming_writer.cpp index 6ce2821d659..a3351f294ca 100644 --- a/dbms/src/Flash/Coprocessor/tests/gtest_streaming_writer.cpp +++ b/dbms/src/Flash/Coprocessor/tests/gtest_streaming_writer.cpp @@ -93,7 +93,7 @@ struct MockStreamWriter {} void write(tipb::SelectResponse & response) { checker(response); } - static bool isWritable() { throw Exception("Unsupport async write"); } + static WaitResult waitForWritable() { throw Exception("Unsupport async write"); } private: MockStreamWriterChecker checker; diff --git a/dbms/src/Flash/Coprocessor/tests/gtest_ti_remote_block_inputstream.cpp b/dbms/src/Flash/Coprocessor/tests/gtest_ti_remote_block_inputstream.cpp index 2d6ddce2881..dafd6f8ce93 100644 --- a/dbms/src/Flash/Coprocessor/tests/gtest_ti_remote_block_inputstream.cpp +++ b/dbms/src/Flash/Coprocessor/tests/gtest_ti_remote_block_inputstream.cpp @@ -147,7 +147,7 @@ struct MockWriter queue->push(tracked_packet); } static uint16_t getPartitionNum() { return 1; } - static bool isWritable() { throw Exception("Unsupport async write"); } + static WaitResult waitForWritable() { throw Exception("Unsupport async write"); } std::vector result_field_types; diff --git a/dbms/src/Flash/Executor/PipelineExecutor.cpp b/dbms/src/Flash/Executor/PipelineExecutor.cpp index a5e1ac7fa41..f80cb1200e0 100644 --- a/dbms/src/Flash/Executor/PipelineExecutor.cpp +++ b/dbms/src/Flash/Executor/PipelineExecutor.cpp @@ -35,6 +35,7 @@ PipelineExecutor::PipelineExecutor( /*query_id=*/context.getDAGContext()->isMPPTask() ? context.getDAGContext()->getMPPTaskId().toString() : "", req_id, memory_tracker_, + context.getDAGContext(), auto_spill_trigger, register_operator_spill_context, context.getDAGContext()->getResourceGroupName()) diff --git a/dbms/src/Flash/Executor/PipelineExecutorContext.cpp b/dbms/src/Flash/Executor/PipelineExecutorContext.cpp index 74b57015598..250e824d8d4 100644 --- a/dbms/src/Flash/Executor/PipelineExecutorContext.cpp +++ b/dbms/src/Flash/Executor/PipelineExecutorContext.cpp @@ -12,8 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include +#include +#include #include #include @@ -52,6 +55,24 @@ String PipelineExecutorContext::getExceptionMsg() } } +String PipelineExecutorContext::getTrimmedErrMsg() +{ + try + { + auto cur_exception_ptr = getExceptionPtr(); + if (!cur_exception_ptr) + return ""; + std::rethrow_exception(cur_exception_ptr); + } + catch (...) + { + auto err_msg = getCurrentExceptionMessage(true, true); + if (likely(!err_msg.empty())) + trimStackTrace(err_msg); + return err_msg; + } +} + void PipelineExecutorContext::onErrorOccurred(const String & err_msg) { DB::Exception e(err_msg); @@ -155,6 +176,12 @@ void PipelineExecutorContext::cancel() if (is_cancelled.compare_exchange_strong(origin_value, true, std::memory_order_release)) { cancelSharedQueues(); + if (likely(dag_context)) + { + // Cancel the tunnel_set here to prevent pipeline tasks waiting in the WAIT_FOR_NOTIFY state from never being notified. + if (dag_context->tunnel_set) + dag_context->tunnel_set->close(getTrimmedErrMsg(), false); + } cancelResultQueueIfNeed(); if likely (TaskScheduler::instance && !query_id.empty()) TaskScheduler::instance->cancel(query_id, resource_group_name); diff --git a/dbms/src/Flash/Executor/PipelineExecutorContext.h b/dbms/src/Flash/Executor/PipelineExecutorContext.h index a562d580518..724c5f7e641 100644 --- a/dbms/src/Flash/Executor/PipelineExecutorContext.h +++ b/dbms/src/Flash/Executor/PipelineExecutorContext.h @@ -34,6 +34,8 @@ using RegisterOperatorSpillContext = std::function; +class DAGContext; + class PipelineExecutorContext : private boost::noncopyable { public: @@ -51,12 +53,14 @@ class PipelineExecutorContext : private boost::noncopyable const String & query_id_, const String & req_id, const MemoryTrackerPtr & mem_tracker_, + DAGContext * dag_context_ = nullptr, AutoSpillTrigger * auto_spill_trigger_ = nullptr, const RegisterOperatorSpillContext & register_operator_spill_context_ = nullptr, const String & resource_group_name_ = "") : query_id(query_id_) , log(Logger::get(req_id)) , mem_tracker(mem_tracker_) + , dag_context(dag_context_) , auto_spill_trigger(auto_spill_trigger_) , register_operator_spill_context(register_operator_spill_context_) , resource_group_name(resource_group_name_) @@ -134,6 +138,8 @@ class PipelineExecutorContext : private boost::noncopyable private: bool setExceptionPtr(const std::exception_ptr & exception_ptr_); + String getTrimmedErrMsg(); + // Need to be called under lock. bool isWaitMode(); @@ -149,6 +155,8 @@ class PipelineExecutorContext : private boost::noncopyable MemoryTrackerPtr mem_tracker; + DAGContext * dag_context{nullptr}; + std::mutex mu; std::condition_variable cv; std::exception_ptr exception_ptr; diff --git a/dbms/src/Flash/Mpp/BroadcastOrPassThroughWriter.cpp b/dbms/src/Flash/Mpp/BroadcastOrPassThroughWriter.cpp index 0adbac1d439..1b33e73019a 100644 --- a/dbms/src/Flash/Mpp/BroadcastOrPassThroughWriter.cpp +++ b/dbms/src/Flash/Mpp/BroadcastOrPassThroughWriter.cpp @@ -71,9 +71,9 @@ void BroadcastOrPassThroughWriter::flush() } template -bool BroadcastOrPassThroughWriter::isWritable() const +WaitResult BroadcastOrPassThroughWriter::waitForWritable() const { - return writer->isWritable(); + return writer->waitForWritable(); } template diff --git a/dbms/src/Flash/Mpp/BroadcastOrPassThroughWriter.h b/dbms/src/Flash/Mpp/BroadcastOrPassThroughWriter.h index 9a28d13a461..be615c4c21c 100644 --- a/dbms/src/Flash/Mpp/BroadcastOrPassThroughWriter.h +++ b/dbms/src/Flash/Mpp/BroadcastOrPassThroughWriter.h @@ -37,7 +37,7 @@ class BroadcastOrPassThroughWriter : public DAGResponseWriter tipb::CompressionMode compression_mode_, tipb::ExchangeType exchange_type_); void write(const Block & block) override; - bool isWritable() const override; + WaitResult waitForWritable() const override; void flush() override; private: diff --git a/dbms/src/Flash/Mpp/FineGrainedShuffleWriter.cpp b/dbms/src/Flash/Mpp/FineGrainedShuffleWriter.cpp index 729505656c7..37387f7e23c 100644 --- a/dbms/src/Flash/Mpp/FineGrainedShuffleWriter.cpp +++ b/dbms/src/Flash/Mpp/FineGrainedShuffleWriter.cpp @@ -97,9 +97,9 @@ void FineGrainedShuffleWriter::flush() } template -bool FineGrainedShuffleWriter::isWritable() const +WaitResult FineGrainedShuffleWriter::waitForWritable() const { - return writer->isWritable(); + return writer->waitForWritable(); } template diff --git a/dbms/src/Flash/Mpp/FineGrainedShuffleWriter.h b/dbms/src/Flash/Mpp/FineGrainedShuffleWriter.h index 7a0afd4adfe..5bdb5a52e77 100644 --- a/dbms/src/Flash/Mpp/FineGrainedShuffleWriter.h +++ b/dbms/src/Flash/Mpp/FineGrainedShuffleWriter.h @@ -40,7 +40,7 @@ class FineGrainedShuffleWriter : public DAGResponseWriter tipb::CompressionMode compression_mode_); void prepare(const Block & sample_block) override; void write(const Block & block) override; - bool isWritable() const override; + WaitResult waitForWritable() const override; void flush() override; private: diff --git a/dbms/src/Flash/Mpp/HashPartitionWriter.cpp b/dbms/src/Flash/Mpp/HashPartitionWriter.cpp index a86ea0f00ff..fd501015663 100644 --- a/dbms/src/Flash/Mpp/HashPartitionWriter.cpp +++ b/dbms/src/Flash/Mpp/HashPartitionWriter.cpp @@ -94,9 +94,9 @@ void HashPartitionWriter::flush() } template -bool HashPartitionWriter::isWritable() const +WaitResult HashPartitionWriter::waitForWritable() const { - return writer->isWritable(); + return writer->waitForWritable(); } template diff --git a/dbms/src/Flash/Mpp/HashPartitionWriter.h b/dbms/src/Flash/Mpp/HashPartitionWriter.h index b0d31b7b2e7..8e36d28234d 100644 --- a/dbms/src/Flash/Mpp/HashPartitionWriter.h +++ b/dbms/src/Flash/Mpp/HashPartitionWriter.h @@ -38,7 +38,7 @@ class HashPartitionWriter : public DAGResponseWriter MPPDataPacketVersion data_codec_version_, tipb::CompressionMode compression_mode_); void write(const Block & block) override; - bool isWritable() const override; + WaitResult waitForWritable() const override; void flush() override; private: diff --git a/dbms/src/Flash/Mpp/LocalRequestHandler.h b/dbms/src/Flash/Mpp/LocalRequestHandler.h index 20f52acc5b1..a6422d79880 100644 --- a/dbms/src/Flash/Mpp/LocalRequestHandler.h +++ b/dbms/src/Flash/Mpp/LocalRequestHandler.h @@ -42,6 +42,9 @@ struct LocalRequestHandler bool isWritable() const { return msg_queue->isWritable(); } + void registerPipeReadTask(TaskPtr && task) const { msg_queue->registerPipeReadTask(std::move(task)); } + void registerPipeWriteTask(TaskPtr && task) const { msg_queue->registerPipeWriteTask(std::move(task)); } + void writeDone(bool meet_error, const String & local_err_msg) const { notify_write_done(meet_error, local_err_msg); diff --git a/dbms/src/Flash/Mpp/MPPTask.cpp b/dbms/src/Flash/Mpp/MPPTask.cpp index 75c5c221ed4..2bac7bcd115 100644 --- a/dbms/src/Flash/Mpp/MPPTask.cpp +++ b/dbms/src/Flash/Mpp/MPPTask.cpp @@ -759,13 +759,13 @@ void MPPTask::abort(const String & message, AbortType abort_type) } else if (previous_status == RUNNING && switchStatus(RUNNING, next_task_status)) { - /// abort the components from top to bottom because if bottom components are aborted - /// first, the top components may see an error caused by the abort, which is not + /// abort mpptunnels first because if others components are aborted + /// first, the mpptunnels may see an error caused by the abort, which is not /// the original error setErrString(message); abortTunnels(message, false); - abortQueryExecutor(); abortReceivers(); + abortQueryExecutor(); scheduleThisTask(ScheduleState::FAILED); /// runImpl is running, leave remaining work to runImpl LOG_WARNING(log, "Finish abort task from running"); diff --git a/dbms/src/Flash/Mpp/MPPTunnel.cpp b/dbms/src/Flash/Mpp/MPPTunnel.cpp index c0fe94efb52..9a16322d241 100644 --- a/dbms/src/Flash/Mpp/MPPTunnel.cpp +++ b/dbms/src/Flash/Mpp/MPPTunnel.cpp @@ -380,7 +380,7 @@ void MPPTunnel::waitUntilConnectedOrFinished(std::unique_lock & lk) throw Exception(fmt::format("MPPTunnel {} can not be connected because MPPTask is cancelled", tunnel_id)); } -bool MPPTunnel::isWritable() const +WaitResult MPPTunnel::waitForWritable() const { std::unique_lock lk(mu); switch (status) @@ -396,12 +396,17 @@ bool MPPTunnel::isWritable() const if (unlikely(timeout_stopwatch->elapsed() > timeout_nanoseconds)) throw Exception(fmt::format("{} is timeout", tunnel_id)); } - return false; + return WaitResult::WaitForPolling; } case TunnelStatus::Connected: case TunnelStatus::WaitingForSenderFinish: RUNTIME_CHECK_MSG(tunnel_sender != nullptr, "write to tunnel {} which is already closed.", tunnel_id); - return tunnel_sender->isWritable(); + if (!tunnel_sender->isWritable()) + { + setNotifyFuture(tunnel_sender); + return WaitResult::WaitForNotify; + } + return WaitResult::Ready; case TunnelStatus::Finished: RUNTIME_CHECK_MSG(tunnel_sender != nullptr, "write to tunnel {} which is already closed.", tunnel_id); throw Exception(fmt::format( diff --git a/dbms/src/Flash/Mpp/MPPTunnel.h b/dbms/src/Flash/Mpp/MPPTunnel.h index 58062b7a170..4c2421437e4 100644 --- a/dbms/src/Flash/Mpp/MPPTunnel.h +++ b/dbms/src/Flash/Mpp/MPPTunnel.h @@ -22,10 +22,12 @@ #include #include #include +#include #include #include #include #include +#include #include #include #include @@ -92,10 +94,12 @@ enum class TunnelSenderMode /// TunnelSender is responsible for consuming data from Tunnel's internal send_queue and do the actual sending work /// After TunnelSend finished its work, either normally or abnormally, set ConsumerState to inform Tunnel -class TunnelSender : private boost::noncopyable +class TunnelSender + : private boost::noncopyable + , public NotifyFuture { public: - virtual ~TunnelSender() = default; + ~TunnelSender() override = default; TunnelSender( MemoryTrackerPtr & memory_tracker_, const LoggerPtr & log_, @@ -193,6 +197,8 @@ class SyncTunnelSender : public TunnelSender bool isWritable() const override { return send_queue.isWritable(); } + void registerTask(TaskPtr && task) override { send_queue.registerPipeWriteTask(std::move(task)); } + private: friend class tests::TestMPPTunnel; void sendJob(PacketWriter * writer); @@ -254,6 +260,8 @@ class AsyncTunnelSender : public TunnelSender void subDataSizeMetric(size_t size) { ::DB::MPPTunnelMetric::subDataSizeMetric(*data_size_in_queue, size); } + void registerTask(TaskPtr && task) override { queue.registerPipeWriteTask(std::move(task)); } + private: GRPCSendQueue queue; }; @@ -311,13 +319,24 @@ class LocalTunnelSenderV2 : public TunnelSender } } + void registerTask(TaskPtr && task) override + { + if constexpr (local_only) + local_request_handler.registerPipeWriteTask(std::move(task)); + else + { + std::lock_guard lock(mu); + local_request_handler.registerPipeWriteTask(std::move(task)); + } + } + private: friend class tests::TestMPPTunnel; template bool pushImpl(TrackedMppDataPacketPtr && data) { - if (unlikely(checkPacketErr(data))) + if (unlikely(is_done || checkPacketErr(data))) return false; // When ExchangeReceiver receives data from local and remote tiflash, number of local tunnel threads @@ -405,6 +424,8 @@ class LocalTunnelSenderV1 : public TunnelSender bool isWritable() const override { return send_queue.isWritable(); } + void registerTask(TaskPtr && task) override { send_queue.registerPipeWriteTask(std::move(task)); } + private: bool cancel_reason_sent = false; LooseBoundedMPMCQueue send_queue; @@ -472,13 +493,14 @@ class MPPTunnel : private boost::noncopyable void write(TrackedMppDataPacketPtr && data); // forceWrite write a single packet to the tunnel's send queue without blocking, - // and need to call isReadForWrite first. + // and need to call waitForWritable first. // ``` - // while (!isWritable()) {} + // auto res = waitForWritable(); + // switch (res) case... // forceWrite(std::move(data)); // ``` + WaitResult waitForWritable() const; void forceWrite(TrackedMppDataPacketPtr && data); - bool isWritable() const; // finish the writing, and wait until the sender finishes. void writeDone(); diff --git a/dbms/src/Flash/Mpp/MPPTunnelSet.cpp b/dbms/src/Flash/Mpp/MPPTunnelSet.cpp index 6078ec8f2a4..47f7fe2299c 100644 --- a/dbms/src/Flash/Mpp/MPPTunnelSet.cpp +++ b/dbms/src/Flash/Mpp/MPPTunnelSet.cpp @@ -68,14 +68,14 @@ void MPPTunnelSetBase::forceWrite(tipb::SelectResponse & response, size_ } template -bool MPPTunnelSetBase::isWritable() const +WaitResult MPPTunnelSetBase::waitForWritable() const { for (const auto & tunnel : tunnels) { - if (!tunnel->isWritable()) - return false; + if (auto res = tunnel->waitForWritable(); res != WaitResult::Ready) + return res; } - return true; + return WaitResult::Ready; } template diff --git a/dbms/src/Flash/Mpp/MPPTunnelSet.h b/dbms/src/Flash/Mpp/MPPTunnelSet.h index ee2c65768f7..a57f57b3ac8 100644 --- a/dbms/src/Flash/Mpp/MPPTunnelSet.h +++ b/dbms/src/Flash/Mpp/MPPTunnelSet.h @@ -60,7 +60,7 @@ class MPPTunnelSetBase : private boost::noncopyable const std::vector & getTunnels() const { return tunnels; } - bool isWritable() const; + WaitResult waitForWritable() const; bool isLocal(size_t index) const; diff --git a/dbms/src/Flash/Mpp/MPPTunnelSetWriter.h b/dbms/src/Flash/Mpp/MPPTunnelSetWriter.h index ce923a88410..1af6730f108 100644 --- a/dbms/src/Flash/Mpp/MPPTunnelSetWriter.h +++ b/dbms/src/Flash/Mpp/MPPTunnelSetWriter.h @@ -69,7 +69,7 @@ class MPPTunnelSetWriterBase : private boost::noncopyable uint16_t getPartitionNum() const { return mpp_tunnel_set->getPartitionNum(); } - virtual bool isWritable() const = 0; + virtual WaitResult waitForWritable() const = 0; protected: virtual void writeToTunnel(TrackedMppDataPacketPtr && data, size_t index) = 0; @@ -91,8 +91,8 @@ class SyncMPPTunnelSetWriter : public MPPTunnelSetWriterBase : MPPTunnelSetWriterBase(mpp_tunnel_set_, result_field_types_, req_id) {} - // For sync writer, `isWritable` will not be called, so an exception is thrown here. - bool isWritable() const override { throw Exception("Unsupport sync writer"); } + // For sync writer, `waitForWritable` will not be called, so an exception is thrown here. + WaitResult waitForWritable() const override { throw Exception("Unsupport sync writer"); } protected: void writeToTunnel(TrackedMppDataPacketPtr && data, size_t index) override; @@ -110,7 +110,7 @@ class AsyncMPPTunnelSetWriter : public MPPTunnelSetWriterBase : MPPTunnelSetWriterBase(mpp_tunnel_set_, result_field_types_, req_id) {} - bool isWritable() const override { return mpp_tunnel_set->isWritable(); } + WaitResult waitForWritable() const override { return mpp_tunnel_set->waitForWritable(); } protected: void writeToTunnel(TrackedMppDataPacketPtr && data, size_t index) override; diff --git a/dbms/src/Flash/Mpp/ReceivedMessageQueue.h b/dbms/src/Flash/Mpp/ReceivedMessageQueue.h index 6fd41463d12..c975e4d2ab1 100644 --- a/dbms/src/Flash/Mpp/ReceivedMessageQueue.h +++ b/dbms/src/Flash/Mpp/ReceivedMessageQueue.h @@ -99,6 +99,9 @@ class ReceivedMessageQueue bool isWritable() const { return grpc_recv_queue.isWritable(); } + void registerPipeReadTask(TaskPtr && task) { grpc_recv_queue.registerPipeReadTask(std::move(task)); } + void registerPipeWriteTask(TaskPtr && task) { grpc_recv_queue.registerPipeWriteTask(std::move(task)); } + #ifndef DBMS_PUBLIC_GTEST private: #endif diff --git a/dbms/src/Flash/Mpp/tests/gtest_mpp_exchange_writer.cpp b/dbms/src/Flash/Mpp/tests/gtest_mpp_exchange_writer.cpp index 81de0806b79..84e51cb7151 100644 --- a/dbms/src/Flash/Mpp/tests/gtest_mpp_exchange_writer.cpp +++ b/dbms/src/Flash/Mpp/tests/gtest_mpp_exchange_writer.cpp @@ -168,7 +168,7 @@ struct MockExchangeWriter // make only part 0 use local tunnel return index == 0; } - bool isWritable() const { throw Exception("Unsupport async write"); } + static WaitResult waitForWritable() { throw Exception("Unsupport async write"); } private: MockExchangeWriterChecker checker; diff --git a/dbms/src/Flash/Mpp/tests/gtest_mpptunnel.cpp b/dbms/src/Flash/Mpp/tests/gtest_mpptunnel.cpp index 3d71f00cb0e..ff11aef5ec1 100644 --- a/dbms/src/Flash/Mpp/tests/gtest_mpptunnel.cpp +++ b/dbms/src/Flash/Mpp/tests/gtest_mpptunnel.cpp @@ -31,7 +31,6 @@ #include #include - namespace DB { namespace tests @@ -751,8 +750,7 @@ try tunnel->write(std::move(packet)); } catch (...) - { - } + {} }; std::thread thd(tunnelRun, std::move(run_tunnel)); thd.join(); @@ -794,7 +792,7 @@ TEST_F(TestMPPTunnel, SyncTunnelForceWrite) mpp_tunnel_ptr->connectSync(writer_ptr.get()); GTEST_ASSERT_EQ(getTunnelConnectedFlag(mpp_tunnel_ptr), true); - ASSERT_TRUE(mpp_tunnel_ptr->isWritable()); + GTEST_ASSERT_EQ(mpp_tunnel_ptr->waitForWritable(), WaitResult::Ready); mpp_tunnel_ptr->forceWrite(newDataPacket("First")); mpp_tunnel_ptr->writeDone(); GTEST_ASSERT_EQ(getTunnelFinishedFlag(mpp_tunnel_ptr), true); @@ -811,7 +809,7 @@ TEST_F(TestMPPTunnel, AsyncTunnelForceWrite) GTEST_ASSERT_EQ(getTunnelConnectedFlag(mpp_tunnel_ptr), true); std::thread t(&MockAsyncCallData::run, call_data.get()); - ASSERT_TRUE(mpp_tunnel_ptr->isWritable()); + GTEST_ASSERT_EQ(mpp_tunnel_ptr->waitForWritable(), WaitResult::Ready); mpp_tunnel_ptr->forceWrite(newDataPacket("First")); mpp_tunnel_ptr->writeDone(); GTEST_ASSERT_EQ(getTunnelFinishedFlag(mpp_tunnel_ptr), true); @@ -828,7 +826,7 @@ TEST_F(TestMPPTunnel, LocalTunnelForceWrite) GTEST_ASSERT_EQ(getTunnelConnectedFlag(mpp_tunnel_ptr), true); std::thread t(&MockExchangeReceiver::receiveAll, receiver.get()); - ASSERT_TRUE(mpp_tunnel_ptr->isWritable()); + GTEST_ASSERT_EQ(mpp_tunnel_ptr->waitForWritable(), WaitResult::Ready); mpp_tunnel_ptr->forceWrite(newDataPacket("First")); mpp_tunnel_ptr->writeDone(); GTEST_ASSERT_EQ(getTunnelFinishedFlag(mpp_tunnel_ptr), true); @@ -846,7 +844,7 @@ try Stopwatch stop_watch{CLOCK_MONOTONIC_COARSE}; while (stop_watch.elapsedSeconds() < 3 * timeout.count()) { - ASSERT_FALSE(mpp_tunnel_ptr->isWritable()); + GTEST_ASSERT_EQ(mpp_tunnel_ptr->waitForWritable(), WaitResult::WaitForPolling); } GTEST_FAIL(); } diff --git a/dbms/src/Flash/Pipeline/Schedule/TaskQueues/tests/gtest_resource_control_queue.cpp b/dbms/src/Flash/Pipeline/Schedule/TaskQueues/tests/gtest_resource_control_queue.cpp index 5003117c91c..76565dac6ea 100644 --- a/dbms/src/Flash/Pipeline/Schedule/TaskQueues/tests/gtest_resource_control_queue.cpp +++ b/dbms/src/Flash/Pipeline/Schedule/TaskQueues/tests/gtest_resource_control_queue.cpp @@ -176,6 +176,7 @@ class TestResourceControlQueue : public ::testing::Test mem_tracker, nullptr, nullptr, + nullptr, resource_group_name); } return all_contexts; @@ -378,6 +379,7 @@ TEST_F(TestResourceControlQueue, BasicTest) mem_tracker, nullptr, nullptr, + nullptr, group_name); for (int j = 0; j < task_num_per_resource_group; ++j) @@ -413,6 +415,7 @@ TEST_F(TestResourceControlQueue, BasicTimeoutTest) mem_tracker, nullptr, nullptr, + nullptr, group_name); auto task = std::make_unique(*exec_context); @@ -441,7 +444,8 @@ TEST_F(TestResourceControlQueue, RunOutOfRU) TaskSchedulerConfig config{thread_num, thread_num}; TaskScheduler task_scheduler(config); - PipelineExecutorContext exec_context("mock-query-id", "mock-req-id", mem_tracker, nullptr, nullptr, rg_name); + PipelineExecutorContext + exec_context("mock-query-id", "mock-req-id", mem_tracker, nullptr, nullptr, nullptr, rg_name); auto task = std::make_unique(exec_context); // This task should use 5*100ms cpu_time. diff --git a/dbms/src/Operators/ExchangeSenderSinkOp.cpp b/dbms/src/Operators/ExchangeSenderSinkOp.cpp index 664e22e3cc5..1f073c67d20 100644 --- a/dbms/src/Operators/ExchangeSenderSinkOp.cpp +++ b/dbms/src/Operators/ExchangeSenderSinkOp.cpp @@ -39,14 +39,28 @@ OperatorStatus ExchangeSenderSinkOp::writeImpl(Block && block) return OperatorStatus::NEED_INPUT; } +OperatorStatus ExchangeSenderSinkOp::waitForWriter() const +{ + auto res = writer->waitForWritable(); + switch (res) + { + case WaitResult::Ready: + return OperatorStatus::NEED_INPUT; + case WaitResult::WaitForPolling: + return OperatorStatus::WAITING; + case WaitResult::WaitForNotify: + return OperatorStatus::WAIT_FOR_NOTIFY; + } +} + OperatorStatus ExchangeSenderSinkOp::prepareImpl() { - return writer->isWritable() ? OperatorStatus::NEED_INPUT : OperatorStatus::WAITING; + return waitForWriter(); } OperatorStatus ExchangeSenderSinkOp::awaitImpl() { - return writer->isWritable() ? OperatorStatus::NEED_INPUT : OperatorStatus::WAITING; + return waitForWriter(); } } // namespace DB diff --git a/dbms/src/Operators/ExchangeSenderSinkOp.h b/dbms/src/Operators/ExchangeSenderSinkOp.h index 0e549a42ae8..77ed8532f98 100644 --- a/dbms/src/Operators/ExchangeSenderSinkOp.h +++ b/dbms/src/Operators/ExchangeSenderSinkOp.h @@ -45,6 +45,9 @@ class ExchangeSenderSinkOp : public SinkOp OperatorStatus awaitImpl() override; +private: + OperatorStatus waitForWriter() const; + private: std::unique_ptr writer; size_t total_rows = 0;