Skip to content

Commit

Permalink
fix(rax_tree): Fix double raxStop call in the SeekIterator
Browse files Browse the repository at this point in the history
fixes #4172

Signed-off-by: Stepan Bagritsevich <[email protected]>
  • Loading branch information
BagritsevichStepan committed Dec 3, 2024
1 parent c857ff9 commit 103d333
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 31 deletions.
92 changes: 61 additions & 31 deletions src/core/search/rax_tree.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <utility>

#include "base/pmr/memory_resource.h"
#include "glog/logging.h"

extern "C" {
#include "redis/rax.h"
Expand All @@ -24,63 +25,85 @@ template <typename V> struct RaxTreeMap {

// Simple seeking iterator
struct SeekIterator {
friend struct FindIterator;
SeekIterator() = default;

SeekIterator() {
raxStart(&it_, nullptr);
it_.node = nullptr;
SeekIterator(rax* tree, const char* op, std::string_view key) {
it_.emplace();

raxStart(&it_.value(), tree);
if (raxSeek(&it_.value(), op, to_key_ptr(key), key.size())) { // Successufly seeked
operator++();
} else {
InvalidateIterator();
LOG_IF(DFATAL, errno == ENOMEM) << "Out of memory during raxSeek()";
}
}

~SeekIterator() {
raxStop(&it_);
explicit SeekIterator(rax* tree) : SeekIterator(tree, "^", std::string_view{nullptr, 0}) {
}

SeekIterator(SeekIterator&&) = delete; // self-referential
SeekIterator(const SeekIterator&) = delete; // self-referential

SeekIterator(rax* tree, const char* op, std::string_view key) {
raxStart(&it_, tree);
raxSeek(&it_, op, to_key_ptr(key), key.size());
operator++();
}
/* Remove copy/move constructors to avoid double iterator invalidation */
SeekIterator(SeekIterator&&) = delete;
SeekIterator(const SeekIterator&) = delete;
SeekIterator& operator=(SeekIterator&&) = delete;
SeekIterator& operator=(const SeekIterator&) = delete;

explicit SeekIterator(rax* tree) : SeekIterator(tree, "^", std::string_view{nullptr, 0}) {
~SeekIterator() {
if (it_) {
InvalidateIterator();
}
}

bool operator==(const SeekIterator& rhs) const {
return it_.node == rhs.it_.node;
if (!IsValid() || !rhs.IsValid())
return !IsValid() && !rhs.IsValid();
return it_->node == rhs.it_->node;
}

bool operator!=(const SeekIterator& rhs) const {
return !operator==(rhs);
}

SeekIterator& operator++() {
if (!raxNext(&it_)) {
raxStop(&it_);
it_.node = nullptr;
DCHECK(IsValid());

int next_result = raxNext(&it_.value());
if (!next_result) { // OOM or we reached the end of the tree
InvalidateIterator();
LOG_IF(DFATAL, errno == ENOMEM) << "Out of memory during raxNext()";
}

return *this;
}

/* After operator++() the first value (string_view) is invalid. So make sure your copied it to
* string */
std::pair<std::string_view, V&> operator*() const {
return {std::string_view{reinterpret_cast<const char*>(it_.key), it_.key_len},
*reinterpret_cast<V*>(it_.data)};
DCHECK(IsValid() && it_->node && it_->node->iskey && it_->data);
return {std::string_view{reinterpret_cast<const char*>(it_->key), it_->key_len},
*reinterpret_cast<V*>(it_->data)};
}

bool IsValid() const {
return it_.has_value();
}

private:
raxIterator it_;
void InvalidateIterator() {
DCHECK(IsValid());
raxStop(&it_.value());
it_.reset();
}

std::optional<raxIterator> it_;
};

// Result of find() call. Inherits from pair to mimic iterator interface, not incrementable.
struct FindIterator : public std::optional<std::pair<std::string, V&>> {
bool operator==(const SeekIterator& rhs) const {
if (this->has_value() != !bool(rhs.it_.flags & RAX_ITER_EOF))
return false;
if (!this->has_value())
return true;
return (*this)->first ==
std::string_view{reinterpret_cast<const char*>(rhs.it_.key), rhs.it_.key_len};
if (!this->has_value() || !rhs.IsValid())
return !this->has_value() && !rhs.IsValid();
return (*this)->first == (*rhs).first;
}

bool operator!=(const SeekIterator& rhs) const {
Expand Down Expand Up @@ -128,9 +151,13 @@ template <typename V> struct RaxTreeMap {
std::pair<FindIterator, bool> try_emplace(std::string_view key, Args&&... args);

void erase(FindIterator it) {
DCHECK(it);

V* old = nullptr;
raxRemove(tree_, to_key_ptr(it->first.data()), it->first.size(),
reinterpret_cast<void**>(&old));
int was_removed = raxRemove(tree_, to_key_ptr(it->first.data()), it->first.size(),
reinterpret_cast<void**>(&old));
DCHECK(was_removed);

std::allocator_traits<decltype(alloc_)>::destroy(alloc_, old);
alloc_.deallocate(old, 1);
}
Expand Down Expand Up @@ -159,7 +186,10 @@ std::pair<typename RaxTreeMap<V>::FindIterator, bool> RaxTreeMap<V>::try_emplace
std::allocator_traits<decltype(alloc_)>::construct(alloc_, ptr, std::forward<Args>(args)...);

V* old = nullptr;
raxInsert(tree_, to_key_ptr(key), key.size(), ptr, reinterpret_cast<void**>(&old));
int was_inserted =
raxInsert(tree_, to_key_ptr(key), key.size(), ptr, reinterpret_cast<void**>(&old));
DCHECK(was_inserted);

assert(old == nullptr);

auto it = std::make_optional(std::pair<std::string, V&>(std::string(key), *ptr));
Expand Down
24 changes: 24 additions & 0 deletions src/core/search/rax_tree_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -104,4 +104,28 @@ TEST_F(RaxTreeTest, Find) {
EXPECT_TRUE(map.find(string_view{}) == map.end());
}

/* Run with mimalloc to make sure there is no double free */
TEST_F(RaxTreeTest, Iterate) {
const char* kKeys[] = {
"aaaaaaaaaaaaaaaaaaaa",
"bbbbbbbbbbbbbbbbbbbbbb"
"cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc",
"dddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddd"
"eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee",
};

RaxTreeMap<int> map(pmr::get_default_resource());
for (const char* key : kKeys) {
map.try_emplace(key, 2);
}

for (auto it = map.begin(); it != map.end(); ++it) {
EXPECT_EQ((*it).second, 2);
}

for (auto it = map.begin(); it != map.end(); ++it) {
EXPECT_EQ((*it).second, 2);
}
}

} // namespace dfly::search

0 comments on commit 103d333

Please sign in to comment.