From 8f1424161355c18bd56d4798af37a2fb49af58db Mon Sep 17 00:00:00 2001 From: "jie.jiang" Date: Fri, 25 Oct 2024 10:56:12 +0800 Subject: [PATCH] fix TokenBucket rate_ & burstSize_ thread safe --- folly/TokenBucket.h | 63 +++++++++++++++++++++++++++++++++++---------- 1 file changed, 49 insertions(+), 14 deletions(-) diff --git a/folly/TokenBucket.h b/folly/TokenBucket.h index 83587dce461..9a5c3d9faa2 100644 --- a/folly/TokenBucket.h +++ b/folly/TokenBucket.h @@ -486,6 +486,8 @@ class BasicDynamicTokenBucket { template class BasicTokenBucket { private: + template + using Atom = typename Policy::template atom; using Impl = BasicDynamicTokenBucket; public: @@ -501,8 +503,8 @@ class BasicTokenBucket { BasicTokenBucket( double genRate, double burstSize, double zeroTime = 0) noexcept : tokenBucket_(zeroTime), rate_(genRate), burstSize_(burstSize) { - assert(rate_ > 0); - assert(burstSize_ > 0); + assert(rate_.load(std::memory_order_acquire) > 0); + assert(burstSize_.load(std::memory_order_acquire) > 0); } /** @@ -510,14 +512,28 @@ class BasicTokenBucket { * * Warning: not thread safe! */ - BasicTokenBucket(const BasicTokenBucket& other) noexcept = default; + BasicTokenBucket(const BasicTokenBucket& other) noexcept + : tokenBucket_(other.tokenBucket_), + rate_(other.rate_.load(std::memory_order_acquire)), + burstSize_(other.burstSize_.load(std::memory_order_acquire)) {} /** * Copy-assignment operator. * * Warning: not thread safe! */ - BasicTokenBucket& operator=(const BasicTokenBucket& other) noexcept = default; + BasicTokenBucket& operator=(const BasicTokenBucket& other) noexcept { + if (this != &other) { + tokenBucket_ = other.tokenBucket_; + rate_.store( + other.rate_.load(std::memory_order_acquire), + std::memory_order_release); + burstSize_.store( + other.burstSize_.load(std::memory_order_acquire), + std::memory_order_release); + } + return *this; + } /** * Returns the current time in seconds since Epoch. @@ -578,7 +594,11 @@ class BasicTokenBucket { * @return True if the rate limit check passed, false otherwise. */ bool consume(double toConsume, double nowInSeconds = defaultClockNow()) { - return tokenBucket_.consume(toConsume, rate_, burstSize_, nowInSeconds); + return tokenBucket_.consume( + toConsume, + rate_.load(std::memory_order_acquire), + burstSize_.load(std::memory_order_acquire), + nowInSeconds); } /** @@ -597,7 +617,10 @@ class BasicTokenBucket { double consumeOrDrain( double toConsume, double nowInSeconds = defaultClockNow()) { return tokenBucket_.consumeOrDrain( - toConsume, rate_, burstSize_, nowInSeconds); + toConsume, + rate_.load(std::memory_order_acquire), + burstSize_.load(std::memory_order_acquire), + nowInSeconds); } /** @@ -605,7 +628,8 @@ class BasicTokenBucket { * For negative tokens, setCapacity() can be used */ void returnTokens(double tokensToReturn) { - return tokenBucket_.returnTokens(tokensToReturn, rate_); + return tokenBucket_.returnTokens( + tokensToReturn, rate_.load(std::memory_order_acquire)); } /** @@ -615,7 +639,10 @@ class BasicTokenBucket { Optional consumeWithBorrowNonBlocking( double toConsume, double nowInSeconds = defaultClockNow()) { return tokenBucket_.consumeWithBorrowNonBlocking( - toConsume, rate_, burstSize_, nowInSeconds); + toConsume, + rate_.load(std::memory_order_acquire), + burstSize_.load(std::memory_order_acquire), + nowInSeconds); } /** @@ -624,7 +651,10 @@ class BasicTokenBucket { bool consumeWithBorrowAndWait( double toConsume, double nowInSeconds = defaultClockNow()) { return tokenBucket_.consumeWithBorrowAndWait( - toConsume, rate_, burstSize_, nowInSeconds); + toConsume, + rate_.load(std::memory_order_acquire), + burstSize_.load(std::memory_order_acquire), + nowInSeconds); } /** @@ -644,7 +674,10 @@ class BasicTokenBucket { * Thread-safe (but returned value may immediately be outdated). */ double balance(double nowInSeconds = defaultClockNow()) const noexcept { - return tokenBucket_.balance(rate_, burstSize_, nowInSeconds); + return tokenBucket_.balance( + rate_.load(std::memory_order_acquire), + burstSize_.load(std::memory_order_acquire), + nowInSeconds); } /** @@ -652,19 +685,21 @@ class BasicTokenBucket { * * Thread-safe (but returned value may immediately be outdated). */ - double rate() const noexcept { return rate_; } + double rate() const noexcept { return rate_.load(std::memory_order_acquire); } /** * Returns the maximum burst size. * * Thread-safe (but returned value may immediately be outdated). */ - double burst() const noexcept { return burstSize_; } + double burst() const noexcept { + return burstSize_.load(std::memory_order_acquire); + } private: Impl tokenBucket_; - double rate_; - double burstSize_; + Atom rate_; + Atom burstSize_; }; using TokenBucket = BasicTokenBucket<>;