Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(search_family): Fix LOAD fields parsing in the FT.AGGREGATE and FT.SEARCH commands #4012

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 10 additions & 15 deletions src/core/search/search.cc
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,12 @@ string_view Schema::LookupAlias(string_view alias) const {
return alias;
}

string_view Schema::LookupIdentifier(string_view identifier) const {
if (auto it = fields.find(identifier); it != fields.end())
return it->second.short_name;
return identifier;
}

IndicesOptions::IndicesOptions() {
static absl::flat_hash_set<std::string> kDefaultStopwords{
"a", "is", "the", "an", "and", "are", "as", "at", "be", "but", "by",
Expand Down Expand Up @@ -646,21 +652,10 @@ const Schema& FieldIndices::GetSchema() const {
return schema_;
}

vector<pair<string, SortableValue>> FieldIndices::ExtractStoredValues(DocId doc) const {
vector<pair<string, SortableValue>> out;
for (const auto& [ident, index] : sort_indices_) {
out.emplace_back(ident, index->Lookup(doc));
}
return out;
}

absl::flat_hash_set<std::string_view> FieldIndices::GetSortIndiciesFields() const {
absl::flat_hash_set<std::string_view> fields_idents;
fields_idents.reserve(sort_indices_.size());
for (const auto& [ident, _] : sort_indices_) {
fields_idents.insert(ident);
}
return fields_idents;
SortableValue FieldIndices::GetSortIndexValue(DocId doc, std::string_view field_identifier) const {
auto it = sort_indices_.find(field_identifier);
BagritsevichStepan marked this conversation as resolved.
Show resolved Hide resolved
DCHECK(it != sort_indices_.end());
return it->second->Lookup(doc);
}

SearchAlgorithm::SearchAlgorithm() = default;
Expand Down
12 changes: 6 additions & 6 deletions src/core/search/search.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ struct Schema {

// Return identifier for alias if found, otherwise return passed value
std::string_view LookupAlias(std::string_view alias) const;

// Return alias for identifier if found, otherwise return passed value
std::string_view LookupIdentifier(std::string_view identifier) const;
};

struct IndicesOptions {
Expand Down Expand Up @@ -88,10 +91,7 @@ class FieldIndices {
const std::vector<DocId>& GetAllDocs() const;
const Schema& GetSchema() const;

// Extract values stored in sort indices
std::vector<std::pair<std::string, SortableValue>> ExtractStoredValues(DocId doc) const;

absl::flat_hash_set<std::string_view> GetSortIndiciesFields() const;
SortableValue GetSortIndexValue(DocId doc, std::string_view field_identifier) const;

private:
void CreateIndices(PMR_NS::memory_resource* mr);
Expand All @@ -100,8 +100,8 @@ class FieldIndices {
const Schema& schema_;
const IndicesOptions& options_;
std::vector<DocId> all_ids_;
absl::flat_hash_map<std::string, std::unique_ptr<BaseIndex>> indices_;
absl::flat_hash_map<std::string, std::unique_ptr<BaseSortIndex>> sort_indices_;
absl::flat_hash_map<std::string_view, std::unique_ptr<BaseIndex>> indices_;
absl::flat_hash_map<std::string_view, std::unique_ptr<BaseSortIndex>> sort_indices_;
};

struct AlgorithmProfile {
Expand Down
4 changes: 4 additions & 0 deletions src/core/string_or_view.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ class StringOrView {
val_ = std::string{std::get<std::string_view>(val_)};
}

bool empty() const {
return visit([](const auto& s) { return s.empty(); }, val_);
}

private:
std::variant<std::string_view, std::string> val_;
};
Expand Down
19 changes: 12 additions & 7 deletions src/server/search/doc_accessors.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,13 @@ FieldValue ExtractSortableValueFromJson(const search::Schema& schema, string_vie

} // namespace

SearchDocData BaseAccessor::Serialize(
const search::Schema& schema, absl::Span<const SearchField<std::string_view>> fields) const {
SearchDocData BaseAccessor::Serialize(const search::Schema& schema,
absl::Span<const SearchField> fields) const {
SearchDocData out{};
for (const auto& [fident, fname] : fields) {
for (const auto& field : fields) {
const auto& fident = field.GetIdentifier(schema, false);
const auto& fname = field.GetShortName(schema);

auto field_value =
ExtractSortableValue(schema, fident, absl::StrJoin(GetStrings(fident).value(), ","));
if (field_value) {
Expand Down Expand Up @@ -348,14 +351,16 @@ JsonAccessor::JsonPathContainer* JsonAccessor::GetPath(std::string_view field) c
SearchDocData JsonAccessor::Serialize(const search::Schema& schema) const {
SearchFieldsList fields{};
for (const auto& [fname, fident] : schema.field_names)
fields.emplace_back(fident, fname);
fields.emplace_back(StringOrView::FromView(fident), false, StringOrView::FromView(fname));
return Serialize(schema, fields);
}

SearchDocData JsonAccessor::Serialize(
const search::Schema& schema, absl::Span<const SearchField<std::string_view>> fields) const {
SearchDocData JsonAccessor::Serialize(const search::Schema& schema,
absl::Span<const SearchField> fields) const {
SearchDocData out{};
for (const auto& [ident, name] : fields) {
for (const auto& field : fields) {
const auto& ident = field.GetIdentifier(schema, true);
const auto& name = field.GetShortName(schema);
if (auto* path = GetPath(ident); path) {
if (auto res = path->Evaluate(json_); !res.empty()) {
auto field_value = ExtractSortableValueFromJson(schema, ident, res[0]);
Expand Down
4 changes: 2 additions & 2 deletions src/server/search/doc_accessors.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ struct BaseAccessor : public search::DocumentAccessor {

// Serialize selected fields
virtual SearchDocData Serialize(const search::Schema& schema,
absl::Span<const SearchField<std::string_view>> fields) const;
absl::Span<const SearchField> fields) const;

/*
Serialize the whole type, the default implementation is to serialize all fields.
Expand Down Expand Up @@ -84,7 +84,7 @@ struct JsonAccessor : public BaseAccessor {

// The JsonAccessor works with structured types and not plain strings, so an overload is needed
SearchDocData Serialize(const search::Schema& schema,
absl::Span<const SearchField<std::string_view>> fields) const override;
absl::Span<const SearchField> fields) const override;
SearchDocData Serialize(const search::Schema& schema) const override;
SearchDocData SerializeDocument(const search::Schema& schema) const override;

Expand Down
103 changes: 68 additions & 35 deletions src/server/search/doc_index.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ bool SerializedSearchDoc::operator>=(const SerializedSearchDoc& other) const {
return this->score >= other.score;
}

bool SearchParams::ShouldReturnField(std::string_view field) const {
auto cb = [field](const auto& entry) { return entry.first == field; };
bool SearchParams::ShouldReturnField(std::string_view alias) const {
auto cb = [alias](const auto& entry) { return entry.GetShortName() == alias; };
return !return_fields || any_of(return_fields->begin(), return_fields->end(), cb);
}

Expand Down Expand Up @@ -224,12 +224,12 @@ bool ShardDocIndex::Matches(string_view key, unsigned obj_code) const {
return base_->Matches(key, obj_code);
}

SearchFieldsList ToSV(const std::optional<OwnedSearchFieldsList>& fields) {
SearchFieldsList ToSV(const search::Schema& schema, const std::optional<SearchFieldsList>& fields) {
SearchFieldsList sv_fields;
if (fields) {
sv_fields.reserve(fields->size());
for (const auto& [fident, fname] : fields.value()) {
sv_fields.emplace_back(fident, fname);
for (const auto& field : fields.value()) {
sv_fields.push_back(field.View());
}
}
return sv_fields;
Expand All @@ -243,8 +243,8 @@ SearchResult ShardDocIndex::Search(const OpArgs& op_args, const SearchParams& pa
if (!search_results.error.empty())
return SearchResult{facade::ErrorReply{std::move(search_results.error)}};

SearchFieldsList fields_to_load =
ToSV(params.ShouldReturnAllFields() ? params.load_fields : params.return_fields);
SearchFieldsList fields_to_load = ToSV(
base_->schema, params.ShouldReturnAllFields() ? params.load_fields : params.return_fields);

vector<SerializedSearchDoc> out;
out.reserve(search_results.ids.size());
Expand Down Expand Up @@ -285,6 +285,57 @@ SearchResult ShardDocIndex::Search(const OpArgs& op_args, const SearchParams& pa
std::move(search_results.profile)};
}

using SortIndiciesFieldsList =
std::vector<std::pair<string_view /*identifier*/, string_view /*alias*/>>;

std::pair<SearchFieldsList, SortIndiciesFieldsList> PreprocessAggregateFields(
const search::Schema& schema, const AggregateParams& params,
const std::optional<SearchFieldsList>& load_fields) {
auto is_sortable = [&schema](std::string_view fident) {
auto it = schema.fields.find(fident);
return it != schema.fields.end() && (it->second.flags & search::SchemaField::SORTABLE);
};

absl::flat_hash_map<std::string_view, SearchField> fields_by_identifier;
absl::flat_hash_map<std::string_view, std::string_view> sort_indicies_aliases;
fields_by_identifier.reserve(schema.field_names.size());
sort_indicies_aliases.reserve(schema.field_names.size());

for (const auto& [fname, fident] : schema.field_names) {
if (!is_sortable(fident)) {
fields_by_identifier[fident] = {StringOrView::FromView(fident), true,
StringOrView::FromView(fname)};
} else {
sort_indicies_aliases[fident] = fname;
}
}

if (load_fields) {
for (const auto& field : load_fields.value()) {
const auto& fident = field.GetIdentifier(schema, false);
if (!is_sortable(fident)) {
fields_by_identifier[fident] = field.View();
} else {
sort_indicies_aliases[fident] = field.GetShortName();
}
}
}

SearchFieldsList fields;
fields.reserve(fields_by_identifier.size());
for (auto& [_, field] : fields_by_identifier) {
fields.emplace_back(std::move(field));
}

SortIndiciesFieldsList sort_fields;
sort_fields.reserve(sort_indicies_aliases.size());
for (auto& [fident, fname] : sort_indicies_aliases) {
sort_fields.emplace_back(fident, fname);
}

return {std::move(fields), std::move(sort_fields)};
}

vector<SearchDocData> ShardDocIndex::SearchForAggregator(
const OpArgs& op_args, const AggregateParams& params,
search::SearchAlgorithm* search_algo) const {
Expand All @@ -294,8 +345,8 @@ vector<SearchDocData> ShardDocIndex::SearchForAggregator(
if (!search_results.error.empty())
return {};

SearchFieldsList fields_to_load =
GetFieldsToLoad(params.load_fields, indices_->GetSortIndiciesFields());
auto [fields_to_load, sort_indicies] =
PreprocessAggregateFields(base_->schema, params, params.load_fields);

vector<absl::flat_hash_map<string, search::SortableValue>> out;
for (DocId doc : search_results.ids) {
Expand All @@ -306,41 +357,23 @@ vector<SearchDocData> ShardDocIndex::SearchForAggregator(
continue;

auto accessor = GetAccessor(op_args.db_cntx, (*it)->second);
auto extracted = indices_->ExtractStoredValues(doc);

SearchDocData extracted_sort_indicies;
extracted_sort_indicies.reserve(sort_indicies.size());
for (const auto& [fident, fname] : sort_indicies) {
extracted_sort_indicies[fname] = indices_->GetSortIndexValue(doc, fident);
}

SearchDocData loaded = accessor->Serialize(base_->schema, fields_to_load);

out.emplace_back(make_move_iterator(extracted.begin()), make_move_iterator(extracted.end()));
out.emplace_back(make_move_iterator(extracted_sort_indicies.begin()),
make_move_iterator(extracted_sort_indicies.end()));
out.back().insert(make_move_iterator(loaded.begin()), make_move_iterator(loaded.end()));
}

return out;
}

SearchFieldsList ShardDocIndex::GetFieldsToLoad(
const std::optional<OwnedSearchFieldsList>& load_fields,
const absl::flat_hash_set<std::string_view>& skip_fields) const {
// identifier to short name
absl::flat_hash_map<std::string_view, std::string_view> unique_fields;
unique_fields.reserve(base_->schema.field_names.size());

for (const auto& [fname, fident] : base_->schema.field_names) {
if (!skip_fields.contains(fident)) {
unique_fields[fident] = fname;
}
}

if (load_fields) {
for (const auto& [fident, fname] : load_fields.value()) {
if (!skip_fields.contains(fident)) {
unique_fields[fident] = fname;
}
}
}

return {unique_fields.begin(), unique_fields.end()};
}

DocIndexInfo ShardDocIndex::GetInfo() const {
return {*base_, key_index_.Size()};
}
Expand Down
Loading
Loading