diff --git a/include/envoy/thread_local/thread_local.h b/include/envoy/thread_local/thread_local.h index 683617634a20..5c4374aff8ea 100644 --- a/include/envoy/thread_local/thread_local.h +++ b/include/envoy/thread_local/thread_local.h @@ -16,6 +16,14 @@ namespace ThreadLocal { class ThreadLocalObject { public: virtual ~ThreadLocalObject() = default; + + /** + * Return the object casted to a concrete type. See getTyped() below for comments on the casts. + */ + template T& asType() { + ASSERT(dynamic_cast(this) != nullptr); + return *static_cast(this); + } }; using ThreadLocalObjectSharedPtr = std::shared_ptr; @@ -54,27 +62,15 @@ class Slot { return *static_cast(get().get()); } - /** - * Run a callback on all registered threads. - * @param cb supplies the callback to run. - */ - virtual void runOnAllThreads(Event::PostCb cb) PURE; - - /** - * Run a callback on all registered threads with a barrier. A shutdown initiated during the - * running of the PostCBs may prevent all_threads_complete_cb from being called. - * @param cb supplies the callback to run on each thread. - * @param all_threads_complete_cb supplies the callback to run on main thread after cb has - * been run on all registered threads. - */ - virtual void runOnAllThreads(Event::PostCb cb, Event::PostCb all_threads_complete_cb) PURE; - /** * Set thread local data on all threads previously registered via registerThread(). * @param initializeCb supplies the functor that will be called *on each thread*. The functor * returns the thread local object which is then stored. The storage is via * a shared_ptr. Thus, this is a flexible mechanism that can be used to share * the same data across all threads or to share different data on each thread. + * + * NOTE: The initialize callback is not supposed to capture the Slot, or its owner. As the owner + * may be destructed in main thread before the update_cb gets called in a worker thread. */ using InitializeCb = std::function; virtual void set(InitializeCb cb) PURE; diff --git a/source/common/stats/thread_local_store.cc b/source/common/stats/thread_local_store.cc index 54d0c78eba9b..4bd4ec6a9d6a 100644 --- a/source/common/stats/thread_local_store.cc +++ b/source/common/stats/thread_local_store.cc @@ -206,11 +206,13 @@ void ThreadLocalStoreImpl::mergeHistograms(PostMergeCb merge_complete_cb) { ASSERT(!merge_in_progress_); merge_in_progress_ = true; tls_->runOnAllThreads( - [this]() -> void { - for (const auto& id_hist : tls_->getTyped().tls_histogram_cache_) { + [](ThreadLocal::ThreadLocalObjectSharedPtr object) + -> ThreadLocal::ThreadLocalObjectSharedPtr { + for (const auto& id_hist : object->asType().tls_histogram_cache_) { const TlsHistogramSharedPtr& tls_hist = id_hist.second; tls_hist->beginMerge(); } + return object; }, [this, merge_complete_cb]() -> void { mergeInternal(merge_complete_cb); }); } else { @@ -304,7 +306,11 @@ void ThreadLocalStoreImpl::clearScopeFromCaches(uint64_t scope_id, if (!shutting_down_) { // Perform a cache flush on all threads. tls_->runOnAllThreads( - [this, scope_id]() { tls_->getTyped().eraseScope(scope_id); }, + [scope_id](ThreadLocal::ThreadLocalObjectSharedPtr object) + -> ThreadLocal::ThreadLocalObjectSharedPtr { + object->asType().eraseScope(scope_id); + return object; + }, [central_cache]() { /* Holds onto central_cache until all tls caches are clear */ }); } } @@ -320,8 +326,11 @@ void ThreadLocalStoreImpl::clearHistogramFromCaches(uint64_t histogram_id) { // https://gist.github.com/jmarantz/838cb6de7e74c0970ea6b63eded0139a // contains a patch that will implement batching together to clear multiple // histograms. - tls_->runOnAllThreads( - [this, histogram_id]() { tls_->getTyped().eraseHistogram(histogram_id); }); + tls_->runOnAllThreads([histogram_id](ThreadLocal::ThreadLocalObjectSharedPtr object) + -> ThreadLocal::ThreadLocalObjectSharedPtr { + object->asType().eraseHistogram(histogram_id); + return object; + }); } } diff --git a/source/common/thread_local/thread_local_impl.cc b/source/common/thread_local/thread_local_impl.cc index d4d02f8b2f5f..7ed9eeca7942 100644 --- a/source/common/thread_local/thread_local_impl.cc +++ b/source/common/thread_local/thread_local_impl.cc @@ -26,79 +26,71 @@ SlotPtr InstanceImpl::allocateSlot() { ASSERT(!shutdown_); if (free_slot_indexes_.empty()) { - SlotImplPtr slot(new SlotImpl(*this, slots_.size())); - auto wrapper = std::make_unique(*this, std::move(slot)); - slots_.push_back(wrapper->slot_.get()); - return wrapper; + SlotPtr slot = std::make_unique(*this, slots_.size()); + slots_.push_back(slot.get()); + return slot; } const uint32_t idx = free_slot_indexes_.front(); free_slot_indexes_.pop_front(); ASSERT(idx < slots_.size()); - SlotImplPtr slot(new SlotImpl(*this, idx)); + SlotPtr slot = std::make_unique(*this, idx); slots_[idx] = slot.get(); - return std::make_unique(*this, std::move(slot)); + return slot; } -bool InstanceImpl::SlotImpl::currentThreadRegistered() { - return thread_local_data_.data_.size() > index_; -} +InstanceImpl::SlotImpl::SlotImpl(InstanceImpl& parent, uint32_t index) + : parent_(parent), index_(index), still_alive_guard_(std::make_shared(true)) {} -void InstanceImpl::SlotImpl::runOnAllThreads(const UpdateCb& cb) { - parent_.runOnAllThreads([this, cb]() { setThreadLocal(index_, cb(get())); }); +Event::PostCb InstanceImpl::SlotImpl::wrapCallback(Event::PostCb&& cb) { + // See the header file comments for still_alive_guard_ for the purpose of this capture and the + // expired check below. + return [still_alive_guard = std::weak_ptr(still_alive_guard_), cb] { + if (!still_alive_guard.expired()) { + cb(); + } + }; } -void InstanceImpl::SlotImpl::runOnAllThreads(const UpdateCb& cb, Event::PostCb complete_cb) { - parent_.runOnAllThreads([this, cb]() { setThreadLocal(index_, cb(get())); }, complete_cb); +bool InstanceImpl::SlotImpl::currentThreadRegisteredWorker(uint32_t index) { + return thread_local_data_.data_.size() > index; } -ThreadLocalObjectSharedPtr InstanceImpl::SlotImpl::get() { - ASSERT(currentThreadRegistered()); - return thread_local_data_.data_[index_]; +bool InstanceImpl::SlotImpl::currentThreadRegistered() { + return currentThreadRegisteredWorker(index_); } -InstanceImpl::Bookkeeper::Bookkeeper(InstanceImpl& parent, SlotImplPtr&& slot) - : parent_(parent), slot_(std::move(slot)), - ref_count_(/*not used.*/ nullptr, - [slot = slot_.get(), &parent = this->parent_](uint32_t* /* not used */) { - // On destruction, post a cleanup callback on main thread, this could happen on - // any thread. - parent.scheduleCleanup(slot); - }) {} - -ThreadLocalObjectSharedPtr InstanceImpl::Bookkeeper::get() { return slot_->get(); } - -void InstanceImpl::Bookkeeper::runOnAllThreads(const UpdateCb& cb, Event::PostCb complete_cb) { - slot_->runOnAllThreads( - [cb, ref_count = this->ref_count_](ThreadLocalObjectSharedPtr previous) { - return cb(std::move(previous)); - }, - complete_cb); +ThreadLocalObjectSharedPtr InstanceImpl::SlotImpl::getWorker(uint32_t index) { + ASSERT(currentThreadRegisteredWorker(index)); + return thread_local_data_.data_[index]; } -void InstanceImpl::Bookkeeper::runOnAllThreads(const UpdateCb& cb) { - slot_->runOnAllThreads([cb, ref_count = this->ref_count_](ThreadLocalObjectSharedPtr previous) { - return cb(std::move(previous)); - }); -} +ThreadLocalObjectSharedPtr InstanceImpl::SlotImpl::get() { return getWorker(index_); } -bool InstanceImpl::Bookkeeper::currentThreadRegistered() { - return slot_->currentThreadRegistered(); +void InstanceImpl::SlotImpl::runOnAllThreads(const UpdateCb& cb, Event::PostCb complete_cb) { + // See the header file comments for still_alive_guard_ for why we capture index_. + parent_.runOnAllThreads( + wrapCallback([cb, index = index_]() { setThreadLocal(index, cb(getWorker(index))); }), + complete_cb); } -void InstanceImpl::Bookkeeper::runOnAllThreads(Event::PostCb cb) { - // Use ref_count_ to bookkeep how many on-the-fly callback are out there. - slot_->runOnAllThreads([cb, ref_count = this->ref_count_]() { cb(); }); +void InstanceImpl::SlotImpl::runOnAllThreads(const UpdateCb& cb) { + // See the header file comments for still_alive_guard_ for why we capture index_. + parent_.runOnAllThreads( + wrapCallback([cb, index = index_]() { setThreadLocal(index, cb(getWorker(index))); })); } -void InstanceImpl::Bookkeeper::runOnAllThreads(Event::PostCb cb, Event::PostCb main_callback) { - // Use ref_count_ to bookkeep how many on-the-fly callback are out there. - slot_->runOnAllThreads([cb, main_callback, ref_count = this->ref_count_]() { cb(); }, - main_callback); -} +void InstanceImpl::SlotImpl::set(InitializeCb cb) { + ASSERT(std::this_thread::get_id() == parent_.main_thread_id_); + ASSERT(!parent_.shutdown_); -void InstanceImpl::Bookkeeper::set(InitializeCb cb) { - slot_->set([cb, ref_count = this->ref_count_](Event::Dispatcher& dispatcher) - -> ThreadLocalObjectSharedPtr { return cb(dispatcher); }); + for (Event::Dispatcher& dispatcher : parent_.registered_threads_) { + // See the header file comments for still_alive_guard_ for why we capture index_. + dispatcher.post(wrapCallback( + [index = index_, cb, &dispatcher]() -> void { setThreadLocal(index, cb(dispatcher)); })); + } + + // Handle main thread. + setThreadLocal(index_, cb(*parent_.main_thread_dispatcher_)); } void InstanceImpl::registerThread(Event::Dispatcher& dispatcher, bool main_thread) { @@ -115,39 +107,7 @@ void InstanceImpl::registerThread(Event::Dispatcher& dispatcher, bool main_threa } } -// Puts the slot into a deferred delete container, the slot will be destructed when its out-going -// callback reference count goes to 0. -void InstanceImpl::recycle(SlotImplPtr&& slot) { - ASSERT(std::this_thread::get_id() == main_thread_id_); - ASSERT(slot != nullptr); - auto* slot_addr = slot.get(); - deferred_deletes_.insert({slot_addr, std::move(slot)}); -} - -// Called by the Bookkeeper ref_count destructor, the SlotImpl in the deferred deletes map can be -// destructed now. -void InstanceImpl::scheduleCleanup(SlotImpl* slot) { - if (shutdown_) { - // If server is shutting down, do nothing here. - // The destruction of Bookkeeper has already transferred the SlotImpl to the deferred_deletes_ - // queue. No matter if this method is called from a Worker thread, the SlotImpl will be - // destructed on main thread when InstanceImpl destructs. - return; - } - if (std::this_thread::get_id() == main_thread_id_) { - // If called from main thread, save a callback. - ASSERT(deferred_deletes_.contains(slot)); - deferred_deletes_.erase(slot); - return; - } - main_thread_dispatcher_->post([slot, this]() { - ASSERT(deferred_deletes_.contains(slot)); - // The slot is guaranteed to be put into the deferred_deletes_ map by Bookkeeper destructor. - deferred_deletes_.erase(slot); - }); -} - -void InstanceImpl::removeSlot(SlotImpl& slot) { +void InstanceImpl::removeSlot(uint32_t slot) { ASSERT(std::this_thread::get_id() == main_thread_id_); // When shutting down, we do not post slot removals to other threads. This is because the other @@ -158,18 +118,18 @@ void InstanceImpl::removeSlot(SlotImpl& slot) { return; } - const uint64_t index = slot.index_; - slots_[index] = nullptr; - ASSERT(std::find(free_slot_indexes_.begin(), free_slot_indexes_.end(), index) == + slots_[slot] = nullptr; + ASSERT(std::find(free_slot_indexes_.begin(), free_slot_indexes_.end(), slot) == free_slot_indexes_.end(), - fmt::format("slot index {} already in free slot set!", index)); - free_slot_indexes_.push_back(index); - runOnAllThreads([index]() -> void { + fmt::format("slot index {} already in free slot set!", slot)); + free_slot_indexes_.push_back(slot); + runOnAllThreads([slot]() -> void { // This runs on each thread and clears the slot, making it available for a new allocations. // This is safe even if a new allocation comes in, because everything happens with post() and - // will be sequenced after this removal. - if (index < thread_local_data_.data_.size()) { - thread_local_data_.data_[index] = nullptr; + // will be sequenced after this removal. It is also safe if there are callbacks pending on + // other threads because they will run first. + if (slot < thread_local_data_.data_.size()) { + thread_local_data_.data_[slot] = nullptr; } }); } @@ -205,19 +165,6 @@ void InstanceImpl::runOnAllThreads(Event::PostCb cb, Event::PostCb all_threads_c } } -void InstanceImpl::SlotImpl::set(InitializeCb cb) { - ASSERT(std::this_thread::get_id() == parent_.main_thread_id_); - ASSERT(!parent_.shutdown_); - - for (Event::Dispatcher& dispatcher : parent_.registered_threads_) { - const uint32_t index = index_; - dispatcher.post([index, cb, &dispatcher]() -> void { setThreadLocal(index, cb(dispatcher)); }); - } - - // Handle main thread. - setThreadLocal(index_, cb(*parent_.main_thread_dispatcher_)); -} - void InstanceImpl::setThreadLocal(uint32_t index, ThreadLocalObjectSharedPtr object) { if (thread_local_data_.data_.size() <= index) { thread_local_data_.data_.resize(index + 1); diff --git a/source/common/thread_local/thread_local_impl.h b/source/common/thread_local/thread_local_impl.h index 71153107fb3d..2b83a2aebf47 100644 --- a/source/common/thread_local/thread_local_impl.h +++ b/source/common/thread_local/thread_local_impl.h @@ -11,8 +11,6 @@ #include "common/common/logger.h" #include "common/common/non_copyable.h" -#include "absl/container/flat_hash_map.h" - namespace Envoy { namespace ThreadLocal { @@ -32,45 +30,38 @@ class InstanceImpl : Logger::Loggable, public NonCopyable, pub Event::Dispatcher& dispatcher() override; private: + // On destruction returns the slot index to the deferred delete queue (detaches it). This allows + // a slot to be destructed on the main thread while controlling the lifetime of the underlying + // slot as callbacks drain from workers. struct SlotImpl : public Slot { - SlotImpl(InstanceImpl& parent, uint64_t index) : parent_(parent), index_(index) {} - ~SlotImpl() override { parent_.removeSlot(*this); } - - // ThreadLocal::Slot - ThreadLocalObjectSharedPtr get() override; - bool currentThreadRegistered() override; - void runOnAllThreads(const UpdateCb& cb) override; - void runOnAllThreads(const UpdateCb& cb, Event::PostCb complete_cb) override; - void runOnAllThreads(Event::PostCb cb) override { parent_.runOnAllThreads(cb); } - void runOnAllThreads(Event::PostCb cb, Event::PostCb main_callback) override { - parent_.runOnAllThreads(cb, main_callback); - } - void set(InitializeCb cb) override; - - InstanceImpl& parent_; - const uint64_t index_; - }; - - using SlotImplPtr = std::unique_ptr; - - // A Wrapper of SlotImpl which on destruction returns the SlotImpl to the deferred delete queue - // (detaches it). - struct Bookkeeper : public Slot { - Bookkeeper(InstanceImpl& parent, SlotImplPtr&& slot); - ~Bookkeeper() override { parent_.recycle(std::move(slot_)); } + SlotImpl(InstanceImpl& parent, uint32_t index); + ~SlotImpl() override { parent_.removeSlot(index_); } + Event::PostCb wrapCallback(Event::PostCb&& cb); + static bool currentThreadRegisteredWorker(uint32_t index); + static ThreadLocalObjectSharedPtr getWorker(uint32_t index); // ThreadLocal::Slot ThreadLocalObjectSharedPtr get() override; void runOnAllThreads(const UpdateCb& cb) override; void runOnAllThreads(const UpdateCb& cb, Event::PostCb complete_cb) override; bool currentThreadRegistered() override; - void runOnAllThreads(Event::PostCb cb) override; - void runOnAllThreads(Event::PostCb cb, Event::PostCb main_callback) override; void set(InitializeCb cb) override; InstanceImpl& parent_; - SlotImplPtr slot_; - std::shared_ptr ref_count_; + const uint32_t index_; + // The following is used to safely verify via weak_ptr that this slot is still alive. This + // does not prevent all races if a callback does not capture appropriately, but it does fix + // the common case of a slot destroyed immediately before anything is posted to a worker. + // NOTE: The general safety model of a slot is that it is destroyed immediately on the main + // thread. This means that *all* captures must not reference the slot object directly. + // this is why index_ is captured manually in callbacks that require it. + // NOTE: When the slot is destroyed, the index is immediately recycled. This is safe because + // any new posts for a recycled index must come after any previous callbacks for the + // previous owner of the index. + // TODO(mattklein123): Add clang-tidy analysis rule to check that "this" is not captured by + // a TLS function call. This check will not prevent all bad captures, but it will at least + // make the programmer more aware of potential issues. + std::shared_ptr still_alive_guard_; }; struct ThreadLocalData { @@ -78,26 +69,16 @@ class InstanceImpl : Logger::Loggable, public NonCopyable, pub std::vector data_; }; - void recycle(SlotImplPtr&& slot); - // Cleanup the deferred deletes queue. - void scheduleCleanup(SlotImpl* slot); - - void removeSlot(SlotImpl& slot); + void removeSlot(uint32_t slot); void runOnAllThreads(Event::PostCb cb); void runOnAllThreads(Event::PostCb cb, Event::PostCb main_callback); static void setThreadLocal(uint32_t index, ThreadLocalObjectSharedPtr object); static thread_local ThreadLocalData thread_local_data_; - // A indexed container for Slots that has to be deferred to delete due to out-going callbacks - // pointing to the Slot. To let the ref_count_ deleter find the SlotImpl by address, the container - // is defined as a map of SlotImpl address to the unique_ptr. - absl::flat_hash_map deferred_deletes_; - - std::vector slots_; + std::vector slots_; // A list of index of freed slots. std::list free_slot_indexes_; - std::list> registered_threads_; std::thread::id main_thread_id_; Event::Dispatcher* main_thread_dispatcher_{}; diff --git a/source/common/upstream/cluster_manager_impl.cc b/source/common/upstream/cluster_manager_impl.cc index c5df7eb10f98..5f814b498028 100644 --- a/source/common/upstream/cluster_manager_impl.cc +++ b/source/common/upstream/cluster_manager_impl.cc @@ -648,10 +648,12 @@ bool ClusterManagerImpl::addOrUpdateCluster(const envoy::config::cluster::v3::Cl } void ClusterManagerImpl::createOrUpdateThreadLocalCluster(ClusterData& cluster) { - tls_->runOnAllThreads([this, new_cluster = cluster.cluster_->info(), - thread_aware_lb_factory = cluster.loadBalancerFactory()]() -> void { + tls_->runOnAllThreads([new_cluster = cluster.cluster_->info(), + thread_aware_lb_factory = cluster.loadBalancerFactory()]( + ThreadLocal::ThreadLocalObjectSharedPtr object) + -> ThreadLocal::ThreadLocalObjectSharedPtr { ThreadLocalClusterManagerImpl& cluster_manager = - tls_->getTyped(); + object->asType(); if (cluster_manager.thread_local_clusters_.count(new_cluster->name()) > 0) { ENVOY_LOG(debug, "updating TLS cluster {}", new_cluster->name()); @@ -665,6 +667,8 @@ void ClusterManagerImpl::createOrUpdateThreadLocalCluster(ClusterData& cluster) for (auto& cb : cluster_manager.update_callbacks_) { cb->onClusterAddOrUpdate(*thread_local_cluster); } + + return object; }); } @@ -678,9 +682,10 @@ bool ClusterManagerImpl::removeCluster(const std::string& cluster_name) { active_clusters_.erase(existing_active_cluster); ENVOY_LOG(info, "removing cluster {}", cluster_name); - tls_->runOnAllThreads([this, cluster_name]() -> void { + tls_->runOnAllThreads([cluster_name](ThreadLocal::ThreadLocalObjectSharedPtr object) + -> ThreadLocal::ThreadLocalObjectSharedPtr { ThreadLocalClusterManagerImpl& cluster_manager = - tls_->getTyped(); + object->asType(); ASSERT(cluster_manager.thread_local_clusters_.count(cluster_name) == 1); ENVOY_LOG(debug, "removing TLS cluster {}", cluster_name); @@ -688,6 +693,7 @@ bool ClusterManagerImpl::removeCluster(const std::string& cluster_name) { cb->onClusterRemoval(cluster_name); } cluster_manager.thread_local_clusters_.erase(cluster_name); + return object; }); } @@ -902,9 +908,12 @@ ClusterManagerImpl::tcpConnPoolForCluster(const std::string& cluster, ResourcePr void ClusterManagerImpl::postThreadLocalDrainConnections(const Cluster& cluster, const HostVector& hosts_removed) { - tls_->runOnAllThreads([this, name = cluster.info()->name(), hosts_removed]() { - ThreadLocalClusterManagerImpl::removeHosts(name, hosts_removed, *tls_); - }); + tls_->runOnAllThreads( + [name = cluster.info()->name(), hosts_removed](ThreadLocal::ThreadLocalObjectSharedPtr object) + -> ThreadLocal::ThreadLocalObjectSharedPtr { + object->asType().removeHosts(name, hosts_removed); + return object; + }); } void ClusterManagerImpl::postThreadLocalClusterUpdate(const Cluster& cluster, uint32_t priority, @@ -912,19 +921,25 @@ void ClusterManagerImpl::postThreadLocalClusterUpdate(const Cluster& cluster, ui const HostVector& hosts_removed) { const auto& host_set = cluster.prioritySet().hostSetsPerPriority()[priority]; - tls_->runOnAllThreads([this, name = cluster.info()->name(), priority, + tls_->runOnAllThreads([name = cluster.info()->name(), priority, update_params = HostSetImpl::updateHostsParams(*host_set), locality_weights = host_set->localityWeights(), hosts_added, hosts_removed, - overprovisioning_factor = host_set->overprovisioningFactor()]() { - ThreadLocalClusterManagerImpl::updateClusterMembership( - name, priority, update_params, locality_weights, hosts_added, hosts_removed, *tls_, + overprovisioning_factor = host_set->overprovisioningFactor()]( + ThreadLocal::ThreadLocalObjectSharedPtr object) + -> ThreadLocal::ThreadLocalObjectSharedPtr { + object->asType().updateClusterMembership( + name, priority, update_params, locality_weights, hosts_added, hosts_removed, overprovisioning_factor); + return object; }); } void ClusterManagerImpl::postThreadLocalHealthFailure(const HostSharedPtr& host) { - tls_->runOnAllThreads( - [this, host] { ThreadLocalClusterManagerImpl::onHostHealthFailure(host, *tls_); }); + tls_->runOnAllThreads([host](ThreadLocal::ThreadLocalObjectSharedPtr object) + -> ThreadLocal::ThreadLocalObjectSharedPtr { + object->asType().onHostHealthFailure(host); + return object; + }); } Host::CreateConnectionData ClusterManagerImpl::tcpConnForCluster(const std::string& cluster, @@ -1160,13 +1175,10 @@ void ClusterManagerImpl::ThreadLocalClusterManagerImpl::removeTcpConn( } } -void ClusterManagerImpl::ThreadLocalClusterManagerImpl::removeHosts(const std::string& name, - const HostVector& hosts_removed, - ThreadLocal::Slot& tls) { - ThreadLocalClusterManagerImpl& config = tls.getTyped(); - - ASSERT(config.thread_local_clusters_.find(name) != config.thread_local_clusters_.end()); - const auto& cluster_entry = config.thread_local_clusters_[name]; +void ClusterManagerImpl::ThreadLocalClusterManagerImpl::removeHosts( + const std::string& name, const HostVector& hosts_removed) { + ASSERT(thread_local_clusters_.find(name) != thread_local_clusters_.end()); + const auto& cluster_entry = thread_local_clusters_[name]; ENVOY_LOG(debug, "removing hosts for TLS cluster {} removed {}", name, hosts_removed.size()); // We need to go through and purge any connection pools for hosts that got deleted. @@ -1178,11 +1190,9 @@ void ClusterManagerImpl::ThreadLocalClusterManagerImpl::removeHosts(const std::s void ClusterManagerImpl::ThreadLocalClusterManagerImpl::updateClusterMembership( const std::string& name, uint32_t priority, PrioritySet::UpdateHostsParams update_hosts_params, LocalityWeightsConstSharedPtr locality_weights, const HostVector& hosts_added, - const HostVector& hosts_removed, ThreadLocal::Slot& tls, uint64_t overprovisioning_factor) { - ThreadLocalClusterManagerImpl& config = tls.getTyped(); - - ASSERT(config.thread_local_clusters_.find(name) != config.thread_local_clusters_.end()); - const auto& cluster_entry = config.thread_local_clusters_[name]; + const HostVector& hosts_removed, uint64_t overprovisioning_factor) { + ASSERT(thread_local_clusters_.find(name) != thread_local_clusters_.end()); + const auto& cluster_entry = thread_local_clusters_[name]; ENVOY_LOG(debug, "membership update for TLS cluster {} added {} removed {}", name, hosts_added.size(), hosts_removed.size()); cluster_entry->priority_set_.updateHosts(priority, std::move(update_hosts_params), @@ -1197,7 +1207,7 @@ void ClusterManagerImpl::ThreadLocalClusterManagerImpl::updateClusterMembership( } void ClusterManagerImpl::ThreadLocalClusterManagerImpl::onHostHealthFailure( - const HostSharedPtr& host, ThreadLocal::Slot& tls) { + const HostSharedPtr& host) { // Drain all HTTP connection pool connections in the case of a host health failure. If outlier/ // health is due to ECMP flow hashing issues for example, a new set of connections might do @@ -1205,9 +1215,8 @@ void ClusterManagerImpl::ThreadLocalClusterManagerImpl::onHostHealthFailure( // TODO(mattklein123): This function is currently very specific, but in the future when we do // more granular host set changes, we should be able to capture single host changes and make them // more targeted. - ThreadLocalClusterManagerImpl& config = tls.getTyped(); { - const auto container = config.getHttpConnPoolsContainer(host); + const auto container = getHttpConnPoolsContainer(host); if (container != nullptr) { container->pools_->drainConnections(); } @@ -1217,8 +1226,8 @@ void ClusterManagerImpl::ThreadLocalClusterManagerImpl::onHostHealthFailure( // connections being closed, it only prevents new connections through the pool. The // CLOSE_CONNECTIONS_ON_HOST_HEALTH_FAILURE can be used to make the pool close any // active connections. - const auto& container = config.host_tcp_conn_pool_map_.find(host); - if (container != config.host_tcp_conn_pool_map_.end()) { + const auto& container = host_tcp_conn_pool_map_.find(host); + if (container != host_tcp_conn_pool_map_.end()) { for (const auto& pair : container->second.pools_) { const Tcp::ConnectionPool::InstancePtr& pool = pair.second; if (host->cluster().features() & @@ -1247,8 +1256,8 @@ void ClusterManagerImpl::ThreadLocalClusterManagerImpl::onHostHealthFailure( // in the configuration documentation in cluster setting // "close_connections_on_host_health_failure". Update the docs if this if this changes. while (true) { - const auto& it = config.host_tcp_conn_map_.find(host); - if (it == config.host_tcp_conn_map_.end()) { + const auto& it = host_tcp_conn_map_.find(host); + if (it == host_tcp_conn_map_.end()) { break; } TcpConnectionsMap& container = it->second; diff --git a/source/common/upstream/cluster_manager_impl.h b/source/common/upstream/cluster_manager_impl.h index 1196bd13db72..b235be21d99c 100644 --- a/source/common/upstream/cluster_manager_impl.h +++ b/source/common/upstream/cluster_manager_impl.h @@ -368,15 +368,13 @@ class ClusterManagerImpl : public ClusterManager, Logger::Loggable& skip_predicate) { // Post the priority set to worker threads. - tls_->runOnAllThreads([this, skip_predicate, cluster_name = this->info()->name()]() { + // TODO(mattklein123): Remove "this" capture. + tls_->runOnAllThreads([this, skip_predicate, cluster_name = this->info()->name()]( + ThreadLocal::ThreadLocalObjectSharedPtr object) + -> ThreadLocal::ThreadLocalObjectSharedPtr { PriorityContextPtr priority_context = linearizePrioritySet(skip_predicate); Upstream::ThreadLocalCluster* cluster = cluster_manager_.get(cluster_name); ASSERT(cluster != nullptr); dynamic_cast(cluster->loadBalancer()) .refresh(std::move(priority_context)); + return object; }); } diff --git a/source/extensions/common/dynamic_forward_proxy/dns_cache_impl.cc b/source/extensions/common/dynamic_forward_proxy/dns_cache_impl.cc index b2e2d5defce1..20c8a62e7bcf 100644 --- a/source/extensions/common/dynamic_forward_proxy/dns_cache_impl.cc +++ b/source/extensions/common/dynamic_forward_proxy/dns_cache_impl.cc @@ -257,8 +257,10 @@ void DnsCacheImpl::updateTlsHostsMap() { } } - tls_slot_->runOnAllThreads([this, new_host_map]() { - tls_slot_->getTyped().updateHostMap(new_host_map); + tls_slot_->runOnAllThreads([new_host_map](ThreadLocal::ThreadLocalObjectSharedPtr object) + -> ThreadLocal::ThreadLocalObjectSharedPtr { + object->asType().updateHostMap(new_host_map); + return object; }); } diff --git a/source/extensions/common/tap/extension_config_base.cc b/source/extensions/common/tap/extension_config_base.cc index fda84e29fbeb..6578c02fc37b 100644 --- a/source/extensions/common/tap/extension_config_base.cc +++ b/source/extensions/common/tap/extension_config_base.cc @@ -63,15 +63,22 @@ const absl::string_view ExtensionConfigBase::adminId() { } void ExtensionConfigBase::clearTapConfig() { - tls_slot_->runOnAllThreads([this] { tls_slot_->getTyped().config_ = nullptr; }); + tls_slot_->runOnAllThreads([](ThreadLocal::ThreadLocalObjectSharedPtr object) + -> ThreadLocal::ThreadLocalObjectSharedPtr { + object->asType().config_ = nullptr; + return object; + }); } void ExtensionConfigBase::installNewTap(envoy::config::tap::v3::TapConfig&& proto_config, Sink* admin_streamer) { TapConfigSharedPtr new_config = config_factory_->createConfigFromProto(std::move(proto_config), admin_streamer); - tls_slot_->runOnAllThreads( - [this, new_config] { tls_slot_->getTyped().config_ = new_config; }); + tls_slot_->runOnAllThreads([new_config](ThreadLocal::ThreadLocalObjectSharedPtr object) + -> ThreadLocal::ThreadLocalObjectSharedPtr { + object->asType().config_ = new_config; + return object; + }); } void ExtensionConfigBase::newTapConfig(envoy::config::tap::v3::TapConfig&& proto_config, diff --git a/source/server/overload_manager_impl.cc b/source/server/overload_manager_impl.cc index c0010f819be1..3249e3808b8d 100644 --- a/source/server/overload_manager_impl.cc +++ b/source/server/overload_manager_impl.cc @@ -323,11 +323,14 @@ void OverloadManagerImpl::flushResourceUpdates() { auto shared_updates = std::make_shared>(); std::swap(*shared_updates, state_updates_to_flush_); - tls_->runOnAllThreads([this, updates = std::move(shared_updates)] { - for (const auto& [action, state] : *updates) { - tls_->getTyped().setState(action, state); - } - }); + tls_->runOnAllThreads( + [updates = std::move(shared_updates)](ThreadLocal::ThreadLocalObjectSharedPtr object) + -> ThreadLocal::ThreadLocalObjectSharedPtr { + for (const auto& [action, state] : *updates) { + object->asType().setState(action, state); + } + return object; + }); } for (const auto& [cb, state] : callbacks_to_flush_) { diff --git a/test/common/stats/thread_local_store_test.cc b/test/common/stats/thread_local_store_test.cc index 135c6b424097..5944f3bc5c5a 100644 --- a/test/common/stats/thread_local_store_test.cc +++ b/test/common/stats/thread_local_store_test.cc @@ -55,10 +55,11 @@ class ThreadLocalStoreTestingPeer { const std::function& num_tls_hist_cb) { auto num_tls_histograms = std::make_shared>(0); thread_local_store_impl.tls_->runOnAllThreads( - [&thread_local_store_impl, num_tls_histograms]() { - auto& tls_cache = - thread_local_store_impl.tls_->getTyped(); + [num_tls_histograms](ThreadLocal::ThreadLocalObjectSharedPtr object) + -> ThreadLocal::ThreadLocalObjectSharedPtr { + auto& tls_cache = object->asType(); *num_tls_histograms += tls_cache.tls_histogram_cache_.size(); + return object; }, [num_tls_hist_cb, num_tls_histograms]() { num_tls_hist_cb(*num_tls_histograms); }); } diff --git a/test/common/thread_local/thread_local_impl_test.cc b/test/common/thread_local/thread_local_impl_test.cc index 6747fa2eae99..a625d57002a7 100644 --- a/test/common/thread_local/thread_local_impl_test.cc +++ b/test/common/thread_local/thread_local_impl_test.cc @@ -45,7 +45,6 @@ class ThreadLocalInstanceImplTest : public testing::Test { object.reset(); return object_ref; } - int deferredDeletesMapSize() { return tls_.deferred_deletes_.size(); } int freeSlotIndexesListSize() { return tls_.free_slot_indexes_.size(); } InstanceImpl tls_; @@ -60,7 +59,6 @@ TEST_F(ThreadLocalInstanceImplTest, All) { EXPECT_CALL(thread_dispatcher_, post(_)); SlotPtr slot1 = tls_.allocateSlot(); slot1.reset(); - EXPECT_EQ(deferredDeletesMapSize(), 0); EXPECT_EQ(freeSlotIndexesListSize(), 1); // Create a new slot which should take the place of the old slot. ReturnPointee() is used to @@ -86,48 +84,67 @@ TEST_F(ThreadLocalInstanceImplTest, All) { slot3.reset(); slot4.reset(); EXPECT_EQ(freeSlotIndexesListSize(), 0); - EXPECT_EQ(deferredDeletesMapSize(), 2); EXPECT_CALL(object_ref4, onDestroy()); EXPECT_CALL(object_ref3, onDestroy()); tls_.shutdownThread(); } -TEST_F(ThreadLocalInstanceImplTest, DeferredRecycle) { +struct ThreadStatus { + uint64_t thread_local_calls_{0}; + bool all_threads_complete_ = false; +}; + +TEST_F(ThreadLocalInstanceImplTest, CallbackNotInvokedAfterDeletion) { InSequence s; - // Free a slot without ever calling set. - EXPECT_CALL(thread_dispatcher_, post(_)); - SlotPtr slot1 = tls_.allocateSlot(); - slot1.reset(); - // Slot destructed directly, as there is no out-going callbacks. - EXPECT_EQ(deferredDeletesMapSize(), 0); + // Allocate a slot and invoke all callback variants. Hold all callbacks and destroy the slot. + // Make sure that recycling happens appropriately. + SlotPtr slot = tls_.allocateSlot(); + + std::list holder; + EXPECT_CALL(thread_dispatcher_, post(_)).Times(4).WillRepeatedly(Invoke([&](Event::PostCb cb) { + // Holds the posted callback. + holder.push_back(cb); + })); + + uint32_t total_callbacks = 0; + slot->set([&total_callbacks](Event::Dispatcher&) -> ThreadLocal::ThreadLocalObjectSharedPtr { + // Callbacks happen on the main thread but not the workers, so track the total. + total_callbacks++; + return nullptr; + }); + slot->runOnAllThreads([&total_callbacks](ThreadLocal::ThreadLocalObjectSharedPtr) + -> ThreadLocal::ThreadLocalObjectSharedPtr { + // Callbacks happen on the main thread but not the workers, so track the total. + total_callbacks++; + return nullptr; + }); + ThreadStatus thread_status; + slot->runOnAllThreads( + [&thread_status]( + ThreadLocal::ThreadLocalObjectSharedPtr) -> ThreadLocal::ThreadLocalObjectSharedPtr { + ++thread_status.thread_local_calls_; + return nullptr; + }, + [&thread_status]() -> void { + // Callbacks happen on the main thread but not the workers. + EXPECT_EQ(thread_status.thread_local_calls_, 1); + thread_status.all_threads_complete_ = true; + }); + EXPECT_FALSE(thread_status.all_threads_complete_); + + EXPECT_EQ(2, total_callbacks); + slot.reset(); EXPECT_EQ(freeSlotIndexesListSize(), 1); - // Allocate a slot and set value, hold the posted callback and the slot will only be returned - // after the held callback is destructed. - { - SlotPtr slot2 = tls_.allocateSlot(); - EXPECT_EQ(freeSlotIndexesListSize(), 0); - { - Event::PostCb holder; - EXPECT_CALL(thread_dispatcher_, post(_)).WillOnce(Invoke([&](Event::PostCb cb) { - // Holds the posted callback. - holder = cb; - })); - slot2->set( - [](Event::Dispatcher&) -> ThreadLocal::ThreadLocalObjectSharedPtr { return nullptr; }); - slot2.reset(); - // Not released yet, as holder has a copy of the ref_count_. - EXPECT_EQ(freeSlotIndexesListSize(), 0); - EXPECT_EQ(deferredDeletesMapSize(), 1); - // This post is called when the holder dies. - EXPECT_CALL(thread_dispatcher_, post(_)); - } - // Slot is deleted now that there holder destructs. - EXPECT_EQ(deferredDeletesMapSize(), 0); - EXPECT_EQ(freeSlotIndexesListSize(), 1); + EXPECT_CALL(main_dispatcher_, post(_)); + while (!holder.empty()) { + holder.front()(); + holder.pop_front(); } + EXPECT_EQ(2, total_callbacks); + EXPECT_TRUE(thread_status.all_threads_complete_); tls_.shutdownGlobalThreading(); } @@ -172,25 +189,29 @@ TEST_F(ThreadLocalInstanceImplTest, UpdateCallback) { // Validate ThreadLocal::runOnAllThreads behavior with all_thread_complete call back. TEST_F(ThreadLocalInstanceImplTest, RunOnAllThreads) { SlotPtr tlsptr = tls_.allocateSlot(); + TestThreadLocalObject& object_ref = setObject(*tlsptr); EXPECT_CALL(thread_dispatcher_, post(_)); EXPECT_CALL(main_dispatcher_, post(_)); // Ensure that the thread local call back and all_thread_complete call back are called. - struct { - uint64_t thread_local_calls_{0}; - bool all_threads_complete_ = false; - } thread_status; - - tlsptr->runOnAllThreads([&thread_status]() -> void { ++thread_status.thread_local_calls_; }, - [&thread_status]() -> void { - EXPECT_EQ(thread_status.thread_local_calls_, 2); - thread_status.all_threads_complete_ = true; - }); - + ThreadStatus thread_status; + tlsptr->runOnAllThreads( + [&thread_status](ThreadLocal::ThreadLocalObjectSharedPtr object) + -> ThreadLocal::ThreadLocalObjectSharedPtr { + ++thread_status.thread_local_calls_; + return object; + }, + [&thread_status]() -> void { + EXPECT_EQ(thread_status.thread_local_calls_, 2); + thread_status.all_threads_complete_ = true; + }); EXPECT_TRUE(thread_status.all_threads_complete_); tls_.shutdownGlobalThreading(); + tlsptr.reset(); + EXPECT_EQ(freeSlotIndexesListSize(), 0); + EXPECT_CALL(object_ref, onDestroy()); tls_.shutdownThread(); } diff --git a/test/mocks/thread_local/mocks.h b/test/mocks/thread_local/mocks.h index 9bbd26a64465..dc6518c5068a 100644 --- a/test/mocks/thread_local/mocks.h +++ b/test/mocks/thread_local/mocks.h @@ -60,10 +60,6 @@ class MockInstance : public Instance { // ThreadLocal::Slot ThreadLocalObjectSharedPtr get() override { return parent_.data_[index_]; } bool currentThreadRegistered() override { return parent_.registered_; } - void runOnAllThreads(Event::PostCb cb) override { parent_.runOnAllThreads(cb); } - void runOnAllThreads(Event::PostCb cb, Event::PostCb main_callback) override { - parent_.runOnAllThreads(cb, main_callback); - } void runOnAllThreads(const UpdateCb& cb) override { parent_.runOnAllThreads([cb, this]() { parent_.data_[index_] = cb(parent_.data_[index_]); }); }