From 0d768e9e49e1fd90aa35ea458eba429197edfa11 Mon Sep 17 00:00:00 2001 From: Alexey Ozeritskiy Date: Wed, 18 Sep 2024 19:18:15 +0200 Subject: [PATCH] Optimize file downloading for requests with strong worker filters (#9452) --- .../global_worker_manager/worker_filter.cpp | 26 +++++--- .../dq/global_worker_manager/worker_filter.h | 2 + .../global_worker_manager/workers_storage.cpp | 3 + .../workers_storage_ut.cpp | 60 +++++++++++++++++++ 4 files changed, 84 insertions(+), 7 deletions(-) diff --git a/ydb/library/yql/providers/dq/global_worker_manager/worker_filter.cpp b/ydb/library/yql/providers/dq/global_worker_manager/worker_filter.cpp index 8b5f443c63e8..433644b01433 100644 --- a/ydb/library/yql/providers/dq/global_worker_manager/worker_filter.cpp +++ b/ydb/library/yql/providers/dq/global_worker_manager/worker_filter.cpp @@ -23,20 +23,29 @@ TWorkerFilter::TWorkerFilter(const Yql::DqsProto::TWorkerFilter& filter) } } -TWorkerFilter::EMatchStatus TWorkerFilter::Match(const TWorkerInfo::TPtr& workerInfo, int taskId, TStats* stats) const { - bool allExists = true; - bool partial = false; +bool TWorkerFilter::MatchHost(const NDqs::TWorkerInfo::TPtr& workerInfo) const { if (FullMatch) { if (Filter.GetClusterName() && workerInfo->ClusterName != Filter.GetClusterName()) { - return EFAIL; + return false; } if (!Addresses.empty() && Addresses.find(workerInfo->Address) == Addresses.end()) { - return EFAIL; + return false; } if (!NodeIds.empty() && NodeIds.find(workerInfo->NodeId) == NodeIds.end()) { - return EFAIL; + return false; } } + + return true; +} + +TWorkerFilter::EMatchStatus TWorkerFilter::Match(const TWorkerInfo::TPtr& workerInfo, int taskId, TStats* stats) const { + bool allExists = true; + bool partial = false; + + if (!MatchHost(workerInfo)) { + return EFAIL; + } if (Filter.GetClusterNameHint() && workerInfo->ClusterName != Filter.GetClusterNameHint()) { partial = true; } @@ -52,7 +61,10 @@ TWorkerFilter::EMatchStatus TWorkerFilter::Match(const TWorkerInfo::TPtr& worker (*stats->WaitingResources)[id].insert(taskId); } else { (*stats->WaitingResources)[id].erase(taskId); - stats->Uploaded->find(id)->second.TryCount ++; + auto maybeUploadedStats = stats->Uploaded->find(id); + if (maybeUploadedStats != stats->Uploaded->end()) { + maybeUploadedStats->second.TryCount ++; + } } } } diff --git a/ydb/library/yql/providers/dq/global_worker_manager/worker_filter.h b/ydb/library/yql/providers/dq/global_worker_manager/worker_filter.h index 8e938e72d49b..1593870ab5ba 100644 --- a/ydb/library/yql/providers/dq/global_worker_manager/worker_filter.h +++ b/ydb/library/yql/providers/dq/global_worker_manager/worker_filter.h @@ -40,6 +40,8 @@ class TWorkerFilter { TWorkerFilter(const Yql::DqsProto::TWorkerFilter& filter); EMatchStatus Match(const NDqs::TWorkerInfo::TPtr& workerInfo, int taskId, TStats* stats) const; + // match mandatory host-specific fields like Address, NodeId, ClusterName + bool MatchHost(const NDqs::TWorkerInfo::TPtr& workerInfo) const; void Visit(const std::function& visitor) const; diff --git a/ydb/library/yql/providers/dq/global_worker_manager/workers_storage.cpp b/ydb/library/yql/providers/dq/global_worker_manager/workers_storage.cpp index b6f915d1d992..695c32ea34e9 100644 --- a/ydb/library/yql/providers/dq/global_worker_manager/workers_storage.cpp +++ b/ydb/library/yql/providers/dq/global_worker_manager/workers_storage.cpp @@ -437,6 +437,9 @@ TVector TWorkersStorage::TryAllocate(const NDq::IScheduler::T if (workerInfo->Stopping) { continue; } + if (!filter.MatchHost(workerInfo)) { + continue; + } filter.Visit([&](const auto& file) { if (workerInfo->AddToDownloadList(file.GetObjectId(), file)) { YQL_CLOG(TRACE, ProviderDq) << "Added " << file.GetName() << "|" << file.GetObjectId() << " to worker's " << GetGuidAsString(workerInfo->WorkerId) << " download list" ; diff --git a/ydb/library/yql/providers/dq/global_worker_manager/workers_storage_ut.cpp b/ydb/library/yql/providers/dq/global_worker_manager/workers_storage_ut.cpp index efa23c8a3266..cdf7863210e8 100644 --- a/ydb/library/yql/providers/dq/global_worker_manager/workers_storage_ut.cpp +++ b/ydb/library/yql/providers/dq/global_worker_manager/workers_storage_ut.cpp @@ -79,4 +79,64 @@ Y_UNIT_TEST_SUITE(WorkersBenchmark) { UNIT_ASSERT_VALUES_EQUAL(all.size(), 100); UNIT_ASSERT_VALUES_EQUAL(0, storage.FreeSlots()); } + + Y_UNIT_TEST(ScheduleDownload) { + int workers = 10; + TWorkersStorage storage(1, new TSensorsGroup, new TSensorsGroup); + storage.Clear(); + for (int i = 0; i < workers; i++) { + TGUID guid; + Yql::DqsProto::RegisterNodeRequest request; + request.SetCapacity(100); + request.AddKnownNodes(1); + CreateGuid(&guid); + storage.CreateOrUpdate(100+i, guid, request); + } + + { + auto request = NDqProto::TAllocateWorkersRequest(); + request.SetCount(10); + + auto waitInfo1 = IScheduler::TWaitInfo(request, NActors::TActorId()); + auto result = storage.TryAllocate(waitInfo1); + + UNIT_ASSERT_VALUES_EQUAL(result.size(), 10); + } + + { + auto request = NDqProto::TAllocateWorkersRequest(); + auto workerFilter = Yql::DqsProto::TWorkerFilter(); + workerFilter.AddNodeId(102); + + request.SetCount(10); + for (ui32 i = 0; i < request.GetCount(); i++) { + *request.AddWorkerFilterPerTask() = workerFilter; + } + auto waitInfo2 = IScheduler::TWaitInfo(request, NActors::TActorId()); + auto result = storage.TryAllocate(waitInfo2); + UNIT_ASSERT_VALUES_EQUAL(result.size(), 10); + } + + { + auto request = NDqProto::TAllocateWorkersRequest(); + auto workerFilter = Yql::DqsProto::TWorkerFilter(); + workerFilter.AddNodeId(102); + Yql::DqsProto::TFile file; + file.SetObjectId("fileId"); + file.SetLocalPath("/tmp/test"); + *workerFilter.AddFile() = file; + request.SetCount(10); + for (ui32 i = 0; i < request.GetCount(); i++) { + *request.AddWorkerFilterPerTask() = workerFilter; + } + + auto waitInfo3 = IScheduler::TWaitInfo(request, NActors::TActorId()); + auto result = storage.TryAllocate(waitInfo3); + UNIT_ASSERT_VALUES_EQUAL(result.size(), 0); + + storage.Visit([](const NDqs::TWorkerInfo::TPtr& workerInfo) { + UNIT_ASSERT(workerInfo->GetDownloadList().size() == 0 || workerInfo->NodeId == 102); + }); + } + } }