Skip to content

Commit

Permalink
fix: unblock transactions only if requirements are correct (#2345)
Browse files Browse the repository at this point in the history
fixes #2294

bug: we unblock XREADGROUP cmd even if we don't have new values

fix: added check with custom requirements for blocking comands
  • Loading branch information
BorysTheDev authored Jan 2, 2024
1 parent 03f69ff commit 5b90545
Show file tree
Hide file tree
Showing 11 changed files with 109 additions and 55 deletions.
54 changes: 28 additions & 26 deletions src/server/blocking_controller.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

#include "server/blocking_controller.h"

#include <absl/container/inlined_vector.h>

#include <boost/smart_ptr/intrusive_ptr.hpp>

extern "C" {
Expand All @@ -20,12 +22,13 @@ using namespace std;

struct WatchItem {
Transaction* trans;
KeyReadyChecker key_ready_checker;

Transaction* get() const {
return trans;
}

WatchItem(Transaction* t) : trans(t) {
WatchItem(Transaction* t, KeyReadyChecker krc) : trans(t), key_ready_checker(std::move(krc)) {
}
};

Expand Down Expand Up @@ -212,15 +215,7 @@ void BlockingController::NotifyPending() {
for (auto key : wt.awakened_keys) {
string_view sv_key = static_cast<string_view>(key);
DVLOG(1) << "Processing awakened key " << sv_key;

// Double verify we still got the item.
auto [it, exp_it] = owner_->db_slice().FindReadOnly(context, sv_key);
// Only LIST, ZSET and STREAM are allowed to block.
if (!IsValid(it) || !(it->second.ObjType() == OBJ_LIST || it->second.ObjType() == OBJ_ZSET ||
it->second.ObjType() == OBJ_STREAM))
continue;

NotifyWatchQueue(sv_key, &wt.queue_map);
NotifyWatchQueue(sv_key, &wt.queue_map, context);
}
wt.awakened_keys.clear();

Expand All @@ -231,7 +226,7 @@ void BlockingController::NotifyPending() {
awakened_indices_.clear();
}

void BlockingController::AddWatched(ArgSlice keys, Transaction* trans) {
void BlockingController::AddWatched(ArgSlice keys, KeyReadyChecker krc, Transaction* trans) {
auto [dbit, added] = watched_dbs_.emplace(trans->GetDbIndex(), nullptr);
if (added) {
dbit->second.reset(new DbWatchTable);
Expand All @@ -254,7 +249,7 @@ void BlockingController::AddWatched(ArgSlice keys, Transaction* trans) {
continue;
}
DVLOG(2) << "Emplace " << trans->DebugId() << " to watch " << key;
res->second->items.emplace_back(trans);
res->second->items.emplace_back(trans, krc);
}
}

Expand All @@ -275,33 +270,40 @@ void BlockingController::AwakeWatched(DbIndex db_index, string_view db_key) {
}

// Marks the queue as active and notifies the first transaction in the queue.
void BlockingController::NotifyWatchQueue(std::string_view key, WatchQueueMap* wqm) {
void BlockingController::NotifyWatchQueue(std::string_view key, WatchQueueMap* wqm,
const DbContext& context) {
auto w_it = wqm->find(key);
CHECK(w_it != wqm->end());
DVLOG(1) << "Notify WQ: [" << owner_->shard_id() << "] " << key;
WatchQueue* wq = w_it->second.get();

DCHECK_EQ(wq->state, WatchQueue::SUSPENDED);
wq->state = WatchQueue::ACTIVE;

auto& queue = wq->items;
ShardId sid = owner_->shard_id();

do {
WatchItem& wi = queue.front();
// In the most cases we shouldn't have skipped elements at all
absl::InlinedVector<dfly::WatchItem, 4> skipped;
while (!queue.empty()) {
auto& wi = queue.front();
Transaction* head = wi.get();
DVLOG(2) << "WQ-Pop " << head->DebugId() << " from key " << key;

if (head->NotifySuspended(owner_->committed_txid(), sid, key)) {
// We deliberately keep the notified transaction in the queue to know which queue
// must handled when this transaction finished.
wq->notify_txid = owner_->committed_txid();
awakened_transactions_.insert(head);
break;
// We check may the transaction be notified otherwise move it to the end of the queue
if (wi.key_ready_checker(owner_, context, head, key)) {
DVLOG(2) << "WQ-Pop " << head->DebugId() << " from key " << key;
if (head->NotifySuspended(owner_->committed_txid(), sid, key)) {
wq->state = WatchQueue::ACTIVE;
// We deliberately keep the notified transaction in the queue to know which queue
// must handled when this transaction finished.
wq->notify_txid = owner_->committed_txid();
awakened_transactions_.insert(head);
break;
}
} else {
skipped.push_back(std::move(wi));
}

queue.pop_front();
} while (!queue.empty());
}
std::move(skipped.begin(), skipped.end(), std::back_inserter(queue));

if (wq->items.empty()) {
wqm->erase(w_it);
Expand Down
4 changes: 2 additions & 2 deletions src/server/blocking_controller.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class BlockingController {
// TODO: consider moving all watched functions to
// EngineShard with separate per db map.
//! AddWatched adds a transaction to the blocking queue.
void AddWatched(ArgSlice watch_keys, Transaction* me);
void AddWatched(ArgSlice watch_keys, KeyReadyChecker krc, Transaction* me);

// Called from operations that create keys like lpush, rename etc.
void AwakeWatched(DbIndex db_index, std::string_view db_key);
Expand All @@ -54,7 +54,7 @@ class BlockingController {

using WatchQueueMap = absl::flat_hash_map<std::string, std::unique_ptr<WatchQueue>>;

void NotifyWatchQueue(std::string_view key, WatchQueueMap* wqm);
void NotifyWatchQueue(std::string_view key, WatchQueueMap* wqm, const DbContext& context);

// void NotifyConvergence(Transaction* tx);

Expand Down
5 changes: 3 additions & 2 deletions src/server/blocking_controller_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ TEST_F(BlockingControllerTest, Basic) {
EngineShard* shard = EngineShard::tlocal();
BlockingController bc(shard);
auto keys = trans_->GetShardArgs(shard->shard_id());
bc.AddWatched(keys, trans_.get());
bc.AddWatched(
keys, [](auto...) { return true; }, trans_.get());
EXPECT_EQ(1, bc.NumWatched(0));

bc.FinalizeWatched(keys, trans_.get());
Expand All @@ -89,7 +90,7 @@ TEST_F(BlockingControllerTest, Timeout) {
trans_->Schedule();
auto cb = [&](Transaction* t, EngineShard* shard) { return trans_->GetShardArgs(0); };

facade::OpStatus status = trans_->WaitOnWatch(tp, cb);
facade::OpStatus status = trans_->WaitOnWatch(tp, cb, [](auto...) { return true; });

EXPECT_EQ(status, facade::OpStatus::TIMED_OUT);
unsigned num_watched = shard_set->Await(
Expand Down
4 changes: 4 additions & 0 deletions src/server/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,10 @@ inline uint32_t MemberTimeSeconds(uint64_t now_ms) {
return (now_ms / 1000) - kMemberExpiryBase;
}

// Checks whether the touched key is valid for a blocking transaction watching it
using KeyReadyChecker =
std::function<bool(EngineShard*, const DbContext& context, Transaction* tx, std::string_view)>;

struct MemoryBytesFlag {
uint64_t value = 0;
};
Expand Down
6 changes: 5 additions & 1 deletion src/server/container_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,11 @@ OpResult<string> RunCbOnFirstNonEmptyBlocking(Transaction* trans, int req_obj_ty
auto wcb = [](Transaction* t, EngineShard* shard) { return t->GetShardArgs(shard->shard_id()); };

*block_flag = true;
auto status = trans->WaitOnWatch(limit_tp, std::move(wcb));
const auto key_checker = [req_obj_type](EngineShard* owner, const DbContext& context,
Transaction*, std::string_view key) -> bool {
return owner->db_slice().FindReadOnly(context, key, req_obj_type).ok();
};
auto status = trans->WaitOnWatch(limit_tp, std::move(wcb), key_checker);
*block_flag = false;

if (status != OpStatus::OK)
Expand Down
13 changes: 11 additions & 2 deletions src/server/list_family.cc
Original file line number Diff line number Diff line change
Expand Up @@ -881,8 +881,12 @@ OpResult<string> BPopPusher::RunSingle(Transaction* t, time_point tp) {

auto wcb = [&](Transaction* t, EngineShard* shard) { return ArgSlice{&this->pop_key_, 1}; };

const auto key_checker = [](EngineShard* owner, const DbContext& context, Transaction*,
std::string_view key) -> bool {
return owner->db_slice().FindReadOnly(context, key, OBJ_LIST).ok();
};
// Block
if (auto status = t->WaitOnWatch(tp, std::move(wcb)); status != OpStatus::OK)
if (auto status = t->WaitOnWatch(tp, std::move(wcb), key_checker); status != OpStatus::OK)
return status;

t->Execute(cb_move, true);
Expand All @@ -906,7 +910,12 @@ OpResult<string> BPopPusher::RunPair(Transaction* t, time_point tp) {
// This allows us to run Transaction::Execute on watched transactions in both shards.
auto wcb = [&](Transaction* t, EngineShard* shard) { return ArgSlice{&this->pop_key_, 1}; };

if (auto status = t->WaitOnWatch(tp, std::move(wcb)); status != OpStatus::OK)
const auto key_checker = [](EngineShard* owner, const DbContext& context, Transaction*,
std::string_view key) -> bool {
return owner->db_slice().FindReadOnly(context, key, OBJ_LIST).ok();
};

if (auto status = t->WaitOnWatch(tp, std::move(wcb), key_checker); status != OpStatus::OK)
return status;

return MoveTwoShards(t, pop_key_, push_key_, popdir_, pushdir_, true);
Expand Down
23 changes: 22 additions & 1 deletion src/server/stream_family.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2817,7 +2817,28 @@ void XReadBlock(ReadOpts opts, ConnectionContext* cntx) {
auto tp = (opts.timeout) ? chrono::steady_clock::now() + chrono::milliseconds(opts.timeout)
: Transaction::time_point::max();

if (auto status = cntx->transaction->WaitOnWatch(tp, std::move(wcb)); status != OpStatus::OK)
const auto key_checker = [&opts](EngineShard* owner, const DbContext& context, Transaction* tx,
std::string_view key) -> bool {
auto res_it = owner->db_slice().FindReadOnly(context, key, OBJ_STREAM);
if (!res_it.ok())
return false;

auto sitem = opts.stream_ids.at(key);
if (sitem.id.val.ms != UINT64_MAX && sitem.id.val.seq != UINT64_MAX)
return true;

const CompactObj& cobj = (*res_it)->second;
stream* s = GetReadOnlyStream(cobj);
streamID last_id = s->last_id;
if (s->length) {
streamLastValidID(s, &last_id);
}

return streamCompareID(&last_id, &sitem.group->last_id) > 0;
};

if (auto status = cntx->transaction->WaitOnWatch(tp, std::move(wcb), key_checker);
status != OpStatus::OK)
return rb->SendNullArray();

// Resolve the entry in the woken key. Note this must not use OpRead since
Expand Down
22 changes: 7 additions & 15 deletions src/server/stream_family_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -342,26 +342,18 @@ TEST_F(StreamFamilyTest, XReadGroupBlock) {
ThisFiber::SleepFor(50us);
pp_->at(1)->Await([&] { return Run("xadd", {"xadd", "bar", "1-*", "k5", "v5"}); });
// The second one should be unblocked
ThisFiber::SleepFor(50us);

fb0.Join();
fb1.Join();
// temporary incorrect results
if (resp0.GetVec()[1].GetVec().size() == 0) {
EXPECT_THAT(resp0.GetVec(), ElementsAre("foo", ArrLen(0)));
EXPECT_THAT(resp1.GetVec(), ElementsAre("foo", ArrLen(1)));
} else {

if (resp0.GetVec()[0].GetString() == "foo") {
EXPECT_THAT(resp0.GetVec(), ElementsAre("foo", ArrLen(1)));
EXPECT_THAT(resp1.GetVec(), ElementsAre("foo", ArrLen(0)));
EXPECT_THAT(resp1.GetVec(), ElementsAre("bar", ArrLen(1)));
} else {
EXPECT_THAT(resp1.GetVec(), ElementsAre("foo", ArrLen(1)));
EXPECT_THAT(resp0.GetVec(), ElementsAre("bar", ArrLen(1)));
}

// correct results
// if (resp0.GetVec()[0].GetString() == "foo") {
// EXPECT_THAT(resp0.GetVec(), ElementsAre("foo", ArrLen(1)));
// EXPECT_THAT(resp1.GetVec(), ElementsAre("bar", ArrLen(1)));
// } else {
// EXPECT_THAT(resp1.GetVec(), ElementsAre("foo", ArrLen(1)));
// EXPECT_THAT(resp0.GetVec(), ElementsAre("bar", ArrLen(1)));
// }
}

TEST_F(StreamFamilyTest, XReadInvalidArgs) {
Expand Down
9 changes: 5 additions & 4 deletions src/server/transaction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1203,13 +1203,14 @@ size_t Transaction::ReverseArgIndex(ShardId shard_id, size_t arg_index) const {
return reverse_index_[sd.arg_start + arg_index];
}

OpStatus Transaction::WaitOnWatch(const time_point& tp, WaitKeysProvider wkeys_provider) {
OpStatus Transaction::WaitOnWatch(const time_point& tp, WaitKeysProvider wkeys_provider,
KeyReadyChecker krc) {
DVLOG(2) << "WaitOnWatch " << DebugId();
using namespace chrono;

auto cb = [&](Transaction* t, EngineShard* shard) {
auto keys = wkeys_provider(t, shard);
return t->WatchInShard(keys, shard);
return t->WatchInShard(keys, shard, krc);
};

Execute(std::move(cb), true);
Expand Down Expand Up @@ -1257,14 +1258,14 @@ OpStatus Transaction::WaitOnWatch(const time_point& tp, WaitKeysProvider wkeys_p
}

// Runs only in the shard thread.
OpStatus Transaction::WatchInShard(ArgSlice keys, EngineShard* shard) {
OpStatus Transaction::WatchInShard(ArgSlice keys, EngineShard* shard, KeyReadyChecker krc) {
ShardId idx = SidToId(shard->shard_id());

auto& sd = shard_data_[idx];
CHECK_EQ(0, sd.local_mask & SUSPENDED_Q);

auto* bc = shard->EnsureBlockingController();
bc->AddWatched(keys, this);
bc->AddWatched(keys, std::move(krc), this);

sd.local_mask |= SUSPENDED_Q;
sd.local_mask &= ~OUT_OF_ORDER;
Expand Down
4 changes: 2 additions & 2 deletions src/server/transaction.h
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ class Transaction {
// or b) tp is reached. If tp is time_point::max() then waits indefinitely.
// Expects that the transaction had been scheduled before, and uses Execute(.., true) to register.
// Returns false if timeout occurred, true if was notified by one of the keys.
facade::OpStatus WaitOnWatch(const time_point& tp, WaitKeysProvider cb);
facade::OpStatus WaitOnWatch(const time_point& tp, WaitKeysProvider cb, KeyReadyChecker krc);

// Returns true if transaction is awaked, false if it's timed-out and can be removed from the
// blocking queue.
Expand Down Expand Up @@ -456,7 +456,7 @@ class Transaction {
void ExecuteAsync();

// Adds itself to watched queue in the shard. Must run in that shard thread.
OpStatus WatchInShard(ArgSlice keys, EngineShard* shard);
OpStatus WatchInShard(ArgSlice keys, EngineShard* shard, KeyReadyChecker krc);

// Expire blocking transaction, unlock keys and unregister it from the blocking controller
void ExpireBlocking(WaitKeysProvider wcb);
Expand Down
20 changes: 20 additions & 0 deletions src/server/zset_family_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -677,6 +677,26 @@ TEST_F(ZSetFamilyTest, BlockingIsReleased) {
}
}

TEST_F(ZSetFamilyTest, BlockingWithIncorrectType) {
RespExpr resp0;
RespExpr resp1;
auto fb0 = pp_->at(0)->LaunchFiber(Launch::dispatch, [&] {
resp0 = Run({"BLPOP", "list1", "0"});
});
auto fb1 = pp_->at(1)->LaunchFiber(Launch::dispatch, [&] {
resp1 = Run({"BZPOPMIN", "list1", "0"});
});

ThisFiber::SleepFor(50us);
pp_->at(2)->Await([&] { return Run({"ZADD", "list1", "1", "a"}); });
pp_->at(2)->Await([&] { return Run({"LPUSH", "list1", "0"}); });
fb0.Join();
fb1.Join();

EXPECT_THAT(resp1.GetVec(), ElementsAre("list1", "a", "1"));
EXPECT_THAT(resp0.GetVec(), ElementsAre("list1", "0"));
}

TEST_F(ZSetFamilyTest, BlockingTimeout) {
RespExpr resp0;

Expand Down

0 comments on commit 5b90545

Please sign in to comment.