Skip to content

Commit

Permalink
fix(search): Use double numeric values (dragonflydb#2015)
Browse files Browse the repository at this point in the history
* fix(search): Use double numeric values

Signed-off-by: Vladislav Oleshko <[email protected]>

---------

Signed-off-by: Vladislav Oleshko <[email protected]>
  • Loading branch information
dranikpg authored and azuredream committed Oct 17, 2023
1 parent d8b6591 commit 03e2771
Show file tree
Hide file tree
Showing 12 changed files with 67 additions and 40 deletions.
7 changes: 3 additions & 4 deletions src/core/search/ast_expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,10 @@ using namespace std;

namespace dfly::search {

AstTermNode::AstTermNode(string term)
: term{term}, pattern{"\\b" + term + "\\b", std::regex::icase} {
AstTermNode::AstTermNode(string term) : term{term} {
}

AstRangeNode::AstRangeNode(int64_t lo, int64_t hi) : lo{lo}, hi{hi} {
AstRangeNode::AstRangeNode(double lo, double hi) : lo{lo}, hi{hi} {
}

AstNegateNode::AstNegateNode(AstNode&& node) : node{make_unique<AstNode>(move(node))} {
Expand Down Expand Up @@ -56,7 +55,7 @@ AstTagsNode::AstTagsNode(AstExpr&& l, std::string tag) {
tags.push_back(move(tag));
}

AstKnnNode::AstKnnNode(size_t limit, std::string_view field, OwnedFtVector vec,
AstKnnNode::AstKnnNode(uint32_t limit, std::string_view field, OwnedFtVector vec,
std::string_view score_alias)
: filter{nullptr},
limit{limit},
Expand Down
9 changes: 4 additions & 5 deletions src/core/search/ast_expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
#include <iostream>
#include <memory>
#include <ostream>
#include <regex>
#include <variant>
#include <vector>

Expand All @@ -28,14 +27,13 @@ struct AstTermNode {
AstTermNode(std::string term);

std::string term;
std::regex pattern;
};

// Matches numeric range
struct AstRangeNode {
AstRangeNode(int64_t lo, int64_t hi);
AstRangeNode(double lo, double hi);

int64_t lo, hi;
double lo, hi;
};

// Negates subtree
Expand Down Expand Up @@ -75,7 +73,8 @@ struct AstTagsNode {
// Applies nearest neighbor search to the final result set
struct AstKnnNode {
AstKnnNode() = default;
AstKnnNode(size_t limit, std::string_view field, OwnedFtVector vec, std::string_view score_alias);
AstKnnNode(uint32_t limit, std::string_view field, OwnedFtVector vec,
std::string_view score_alias);
AstKnnNode(AstNode&& sub, AstKnnNode&& self);

friend std::ostream& operator<<(std::ostream& stream, const AstKnnNode& matrix) {
Expand Down
2 changes: 1 addition & 1 deletion src/core/search/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ struct WrappedStrPtr {
std::unique_ptr<char[]> ptr;
};

using ResultScore = std::variant<std::monostate, float, int64_t, WrappedStrPtr>;
using ResultScore = std::variant<std::monostate, float, double, WrappedStrPtr>;

// Interface for accessing document values with different data structures underneath.
struct DocumentAccessor {
Expand Down
8 changes: 4 additions & 4 deletions src/core/search/indices.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ NumericIndex::NumericIndex(PMR_NS::memory_resource* mr) : entries_{mr} {
}

void NumericIndex::Add(DocId id, DocumentAccessor* doc, string_view field) {
int64_t num;
if (absl::SimpleAtoi(doc->GetString(field), &num))
double num;
if (absl::SimpleAtod(doc->GetString(field), &num))
entries_.emplace(num, id);
}

Expand All @@ -70,9 +70,9 @@ void NumericIndex::Remove(DocId id, DocumentAccessor* doc, string_view field) {
entries_.erase({num, id});
}

vector<DocId> NumericIndex::Range(int64_t l, int64_t r) const {
vector<DocId> NumericIndex::Range(double l, double r) const {
auto it_l = entries_.lower_bound({l, 0});
auto it_r = entries_.lower_bound({r + 1, 0});
auto it_r = entries_.lower_bound({r, numeric_limits<DocId>::max()});

vector<DocId> out;
for (auto it = it_l; it != it_r; ++it)
Expand Down
4 changes: 2 additions & 2 deletions src/core/search/indices.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ struct NumericIndex : public BaseIndex {
void Add(DocId id, DocumentAccessor* doc, std::string_view field) override;
void Remove(DocId id, DocumentAccessor* doc, std::string_view field) override;

std::vector<DocId> Range(int64_t l, int64_t r) const;
std::vector<DocId> Range(double l, double r) const;

private:
using Entry = std::pair<int64_t, DocId>;
using Entry = std::pair<double, DocId>;
absl::btree_set<Entry, std::less<Entry>, PMR_NS::polymorphic_allocator<Entry>> entries_;
};

Expand Down
22 changes: 16 additions & 6 deletions src/core/search/lexer.lex
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
using dfly::search::Parser;
using namespace std;

Parser::symbol_type make_INT64 (string_view, const Parser::location_type& loc);
Parser::symbol_type make_DOUBLE(string_view, const Parser::location_type& loc);
Parser::symbol_type make_UINT32(string_view, const Parser::location_type& loc);
Parser::symbol_type make_StringLit(string_view src, const Parser::location_type& loc);
%}

Expand Down Expand Up @@ -64,7 +65,8 @@ term_char [_]|\w
"KNN" return Parser::make_KNN (loc());
"AS" return Parser::make_AS (loc());

-?[0-9]+ return make_INT64(matched_view(), loc());
[0-9]+ return make_UINT32(matched_view(), loc());
[+-]?([0-9]*[.])?[0-9]+ return make_DOUBLE(matched_view(), loc());

{dq}{str_char}*{dq} return make_StringLit(matched_view(1, 1), loc());

Expand All @@ -76,12 +78,20 @@ term_char [_]|\w
<<EOF>> return Parser::make_YYEOF(loc());
%%

Parser::symbol_type make_INT64 (string_view str, const Parser::location_type& loc) {
int64_t val = 0;
Parser::symbol_type make_UINT32 (string_view str, const Parser::location_type& loc) {
uint32_t val = 0;
if (!absl::SimpleAtoi(str, &val))
throw Parser::syntax_error (loc, "not an integer or out of range: " + string(str));
throw Parser::syntax_error (loc, "not an unsigned integer or out of range: " + string(str));

return Parser::make_INT64(val, loc);
return Parser::make_UINT32(val, loc);
}

Parser::symbol_type make_DOUBLE (string_view str, const Parser::location_type& loc) {
double val = 0;
if (!absl::SimpleAtod(str, &val))
throw Parser::syntax_error (loc, "not a double or out of range: " + string(str));

return Parser::make_DOUBLE(val, loc);
}

Parser::symbol_type make_StringLit(string_view src, const Parser::location_type& loc) {
Expand Down
18 changes: 10 additions & 8 deletions src/core/search/parser.y
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ using namespace std;
%right NOT_OP
%precedence LPAREN RPAREN

%token <int64_t> INT64 "int64"
%token <double> DOUBLE "double"
%token <uint32_t> UINT32 "uint32"
%nterm <AstExpr> final_query filter search_expr search_unary_expr search_or_expr search_and_expr
%nterm <AstExpr> field_cond field_cond_expr field_unary_expr field_or_expr field_and_expr tag_list

Expand All @@ -90,7 +91,7 @@ final_query:
{ driver->Set(AstKnnNode(move($1), move($3))); }

knn_query:
LBRACKET KNN INT64 FIELD TERM opt_knn_alias RBRACKET
LBRACKET KNN UINT32 FIELD TERM opt_knn_alias RBRACKET
{ $$ = AstKnnNode($3, $4, BytesToFtVector($5), $6); }

opt_knn_alias:
Expand Down Expand Up @@ -118,15 +119,16 @@ search_unary_expr:
LPAREN search_expr RPAREN { $$ = move($2); }
| NOT_OP search_unary_expr { $$ = AstNegateNode(move($2)); }
| TERM { $$ = AstTermNode(move($1)); }
| INT64 { $$ = AstTermNode(to_string($1)); }
| UINT32 { $$ = AstTermNode(to_string($1)); }
| FIELD COLON field_cond { $$ = AstFieldNode(move($1), move($3)); }

field_cond:
TERM { $$ = AstTermNode(move($1)); }
| INT64 { $$ = AstTermNode(to_string($1)); }
| UINT32 { $$ = AstTermNode(to_string($1)); }
| NOT_OP field_cond { $$ = AstNegateNode(move($2)); }
| LPAREN field_cond_expr RPAREN { $$ = move($2); }
| LBRACKET INT64 INT64 RBRACKET { $$ = AstRangeNode(move($2), move($3)); }
| LBRACKET DOUBLE DOUBLE RBRACKET { $$ = AstRangeNode(move($2), move($3)); }
| LBRACKET UINT32 UINT32 RBRACKET { $$ = AstRangeNode(move($2), move($3)); }
| LCURLBR tag_list RCURLBR { $$ = move($2); }

field_cond_expr:
Expand All @@ -146,13 +148,13 @@ field_unary_expr:
LPAREN field_cond_expr RPAREN { $$ = move($2); }
| NOT_OP field_unary_expr { $$ = AstNegateNode(move($2)); };
| TERM { $$ = AstTermNode(move($1)); }
| INT64 { $$ = AstTermNode(to_string($1)); }
| UINT32 { $$ = AstTermNode(to_string($1)); }

tag_list:
TERM { $$ = AstTagsNode(move($1)); }
| INT64 { $$ = AstTagsNode(to_string($1)); }
| UINT32 { $$ = AstTagsNode(to_string($1)); }
| tag_list OR_OP TERM { $$ = AstTagsNode(move($1), move($3)); }
| tag_list OR_OP INT64 { $$ = AstTagsNode(move($1), to_string($3)); }
| tag_list OR_OP DOUBLE { $$ = AstTagsNode(move($1), to_string($3)); }

%%

Expand Down
4 changes: 2 additions & 2 deletions src/core/search/scanner.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,11 @@ class Scanner : public Lexer {
if (str.empty())
throw std::runtime_error(absl::StrCat("Query parameter ", name, " not found"));

int64_t val = 0;
uint32_t val = 0;
if (!absl::SimpleAtoi(str, &val))
return Parser::make_TERM(std::string{str}, loc);

return Parser::make_INT64(val, loc);
return Parser::make_UINT32(val, loc);
}

private:
Expand Down
9 changes: 7 additions & 2 deletions src/core/search/search_parser_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ TEST_F(SearchParserTest, Scanner) {

NEXT_TOK(TOK_LPAREN);
NEXT_EQ(TOK_TERM, string, "5a");
NEXT_EQ(TOK_INT64, int64_t, 6);
NEXT_EQ(TOK_UINT32, uint32_t, 6);
NEXT_TOK(TOK_RPAREN);

SetInput(R"( "hello\"world" )");
Expand All @@ -99,6 +99,11 @@ TEST_F(SearchParserTest, Scanner) {
NEXT_EQ(TOK_TERM, string, "почтальон");
NEXT_EQ(TOK_TERM, string, "Печкин");

double d;
absl::SimpleAtod("33.3", &d);
SetInput("33.3");
NEXT_EQ(TOK_DOUBLE, double, d);

SetInput("18446744073709551616");
NEXT_ERROR();
}
Expand All @@ -122,7 +127,7 @@ TEST_F(SearchParserTest, ParseParams) {

SetInput("$name $k");
NEXT_EQ(TOK_TERM, string, "alex");
NEXT_EQ(TOK_INT64, int64_t, 10);
NEXT_EQ(TOK_UINT32, uint32_t, 10);
}

} // namespace dfly::search
12 changes: 12 additions & 0 deletions src/core/search/search_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,18 @@ TEST_F(SearchTest, MatchRange) {
EXPECT_TRUE(Check()) << GetError();
}

TEST_F(SearchTest, MatchDoubleRange) {
PrepareSchema({{"f1", SchemaField::NUMERIC}});
PrepareQuery("@f1: [100.03 199.97]");

ExpectAll(Map{{"f1", "130"}}, Map{{"f1", "170"}}, Map{{"f1", "100.03"}}, Map{{"f1", "199.97"}});

ExpectNone(Map{{"f1", "0"}}, Map{{"f1", "200"}}, Map{{"f1", "100.02999"}},
Map{{"f1", "199.9700001"}});

EXPECT_TRUE(Check()) << GetError();
}

TEST_F(SearchTest, MatchStar) {
PrepareQuery("*");
ExpectAll("one", "two", "three", "and", "all", "documents");
Expand Down
8 changes: 4 additions & 4 deletions src/core/search/sort_indices.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,12 @@ template <typename T> PMR_NS::memory_resource* SimpleValueSortIndex<T>::GetMemRe
return values_.get_allocator().resource();
}

template struct SimpleValueSortIndex<int64_t>;
template struct SimpleValueSortIndex<double>;
template struct SimpleValueSortIndex<PMR_NS::string>;

int64_t NumericSortIndex::Get(DocId id, DocumentAccessor* doc, std::string_view field) {
int64_t v;
if (!absl::SimpleAtoi(doc->GetString(field), &v))
double NumericSortIndex::Get(DocId id, DocumentAccessor* doc, std::string_view field) {
double v;
if (!absl::SimpleAtod(doc->GetString(field), &v))
return 0;
return v;
}
Expand Down
4 changes: 2 additions & 2 deletions src/core/search/sort_indices.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ template <typename T> struct SimpleValueSortIndex : BaseSortIndex {
PMR_NS::vector<T> values_;
};

struct NumericSortIndex : public SimpleValueSortIndex<int64_t> {
struct NumericSortIndex : public SimpleValueSortIndex<double> {
NumericSortIndex(PMR_NS::memory_resource* mr) : SimpleValueSortIndex{mr} {};

int64_t Get(DocId id, DocumentAccessor* doc, std::string_view field) override;
double Get(DocId id, DocumentAccessor* doc, std::string_view field) override;
};

// TODO: Map tags to integers for fast sort
Expand Down

0 comments on commit 03e2771

Please sign in to comment.