From 73c5897e2562f640b216ec58bebc1617a1a9a86c Mon Sep 17 00:00:00 2001 From: Stsiapan Bahrytsevich Date: Tue, 29 Oct 2024 10:20:38 +0100 Subject: [PATCH] fix(search_family): Fix LOAD fields parsing in the FT.AGGREGATE and FT.SEARCH commands fixes dragonflydb#3989 Signed-off-by: Stsiapan Bahrytsevich --- src/core/search/search.cc | 14 +++- src/core/search/search.h | 9 ++- src/server/search/doc_accessors.cc | 9 ++- src/server/search/doc_index.cc | 50 ++++++++----- src/server/search/doc_index.h | 108 +++++++++++++++++++++++++++-- src/server/search/search_family.cc | 41 ++++++++--- 6 files changed, 196 insertions(+), 35 deletions(-) diff --git a/src/core/search/search.cc b/src/core/search/search.cc index cd0cc5a8a232..e5c8d8e73fc1 100644 --- a/src/core/search/search.cc +++ b/src/core/search/search.cc @@ -501,6 +501,12 @@ string_view Schema::LookupAlias(string_view alias) const { return alias; } +string_view Schema::LookupIdentifier(string_view identifier) const { + if (auto it = fields.find(identifier); it != fields.end()) + return it->second.short_name; + return identifier; +} + IndicesOptions::IndicesOptions() { static absl::flat_hash_set kDefaultStopwords{ "a", "is", "the", "an", "and", "are", "as", "at", "be", "but", "by", @@ -621,10 +627,14 @@ const Schema& FieldIndices::GetSchema() const { return schema_; } -vector> FieldIndices::ExtractStoredValues(DocId doc) const { +vector> FieldIndices::ExtractStoredValues( + DocId doc, const absl::flat_hash_map& aliases) const { vector> out; for (const auto& [ident, index] : sort_indices_) { - out.emplace_back(ident, index->Lookup(doc)); + const auto& it = aliases.find(ident); + const auto& name = it == aliases.end() ? schema_.LookupIdentifier(ident) : it->second; + + out.emplace_back(name, index->Lookup(doc)); } return out; } diff --git a/src/core/search/search.h b/src/core/search/search.h index c37a67fa7d6e..1cdee57b4dc8 100644 --- a/src/core/search/search.h +++ b/src/core/search/search.h @@ -60,6 +60,9 @@ struct Schema { // Return identifier for alias if found, otherwise return passed value std::string_view LookupAlias(std::string_view alias) const; + + // Return alias for identifier if found, otherwise return passed value + std::string_view LookupIdentifier(std::string_view identifier) const; }; struct IndicesOptions { @@ -88,7 +91,11 @@ class FieldIndices { const Schema& GetSchema() const; // Extract values stored in sort indices - std::vector> ExtractStoredValues(DocId doc) const; + // aliases are specified in addition to the aliases in search::Schema + std::vector> ExtractStoredValues( + DocId doc, + const absl::flat_hash_map& + aliases) const; absl::flat_hash_set GetSortIndiciesFields() const; diff --git a/src/server/search/doc_accessors.cc b/src/server/search/doc_accessors.cc index b256647fbf97..556eb599af82 100644 --- a/src/server/search/doc_accessors.cc +++ b/src/server/search/doc_accessors.cc @@ -82,7 +82,10 @@ search::SortableValue ExtractSortableValueFromJson(const search::Schema& schema, SearchDocData BaseAccessor::Serialize( const search::Schema& schema, absl::Span> fields) const { SearchDocData out{}; - for (const auto& [fident, fname] : fields) { + for (const auto& field : fields) { + const auto& fident = field.GetIdentifier(schema, false); + const auto& fname = field.GetShortName(schema); + out[fname] = ExtractSortableValue(schema, fident, absl::StrJoin(GetStrings(fident), ",")); } return out; @@ -257,7 +260,9 @@ SearchDocData JsonAccessor::Serialize(const search::Schema& schema) const { SearchDocData JsonAccessor::Serialize( const search::Schema& schema, absl::Span> fields) const { SearchDocData out{}; - for (const auto& [ident, name] : fields) { + for (const auto& field : fields) { + const auto& ident = field.GetIdentifier(schema, true); + const auto& name = field.GetShortName(schema); if (auto* path = GetPath(ident); path) { if (auto res = path->Evaluate(json_); !res.empty()) out[name] = ExtractSortableValueFromJson(schema, ident, res[0]); diff --git a/src/server/search/doc_index.cc b/src/server/search/doc_index.cc index 5835971eb76f..27d1d6863107 100644 --- a/src/server/search/doc_index.cc +++ b/src/server/search/doc_index.cc @@ -60,8 +60,8 @@ bool SerializedSearchDoc::operator>=(const SerializedSearchDoc& other) const { return this->score >= other.score; } -bool SearchParams::ShouldReturnField(std::string_view field) const { - auto cb = [field](const auto& entry) { return entry.first == field; }; +bool SearchParams::ShouldReturnField(std::string_view alias) const { + auto cb = [alias](const auto& entry) { return entry.AsShortName() == alias; }; return !return_fields || any_of(return_fields->begin(), return_fields->end(), cb); } @@ -211,12 +211,13 @@ bool ShardDocIndex::Matches(string_view key, unsigned obj_code) const { return base_->Matches(key, obj_code); } -SearchFieldsList ToSV(const std::optional& fields) { +SearchFieldsList ToSV(const search::Schema& schema, + const std::optional& fields) { SearchFieldsList sv_fields; if (fields) { sv_fields.reserve(fields->size()); - for (const auto& [fident, fname] : fields.value()) { - sv_fields.emplace_back(fident, fname); + for (const auto& field : fields.value()) { + sv_fields.emplace_back(field.GetIdentifier(schema), field.GetShortName(schema)); } } return sv_fields; @@ -230,8 +231,8 @@ SearchResult ShardDocIndex::Search(const OpArgs& op_args, const SearchParams& pa if (!search_results.error.empty()) return SearchResult{facade::ErrorReply{std::move(search_results.error)}}; - SearchFieldsList fields_to_load = - ToSV(params.ShouldReturnAllFields() ? params.load_fields : params.return_fields); + SearchFieldsList fields_to_load = ToSV( + base_->schema, params.ShouldReturnAllFields() ? params.load_fields : params.return_fields); vector out; out.reserve(search_results.ids.size()); @@ -281,8 +282,19 @@ vector ShardDocIndex::SearchForAggregator( if (!search_results.error.empty()) return {}; - SearchFieldsList fields_to_load = - GetFieldsToLoad(params.load_fields, indices_->GetSortIndiciesFields()); + auto sort_indicies_fields = indices_->GetSortIndiciesFields(); + SearchFieldsList fields_to_load = GetFieldsToLoad(params.load_fields, sort_indicies_fields); + + // aliases for ExtractStoredValues + absl::flat_hash_map sort_indicies_aliases; + if (params.load_fields) { + for (const auto& field : params.load_fields.value()) { + auto ident = field.GetIdentifier(base_->schema); + if (sort_indicies_fields.contains(ident)) { + sort_indicies_aliases[ident] = field.AsShortName(); + } + } + } vector> out; for (DocId doc : search_results.ids) { @@ -293,7 +305,7 @@ vector ShardDocIndex::SearchForAggregator( continue; auto accessor = GetAccessor(op_args.db_cntx, (*it)->second); - auto extracted = indices_->ExtractStoredValues(doc); + auto extracted = indices_->ExtractStoredValues(doc, sort_indicies_aliases); SearchDocData loaded = accessor->Serialize(base_->schema, fields_to_load); @@ -307,25 +319,31 @@ vector ShardDocIndex::SearchForAggregator( SearchFieldsList ShardDocIndex::GetFieldsToLoad( const std::optional& load_fields, const absl::flat_hash_set& skip_fields) const { - // identifier to short name - absl::flat_hash_map unique_fields; + absl::flat_hash_map> unique_fields; unique_fields.reserve(base_->schema.field_names.size()); for (const auto& [fname, fident] : base_->schema.field_names) { if (!skip_fields.contains(fident)) { - unique_fields[fident] = fname; + unique_fields[fident] = {std::string_view{fident}, std::string_view{fname}}; } } if (load_fields) { - for (const auto& [fident, fname] : load_fields.value()) { + for (const auto& field : load_fields.value()) { + const auto& fident = field.GetIdentifier(base_->schema); if (!skip_fields.contains(fident)) { - unique_fields[fident] = fname; + unique_fields[fident] = field; } } } - return {unique_fields.begin(), unique_fields.end()}; + SearchFieldsList fields; + fields.reserve(unique_fields.size()); + for (auto& [_, field] : unique_fields) { + fields.emplace_back(std::move(field)); + } + + return fields; } DocIndexInfo ShardDocIndex::GetInfo() const { diff --git a/src/server/search/doc_index.h b/src/server/search/doc_index.h index 564ca6193540..f6c3db38d02c 100644 --- a/src/server/search/doc_index.h +++ b/src/server/search/doc_index.h @@ -52,7 +52,106 @@ struct SearchResult { std::optional error; }; -template using SearchField = std::pair; +enum class NameType : uint8_t { kIdentifier, kShortName, kUndefined }; + +template class SearchField { + private: + using SingleName = std::pair; + using IdentifierAndName = std::pair; + + static bool IsJsonPath(const T& name) { + if (name.size() < 2) { + return false; + } + return name.front() == '$' && (name[1] == '.' || name[1] == '['); + } + + public: + SearchField() = default; + + explicit SearchField(T name) : name_(std::make_pair(std::move(name), NameType::kUndefined)) { + } + + SearchField(T name, NameType name_type) : name_(std::make_pair(std::move(name), name_type)) { + } + + SearchField(T identifier, T short_name) + : name_(std::make_pair(std::move(identifier), std::move(short_name))) { + } + + template >> + SearchField& operator=(const SearchField& other) { + if (other.IsIdentifierAndName()) { + const auto& ident_and_name = other.AsIdentifierAndName(); + name_ = std::make_pair(T{ident_and_name.first}, T{ident_and_name.second}); + } else { + const auto& single_name = other.AsSingle(); + name_ = std::make_pair(T{single_name.first}, single_name.second); + } + return *this; + } + + std::string_view GetIdentifier(const search::Schema& schema) const { + return GetIdentifier(schema, [&](const SingleName& single_name) { + return single_name.second == NameType::kIdentifier; + }); + } + + std::string_view GetIdentifier(const search::Schema& schema, bool is_json_field) const { + return GetIdentifier(schema, [&](const SingleName& single_name) { + return single_name.second == NameType::kIdentifier || + (is_json_field && IsJsonPath(single_name.first)); + }); + } + + std::string_view GetShortName(const search::Schema& schema) const { + if (IsIdentifierAndName()) { + return AsIdentifierAndName().second; + } + + const auto& single_name = AsSingle(); + if (single_name.second == NameType::kShortName) { + return single_name.first; + } + return schema.LookupIdentifier(std::string_view{single_name.first}); + } + + std::string_view AsShortName() const { + if (IsIdentifierAndName()) { + return AsIdentifierAndName().second; + } + return AsSingle().first; + } + + bool IsIdentifierAndName() const { + return std::holds_alternative(name_); + } + + const IdentifierAndName& AsIdentifierAndName() const { + return std::get(name_); + } + + const SingleName& AsSingle() const { + return std::get(name_); + } + + private: + template + std::string_view GetIdentifier(const search::Schema& schema, Callback is_identifier) const { + if (IsIdentifierAndName()) { + return AsIdentifierAndName().first; + } + + const auto& single_name = AsSingle(); + if (is_identifier(single_name)) { + return single_name.first; + } + return schema.LookupAlias(std::string_view{single_name.first}); + } + + private: + std::variant name_; +}; using SearchFieldsList = std::vector>; using OwnedSearchFieldsList = std::vector>; @@ -88,7 +187,7 @@ struct SearchParams { return return_fields && return_fields->empty(); } - bool ShouldReturnField(std::string_view field) const; + bool ShouldReturnField(std::string_view alias) const; }; struct AggregateParams { @@ -169,8 +268,9 @@ class ShardDocIndex { io::Result GetTagVals(std::string_view field) const; private: - // Returns the fields that are the union of the already indexed fields and load_fields, excluding - // skip_fields Load_fields should not be destroyed while the result of this function is being used + /* Returns the fields that are the union of the already indexed fields and load_fields, excluding + skip_fields. + Load_fields should not be destroyed while the result of this function is being used */ SearchFieldsList GetFieldsToLoad(const std::optional& load_fields, const absl::flat_hash_set& skip_fields) const; diff --git a/src/server/search/search_family.cc b/src/server/search/search_family.cc index 6ef550715c10..a215d87800ee 100644 --- a/src/server/search/search_family.cc +++ b/src/server/search/search_family.cc @@ -183,6 +183,10 @@ optional ParseSchemaOrReply(DocIndex::DataType type, CmdArgParse #pragma GCC diagnostic pop #endif +bool StartsWithAtSign(std::string_view field) { + return !field.empty() && field.front() == '@'; +} + std::string_view ParseField(CmdArgParser* parser) { std::string_view field = parser->Next(); if (!field.empty() && field.front() == '@') { @@ -204,15 +208,28 @@ std::string_view ParseFieldWithAtSign(CmdArgParser* parser) { } void ParseLoadFields(CmdArgParser* parser, std::optional* load_fields) { - size_t num_fields = parser->Next(); + using Field = SearchField; + + size_t num_strings = parser->Next(); if (!load_fields->has_value()) { load_fields->emplace(); } - while (num_fields--) { - string_view field = ParseField(parser); - string_view alias = parser->Check("AS") ? parser->Next() : field; - load_fields->value().emplace_back(field, alias); + while (num_strings--) { + string_view str = parser->Next(); + + Field field; + if (parser->Check("AS")) { // str is identifier + field = {std::string{str}, std::string{parser->Next()}}; + num_strings -= 2; + } else if (StartsWithAtSign(str)) { + str.remove_prefix(1); // remove leading @ + field = {std::string{str}, NameType::kShortName}; + } else { + field = Field{std::string{str}}; + } + + load_fields->value().emplace_back(std::move(field)); } } @@ -248,12 +265,16 @@ optional ParseSearchParamsOrReply(CmdArgParser parser, SinkReplyBu } // RETURN {num} [{ident} AS {name}...] - size_t num_fields = parser.Next(); + size_t num_strings = parser.Next(); params.return_fields.emplace(); - while (params.return_fields->size() < num_fields) { - string_view ident = parser.Next(); - string_view alias = parser.Check("AS") ? parser.Next() : ident; - params.return_fields->emplace_back(ident, alias); + while (num_strings--) { + std::string_view ident = parser.Next(); + if (parser.Check("AS")) { + params.return_fields->emplace_back(std::string{ident}, std::string{parser.Next()}); + num_strings -= 2; + } else { + params.return_fields->emplace_back(std::string{ident}, NameType::kIdentifier); + } } } else if (parser.Check("NOCONTENT")) { // NOCONTENT params.load_fields.emplace();