Skip to content

Commit

Permalink
fix(search_family): Fix LOAD fields parsing in the FT.AGGREGATE and F…
Browse files Browse the repository at this point in the history
…T.SEARCH commands

fixes dragonflydb#3989

Signed-off-by: Stsiapan Bahrytsevich <[email protected]>
  • Loading branch information
BagritsevichStepan committed Oct 29, 2024
1 parent 4b49518 commit 73c5897
Show file tree
Hide file tree
Showing 6 changed files with 196 additions and 35 deletions.
14 changes: 12 additions & 2 deletions src/core/search/search.cc
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,12 @@ string_view Schema::LookupAlias(string_view alias) const {
return alias;
}

string_view Schema::LookupIdentifier(string_view identifier) const {
if (auto it = fields.find(identifier); it != fields.end())
return it->second.short_name;
return identifier;
}

IndicesOptions::IndicesOptions() {
static absl::flat_hash_set<std::string> kDefaultStopwords{
"a", "is", "the", "an", "and", "are", "as", "at", "be", "but", "by",
Expand Down Expand Up @@ -621,10 +627,14 @@ const Schema& FieldIndices::GetSchema() const {
return schema_;
}

vector<pair<string, SortableValue>> FieldIndices::ExtractStoredValues(DocId doc) const {
vector<pair<string, SortableValue>> FieldIndices::ExtractStoredValues(
DocId doc, const absl::flat_hash_map<std::string_view, std::string_view>& aliases) const {
vector<pair<string, SortableValue>> out;
for (const auto& [ident, index] : sort_indices_) {
out.emplace_back(ident, index->Lookup(doc));
const auto& it = aliases.find(ident);
const auto& name = it == aliases.end() ? schema_.LookupIdentifier(ident) : it->second;

out.emplace_back(name, index->Lookup(doc));
}
return out;
}
Expand Down
9 changes: 8 additions & 1 deletion src/core/search/search.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ struct Schema {

// Return identifier for alias if found, otherwise return passed value
std::string_view LookupAlias(std::string_view alias) const;

// Return alias for identifier if found, otherwise return passed value
std::string_view LookupIdentifier(std::string_view identifier) const;
};

struct IndicesOptions {
Expand Down Expand Up @@ -88,7 +91,11 @@ class FieldIndices {
const Schema& GetSchema() const;

// Extract values stored in sort indices
std::vector<std::pair<std::string, SortableValue>> ExtractStoredValues(DocId doc) const;
// aliases are specified in addition to the aliases in search::Schema
std::vector<std::pair<std::string, SortableValue>> ExtractStoredValues(
DocId doc,
const absl::flat_hash_map<std::string_view /*identifier*/, std::string_view /*alias*/>&
aliases) const;

absl::flat_hash_set<std::string_view> GetSortIndiciesFields() const;

Expand Down
9 changes: 7 additions & 2 deletions src/server/search/doc_accessors.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,10 @@ search::SortableValue ExtractSortableValueFromJson(const search::Schema& schema,
SearchDocData BaseAccessor::Serialize(
const search::Schema& schema, absl::Span<const SearchField<std::string_view>> fields) const {
SearchDocData out{};
for (const auto& [fident, fname] : fields) {
for (const auto& field : fields) {
const auto& fident = field.GetIdentifier(schema, false);
const auto& fname = field.GetShortName(schema);

out[fname] = ExtractSortableValue(schema, fident, absl::StrJoin(GetStrings(fident), ","));
}
return out;
Expand Down Expand Up @@ -257,7 +260,9 @@ SearchDocData JsonAccessor::Serialize(const search::Schema& schema) const {
SearchDocData JsonAccessor::Serialize(
const search::Schema& schema, absl::Span<const SearchField<std::string_view>> fields) const {
SearchDocData out{};
for (const auto& [ident, name] : fields) {
for (const auto& field : fields) {
const auto& ident = field.GetIdentifier(schema, true);
const auto& name = field.GetShortName(schema);
if (auto* path = GetPath(ident); path) {
if (auto res = path->Evaluate(json_); !res.empty())
out[name] = ExtractSortableValueFromJson(schema, ident, res[0]);
Expand Down
50 changes: 34 additions & 16 deletions src/server/search/doc_index.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ bool SerializedSearchDoc::operator>=(const SerializedSearchDoc& other) const {
return this->score >= other.score;
}

bool SearchParams::ShouldReturnField(std::string_view field) const {
auto cb = [field](const auto& entry) { return entry.first == field; };
bool SearchParams::ShouldReturnField(std::string_view alias) const {
auto cb = [alias](const auto& entry) { return entry.AsShortName() == alias; };
return !return_fields || any_of(return_fields->begin(), return_fields->end(), cb);
}

Expand Down Expand Up @@ -211,12 +211,13 @@ bool ShardDocIndex::Matches(string_view key, unsigned obj_code) const {
return base_->Matches(key, obj_code);
}

SearchFieldsList ToSV(const std::optional<OwnedSearchFieldsList>& fields) {
SearchFieldsList ToSV(const search::Schema& schema,
const std::optional<OwnedSearchFieldsList>& fields) {
SearchFieldsList sv_fields;
if (fields) {
sv_fields.reserve(fields->size());
for (const auto& [fident, fname] : fields.value()) {
sv_fields.emplace_back(fident, fname);
for (const auto& field : fields.value()) {
sv_fields.emplace_back(field.GetIdentifier(schema), field.GetShortName(schema));
}
}
return sv_fields;
Expand All @@ -230,8 +231,8 @@ SearchResult ShardDocIndex::Search(const OpArgs& op_args, const SearchParams& pa
if (!search_results.error.empty())
return SearchResult{facade::ErrorReply{std::move(search_results.error)}};

SearchFieldsList fields_to_load =
ToSV(params.ShouldReturnAllFields() ? params.load_fields : params.return_fields);
SearchFieldsList fields_to_load = ToSV(
base_->schema, params.ShouldReturnAllFields() ? params.load_fields : params.return_fields);

vector<SerializedSearchDoc> out;
out.reserve(search_results.ids.size());
Expand Down Expand Up @@ -281,8 +282,19 @@ vector<SearchDocData> ShardDocIndex::SearchForAggregator(
if (!search_results.error.empty())
return {};

SearchFieldsList fields_to_load =
GetFieldsToLoad(params.load_fields, indices_->GetSortIndiciesFields());
auto sort_indicies_fields = indices_->GetSortIndiciesFields();
SearchFieldsList fields_to_load = GetFieldsToLoad(params.load_fields, sort_indicies_fields);

// aliases for ExtractStoredValues
absl::flat_hash_map<std::string_view, std::string_view> sort_indicies_aliases;
if (params.load_fields) {
for (const auto& field : params.load_fields.value()) {
auto ident = field.GetIdentifier(base_->schema);
if (sort_indicies_fields.contains(ident)) {
sort_indicies_aliases[ident] = field.AsShortName();
}
}
}

vector<absl::flat_hash_map<string, search::SortableValue>> out;
for (DocId doc : search_results.ids) {
Expand All @@ -293,7 +305,7 @@ vector<SearchDocData> ShardDocIndex::SearchForAggregator(
continue;

auto accessor = GetAccessor(op_args.db_cntx, (*it)->second);
auto extracted = indices_->ExtractStoredValues(doc);
auto extracted = indices_->ExtractStoredValues(doc, sort_indicies_aliases);

SearchDocData loaded = accessor->Serialize(base_->schema, fields_to_load);

Expand All @@ -307,25 +319,31 @@ vector<SearchDocData> ShardDocIndex::SearchForAggregator(
SearchFieldsList ShardDocIndex::GetFieldsToLoad(
const std::optional<OwnedSearchFieldsList>& load_fields,
const absl::flat_hash_set<std::string_view>& skip_fields) const {
// identifier to short name
absl::flat_hash_map<std::string_view, std::string_view> unique_fields;
absl::flat_hash_map<std::string_view, SearchField<std::string_view>> unique_fields;
unique_fields.reserve(base_->schema.field_names.size());

for (const auto& [fname, fident] : base_->schema.field_names) {
if (!skip_fields.contains(fident)) {
unique_fields[fident] = fname;
unique_fields[fident] = {std::string_view{fident}, std::string_view{fname}};
}
}

if (load_fields) {
for (const auto& [fident, fname] : load_fields.value()) {
for (const auto& field : load_fields.value()) {
const auto& fident = field.GetIdentifier(base_->schema);
if (!skip_fields.contains(fident)) {
unique_fields[fident] = fname;
unique_fields[fident] = field;
}
}
}

return {unique_fields.begin(), unique_fields.end()};
SearchFieldsList fields;
fields.reserve(unique_fields.size());
for (auto& [_, field] : unique_fields) {
fields.emplace_back(std::move(field));
}

return fields;
}

DocIndexInfo ShardDocIndex::GetInfo() const {
Expand Down
108 changes: 104 additions & 4 deletions src/server/search/doc_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,106 @@ struct SearchResult {
std::optional<facade::ErrorReply> error;
};

template <typename T> using SearchField = std::pair<T /*identifier*/, T /*short name*/>;
enum class NameType : uint8_t { kIdentifier, kShortName, kUndefined };

template <typename T> class SearchField {
private:
using SingleName = std::pair<T /*identifier or short name*/, NameType>;
using IdentifierAndName = std::pair<T /*identifier*/, T /*short name*/>;

static bool IsJsonPath(const T& name) {
if (name.size() < 2) {
return false;
}
return name.front() == '$' && (name[1] == '.' || name[1] == '[');
}

public:
SearchField() = default;

explicit SearchField(T name) : name_(std::make_pair(std::move(name), NameType::kUndefined)) {
}

SearchField(T name, NameType name_type) : name_(std::make_pair(std::move(name), name_type)) {
}

SearchField(T identifier, T short_name)
: name_(std::make_pair(std::move(identifier), std::move(short_name))) {
}

template <typename U, typename = std::enable_if_t<std::is_constructible_v<T, U>>>
SearchField& operator=(const SearchField<U>& other) {
if (other.IsIdentifierAndName()) {
const auto& ident_and_name = other.AsIdentifierAndName();
name_ = std::make_pair(T{ident_and_name.first}, T{ident_and_name.second});
} else {
const auto& single_name = other.AsSingle();
name_ = std::make_pair(T{single_name.first}, single_name.second);
}
return *this;
}

std::string_view GetIdentifier(const search::Schema& schema) const {
return GetIdentifier(schema, [&](const SingleName& single_name) {
return single_name.second == NameType::kIdentifier;
});
}

std::string_view GetIdentifier(const search::Schema& schema, bool is_json_field) const {
return GetIdentifier(schema, [&](const SingleName& single_name) {
return single_name.second == NameType::kIdentifier ||
(is_json_field && IsJsonPath(single_name.first));
});
}

std::string_view GetShortName(const search::Schema& schema) const {
if (IsIdentifierAndName()) {
return AsIdentifierAndName().second;
}

const auto& single_name = AsSingle();
if (single_name.second == NameType::kShortName) {
return single_name.first;
}
return schema.LookupIdentifier(std::string_view{single_name.first});
}

std::string_view AsShortName() const {
if (IsIdentifierAndName()) {
return AsIdentifierAndName().second;
}
return AsSingle().first;
}

bool IsIdentifierAndName() const {
return std::holds_alternative<IdentifierAndName>(name_);
}

const IdentifierAndName& AsIdentifierAndName() const {
return std::get<IdentifierAndName>(name_);
}

const SingleName& AsSingle() const {
return std::get<SingleName>(name_);
}

private:
template <typename Callback>
std::string_view GetIdentifier(const search::Schema& schema, Callback is_identifier) const {
if (IsIdentifierAndName()) {
return AsIdentifierAndName().first;
}

const auto& single_name = AsSingle();
if (is_identifier(single_name)) {
return single_name.first;
}
return schema.LookupAlias(std::string_view{single_name.first});
}

private:
std::variant<SingleName, IdentifierAndName> name_;
};

using SearchFieldsList = std::vector<SearchField<std::string_view>>;
using OwnedSearchFieldsList = std::vector<SearchField<std::string>>;
Expand Down Expand Up @@ -88,7 +187,7 @@ struct SearchParams {
return return_fields && return_fields->empty();
}

bool ShouldReturnField(std::string_view field) const;
bool ShouldReturnField(std::string_view alias) const;
};

struct AggregateParams {
Expand Down Expand Up @@ -169,8 +268,9 @@ class ShardDocIndex {
io::Result<StringVec, facade::ErrorReply> GetTagVals(std::string_view field) const;

private:
// Returns the fields that are the union of the already indexed fields and load_fields, excluding
// skip_fields Load_fields should not be destroyed while the result of this function is being used
/* Returns the fields that are the union of the already indexed fields and load_fields, excluding
skip_fields.
Load_fields should not be destroyed while the result of this function is being used */
SearchFieldsList GetFieldsToLoad(const std::optional<OwnedSearchFieldsList>& load_fields,
const absl::flat_hash_set<std::string_view>& skip_fields) const;

Expand Down
41 changes: 31 additions & 10 deletions src/server/search/search_family.cc
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,10 @@ optional<search::Schema> ParseSchemaOrReply(DocIndex::DataType type, CmdArgParse
#pragma GCC diagnostic pop
#endif

bool StartsWithAtSign(std::string_view field) {
return !field.empty() && field.front() == '@';
}

std::string_view ParseField(CmdArgParser* parser) {
std::string_view field = parser->Next();
if (!field.empty() && field.front() == '@') {
Expand All @@ -204,15 +208,28 @@ std::string_view ParseFieldWithAtSign(CmdArgParser* parser) {
}

void ParseLoadFields(CmdArgParser* parser, std::optional<OwnedSearchFieldsList>* load_fields) {
size_t num_fields = parser->Next<size_t>();
using Field = SearchField<std::string>;

size_t num_strings = parser->Next<size_t>();
if (!load_fields->has_value()) {
load_fields->emplace();
}

while (num_fields--) {
string_view field = ParseField(parser);
string_view alias = parser->Check("AS") ? parser->Next() : field;
load_fields->value().emplace_back(field, alias);
while (num_strings--) {
string_view str = parser->Next();

Field field;
if (parser->Check("AS")) { // str is identifier
field = {std::string{str}, std::string{parser->Next()}};
num_strings -= 2;
} else if (StartsWithAtSign(str)) {
str.remove_prefix(1); // remove leading @
field = {std::string{str}, NameType::kShortName};
} else {
field = Field{std::string{str}};
}

load_fields->value().emplace_back(std::move(field));
}
}

Expand Down Expand Up @@ -248,12 +265,16 @@ optional<SearchParams> ParseSearchParamsOrReply(CmdArgParser parser, SinkReplyBu
}

// RETURN {num} [{ident} AS {name}...]
size_t num_fields = parser.Next<size_t>();
size_t num_strings = parser.Next<size_t>();
params.return_fields.emplace();
while (params.return_fields->size() < num_fields) {
string_view ident = parser.Next();
string_view alias = parser.Check("AS") ? parser.Next() : ident;
params.return_fields->emplace_back(ident, alias);
while (num_strings--) {
std::string_view ident = parser.Next();
if (parser.Check("AS")) {
params.return_fields->emplace_back(std::string{ident}, std::string{parser.Next()});
num_strings -= 2;
} else {
params.return_fields->emplace_back(std::string{ident}, NameType::kIdentifier);
}
}
} else if (parser.Check("NOCONTENT")) { // NOCONTENT
params.load_fields.emplace();
Expand Down

0 comments on commit 73c5897

Please sign in to comment.