Skip to content

Commit

Permalink
fix(search_family): Fix LOAD fields parsing in the FT.AGGREGATE and F…
Browse files Browse the repository at this point in the history
…T.SEARCH commands

fixes dragonflydb#3989

Signed-off-by: Stsiapan Bahrytsevich <[email protected]>
  • Loading branch information
BagritsevichStepan committed Nov 19, 2024
1 parent 794bd1c commit 7742b5a
Show file tree
Hide file tree
Showing 7 changed files with 391 additions and 34 deletions.
14 changes: 12 additions & 2 deletions src/core/search/search.cc
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,12 @@ string_view Schema::LookupAlias(string_view alias) const {
return alias;
}

string_view Schema::LookupIdentifier(string_view identifier) const {
if (auto it = fields.find(identifier); it != fields.end())
return it->second.short_name;
return identifier;
}

IndicesOptions::IndicesOptions() {
static absl::flat_hash_set<std::string> kDefaultStopwords{
"a", "is", "the", "an", "and", "are", "as", "at", "be", "but", "by",
Expand Down Expand Up @@ -646,10 +652,14 @@ const Schema& FieldIndices::GetSchema() const {
return schema_;
}

vector<pair<string, SortableValue>> FieldIndices::ExtractStoredValues(DocId doc) const {
vector<pair<string, SortableValue>> FieldIndices::ExtractStoredValues(
DocId doc, const absl::flat_hash_map<std::string_view, std::string_view>& aliases) const {
vector<pair<string, SortableValue>> out;
for (const auto& [ident, index] : sort_indices_) {
out.emplace_back(ident, index->Lookup(doc));
const auto& it = aliases.find(ident);
const auto& name = it == aliases.end() ? schema_.LookupIdentifier(ident) : it->second;

out.emplace_back(name, index->Lookup(doc));
}
return out;
}
Expand Down
9 changes: 8 additions & 1 deletion src/core/search/search.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ struct Schema {

// Return identifier for alias if found, otherwise return passed value
std::string_view LookupAlias(std::string_view alias) const;

// Return alias for identifier if found, otherwise return passed value
std::string_view LookupIdentifier(std::string_view identifier) const;
};

struct IndicesOptions {
Expand Down Expand Up @@ -89,7 +92,11 @@ class FieldIndices {
const Schema& GetSchema() const;

// Extract values stored in sort indices
std::vector<std::pair<std::string, SortableValue>> ExtractStoredValues(DocId doc) const;
// aliases are specified in addition to the aliases in search::Schema
std::vector<std::pair<std::string, SortableValue>> ExtractStoredValues(
DocId doc,
const absl::flat_hash_map<std::string_view /*identifier*/, std::string_view /*alias*/>&
aliases) const;

absl::flat_hash_set<std::string_view> GetSortIndiciesFields() const;

Expand Down
11 changes: 8 additions & 3 deletions src/server/search/doc_accessors.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,10 @@ FieldValue ExtractSortableValueFromJson(const search::Schema& schema, string_vie
SearchDocData BaseAccessor::Serialize(
const search::Schema& schema, absl::Span<const SearchField<std::string_view>> fields) const {
SearchDocData out{};
for (const auto& [fident, fname] : fields) {
for (const auto& field : fields) {
const auto& fident = field.GetIdentifier(schema, false);
const auto& fname = field.GetShortName(schema);

auto field_value =
ExtractSortableValue(schema, fident, absl::StrJoin(GetStrings(fident).value(), ","));
if (field_value) {
Expand Down Expand Up @@ -348,14 +351,16 @@ JsonAccessor::JsonPathContainer* JsonAccessor::GetPath(std::string_view field) c
SearchDocData JsonAccessor::Serialize(const search::Schema& schema) const {
SearchFieldsList fields{};
for (const auto& [fname, fident] : schema.field_names)
fields.emplace_back(fident, fname);
fields.emplace_back(fident, NameType::kIdentifier, fname);
return Serialize(schema, fields);
}

SearchDocData JsonAccessor::Serialize(
const search::Schema& schema, absl::Span<const SearchField<std::string_view>> fields) const {
SearchDocData out{};
for (const auto& [ident, name] : fields) {
for (const auto& field : fields) {
const auto& ident = field.GetIdentifier(schema, true);
const auto& name = field.GetShortName(schema);
if (auto* path = GetPath(ident); path) {
if (auto res = path->Evaluate(json_); !res.empty()) {
auto field_value = ExtractSortableValueFromJson(schema, ident, res[0]);
Expand Down
51 changes: 35 additions & 16 deletions src/server/search/doc_index.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ bool SerializedSearchDoc::operator>=(const SerializedSearchDoc& other) const {
return this->score >= other.score;
}

bool SearchParams::ShouldReturnField(std::string_view field) const {
auto cb = [field](const auto& entry) { return entry.first == field; };
bool SearchParams::ShouldReturnField(std::string_view alias) const {
auto cb = [alias](const auto& entry) { return entry.GetShortName() == alias; };
return !return_fields || any_of(return_fields->begin(), return_fields->end(), cb);
}

Expand Down Expand Up @@ -224,12 +224,13 @@ bool ShardDocIndex::Matches(string_view key, unsigned obj_code) const {
return base_->Matches(key, obj_code);
}

SearchFieldsList ToSV(const std::optional<OwnedSearchFieldsList>& fields) {
SearchFieldsList ToSV(const search::Schema& schema,
const std::optional<OwnedSearchFieldsList>& fields) {
SearchFieldsList sv_fields;
if (fields) {
sv_fields.reserve(fields->size());
for (const auto& [fident, fname] : fields.value()) {
sv_fields.emplace_back(fident, fname);
for (const auto& field : fields.value()) {
sv_fields.emplace_back(field);
}
}
return sv_fields;
Expand All @@ -243,8 +244,8 @@ SearchResult ShardDocIndex::Search(const OpArgs& op_args, const SearchParams& pa
if (!search_results.error.empty())
return SearchResult{facade::ErrorReply{std::move(search_results.error)}};

SearchFieldsList fields_to_load =
ToSV(params.ShouldReturnAllFields() ? params.load_fields : params.return_fields);
SearchFieldsList fields_to_load = ToSV(
base_->schema, params.ShouldReturnAllFields() ? params.load_fields : params.return_fields);

vector<SerializedSearchDoc> out;
out.reserve(search_results.ids.size());
Expand Down Expand Up @@ -294,8 +295,19 @@ vector<SearchDocData> ShardDocIndex::SearchForAggregator(
if (!search_results.error.empty())
return {};

SearchFieldsList fields_to_load =
GetFieldsToLoad(params.load_fields, indices_->GetSortIndiciesFields());
auto sort_indicies_fields = indices_->GetSortIndiciesFields();
SearchFieldsList fields_to_load = GetFieldsToLoad(params.load_fields, sort_indicies_fields);

// aliases for ExtractStoredValues
absl::flat_hash_map<std::string_view, std::string_view> sort_indicies_aliases;
if (params.load_fields) {
for (const auto& field : params.load_fields.value()) {
auto ident = field.GetIdentifier(base_->schema);
if (sort_indicies_fields.contains(ident)) {
sort_indicies_aliases[ident] = field.GetShortName();
}
}
}

vector<absl::flat_hash_map<string, search::SortableValue>> out;
for (DocId doc : search_results.ids) {
Expand All @@ -306,7 +318,7 @@ vector<SearchDocData> ShardDocIndex::SearchForAggregator(
continue;

auto accessor = GetAccessor(op_args.db_cntx, (*it)->second);
auto extracted = indices_->ExtractStoredValues(doc);
auto extracted = indices_->ExtractStoredValues(doc, sort_indicies_aliases);

SearchDocData loaded = accessor->Serialize(base_->schema, fields_to_load);

Expand All @@ -320,25 +332,32 @@ vector<SearchDocData> ShardDocIndex::SearchForAggregator(
SearchFieldsList ShardDocIndex::GetFieldsToLoad(
const std::optional<OwnedSearchFieldsList>& load_fields,
const absl::flat_hash_set<std::string_view>& skip_fields) const {
// identifier to short name
absl::flat_hash_map<std::string_view, std::string_view> unique_fields;
absl::flat_hash_map<std::string_view, SearchField<std::string_view>> unique_fields;
unique_fields.reserve(base_->schema.field_names.size());

for (const auto& [fname, fident] : base_->schema.field_names) {
if (!skip_fields.contains(fident)) {
unique_fields[fident] = fname;
unique_fields[fident] = {std::string_view{fident}, NameType::kIdentifier,
std::string_view{fname}};
}
}

if (load_fields) {
for (const auto& [fident, fname] : load_fields.value()) {
for (const auto& field : load_fields.value()) {
const auto& fident = field.GetIdentifier(base_->schema);
if (!skip_fields.contains(fident)) {
unique_fields[fident] = fname;
unique_fields[fident] = field;
}
}
}

return {unique_fields.begin(), unique_fields.end()};
SearchFieldsList fields;
fields.reserve(unique_fields.size());
for (auto& [_, field] : unique_fields) {
fields.emplace_back(std::move(field));
}

return fields;
}

DocIndexInfo ShardDocIndex::GetInfo() const {
Expand Down
104 changes: 100 additions & 4 deletions src/server/search/doc_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,102 @@ struct SearchResult {
std::optional<facade::ErrorReply> error;
};

template <typename T> using SearchField = std::pair<T /*identifier*/, T /*short name*/>;
enum class NameType : uint8_t { kIdentifier, kShortName, kUndefined };

template <typename T> class SearchField {
private:
using SingleName = std::pair<T /*identifier or short name*/, NameType>;

static bool IsJsonPath(const T& name) {
if (name.size() < 2) {
return false;
}
return name.front() == '$' && (name[1] == '.' || name[1] == '[');
}

public:
SearchField() = default;

SearchField(T name, NameType name_type) : name_(std::make_pair(std::move(name), name_type)) {
}

SearchField(T name, NameType name_type, T new_alias)
: name_(std::make_pair(std::move(name), name_type)), new_alias_(std::move(new_alias)) {
}

template <typename U, typename = std::enable_if_t<std::is_constructible_v<T, U>>>
explicit SearchField(const SearchField<U>& other)
: name_(std::make_pair(T{other.name_.first}, other.name_.second)) {
if (other.HasNewAlias()) {
new_alias_ = T{other.new_alias_.value()};
} else {
new_alias_.reset();
}
}

template <typename U, typename = std::enable_if_t<std::is_constructible_v<T, U>>>
SearchField& operator=(const SearchField<U>& other) {
name_ = std::make_pair(T{other.name_.first}, other.name_.second);
if (other.HasNewAlias()) {
new_alias_ = T{other.new_alias_.value()};
} else {
new_alias_.reset();
}
return *this;
}

~SearchField() = default;

std::string_view GetIdentifier(const search::Schema& schema) const {
return GetIdentifier(schema, [&](const SingleName& single_name) {
return single_name.second == NameType::kIdentifier;
});
}

std::string_view GetIdentifier(const search::Schema& schema, bool is_json_field) const {
return GetIdentifier(schema, [&](const SingleName& single_name) {
return single_name.second == NameType::kIdentifier ||
(is_json_field && IsJsonPath(single_name.first));
});
}

std::string_view GetShortName() const {
if (HasNewAlias()) {
return new_alias_.value();
}
return name_.first;
}

std::string_view GetShortName(const search::Schema& schema) const {
if (HasNewAlias()) {
return new_alias_.value();
}

if (name_.second == NameType::kShortName) {
return name_.first;
}
return schema.LookupIdentifier(std::string_view{name_.first});
}

private:
template <typename Callback>
std::string_view GetIdentifier(const search::Schema& schema, Callback is_identifier) const {
if (is_identifier(name_)) {
return name_.first;
}
return schema.LookupAlias(std::string_view{name_.first});
}

bool HasNewAlias() const {
return new_alias_.has_value();
}

template <typename U> friend class SearchField;

private:
SingleName name_;
std::optional<T> new_alias_;
};

using SearchFieldsList = std::vector<SearchField<std::string_view>>;
using OwnedSearchFieldsList = std::vector<SearchField<std::string>>;
Expand Down Expand Up @@ -88,7 +183,7 @@ struct SearchParams {
return return_fields && return_fields->empty();
}

bool ShouldReturnField(std::string_view field) const;
bool ShouldReturnField(std::string_view alias) const;
};

struct AggregateParams {
Expand Down Expand Up @@ -169,8 +264,9 @@ class ShardDocIndex {
io::Result<StringVec, facade::ErrorReply> GetTagVals(std::string_view field) const;

private:
// Returns the fields that are the union of the already indexed fields and load_fields, excluding
// skip_fields Load_fields should not be destroyed while the result of this function is being used
/* Returns the fields that are the union of the already indexed fields and load_fields, excluding
skip_fields.
Load_fields should not be destroyed while the result of this function is being used */
SearchFieldsList GetFieldsToLoad(const std::optional<OwnedSearchFieldsList>& load_fields,
const absl::flat_hash_set<std::string_view>& skip_fields) const;

Expand Down
36 changes: 28 additions & 8 deletions src/server/search/search_family.cc
Original file line number Diff line number Diff line change
Expand Up @@ -183,17 +183,21 @@ optional<search::Schema> ParseSchemaOrReply(DocIndex::DataType type, CmdArgParse
#pragma GCC diagnostic pop
#endif

bool StartsWithAtSign(std::string_view field) {
return !field.empty() && field.front() == '@';
}

std::string_view ParseField(CmdArgParser* parser) {
std::string_view field = parser->Next();
if (!field.empty() && field.front() == '@') {
if (StartsWithAtSign(field)) {
field.remove_prefix(1); // remove leading @ if exists
}
return field;
}

std::string_view ParseFieldWithAtSign(CmdArgParser* parser) {
std::string_view field = parser->Next();
if (!field.empty() && field.front() == '@') {
if (StartsWithAtSign(field)) {
field.remove_prefix(1); // remove leading @
} else {
// Temporary warning until we can throw an error
Expand All @@ -204,15 +208,25 @@ std::string_view ParseFieldWithAtSign(CmdArgParser* parser) {
}

void ParseLoadFields(CmdArgParser* parser, std::optional<OwnedSearchFieldsList>* load_fields) {
// TODO: Change to num_strings. In Redis strings number is expected. For example: LOAD 3 $.a AS a
size_t num_fields = parser->Next<size_t>();
if (!load_fields->has_value()) {
load_fields->emplace();
}

while (num_fields--) {
string_view field = ParseField(parser);
string_view alias = parser->Check("AS") ? parser->Next() : field;
load_fields->value().emplace_back(field, alias);
string_view str = parser->Next();

if (StartsWithAtSign(str)) {
str.remove_prefix(1); // remove leading @
}

if (parser->Check("AS")) {
load_fields->value().emplace_back(std::string{str}, NameType::kShortName,
std::string{parser->Next()});
} else {
load_fields->value().emplace_back(std::string{str}, NameType::kShortName);
}
}
}

Expand Down Expand Up @@ -248,12 +262,18 @@ optional<SearchParams> ParseSearchParamsOrReply(CmdArgParser* parser, SinkReplyB
}

// RETURN {num} [{ident} AS {name}...]
/* TODO: Change to num_strings. In Redis strings number is expected. For example: RETURN 3 $.a AS a */
size_t num_fields = parser->Next<size_t>();
params.return_fields.emplace();
while (params.return_fields->size() < num_fields) {
string_view ident = parser->Next();
string_view alias = parser->Check("AS") ? parser->Next() : ident;
params.return_fields->emplace_back(ident, alias);
std::string_view str = parser->Next();

if (parser->Check("AS")) {
params.return_fields->emplace_back(std::string{str}, NameType::kShortName,
std::string{parser->Next()});
} else {
params.return_fields->emplace_back(std::string{str}, NameType::kShortName);
}
}
} else if (parser->Check("NOCONTENT")) { // NOCONTENT
params.load_fields.emplace();
Expand Down
Loading

0 comments on commit 7742b5a

Please sign in to comment.