Skip to content

Commit

Permalink
refactor(search_family): Address comments 2
Browse files Browse the repository at this point in the history
Signed-off-by: Stepan Bagritsevich <[email protected]>
  • Loading branch information
BagritsevichStepan committed Nov 18, 2024
1 parent 2b2e443 commit a62cf79
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 88 deletions.
22 changes: 3 additions & 19 deletions src/core/search/search.cc
Original file line number Diff line number Diff line change
Expand Up @@ -652,25 +652,9 @@ const Schema& FieldIndices::GetSchema() const {
return schema_;
}

vector<pair<string, SortableValue>> FieldIndices::ExtractStoredValues(
DocId doc, const absl::flat_hash_map<std::string_view, std::string_view>& aliases) const {
vector<pair<string, SortableValue>> out;
for (const auto& [ident, index] : sort_indices_) {
const auto& it = aliases.find(ident);
const auto& name = it == aliases.end() ? schema_.LookupIdentifier(ident) : it->second;

out.emplace_back(name, index->Lookup(doc));
}
return out;
}

absl::flat_hash_set<std::string_view> FieldIndices::GetSortIndiciesFields() const {
absl::flat_hash_set<std::string_view> fields_idents;
fields_idents.reserve(sort_indices_.size());
for (const auto& [ident, _] : sort_indices_) {
fields_idents.insert(ident);
}
return fields_idents;
SortableValue FieldIndices::GetSortIndexValue(DocId doc, std::string_view field_identifier) const {
auto it = sort_indices_.find(field_identifier);
return it->second->Lookup(doc);
}

SearchAlgorithm::SearchAlgorithm() = default;
Expand Down
13 changes: 3 additions & 10 deletions src/core/search/search.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,14 +91,7 @@ class FieldIndices {
const std::vector<DocId>& GetAllDocs() const;
const Schema& GetSchema() const;

// Extract values stored in sort indices
// aliases are specified in addition to the aliases in search::Schema
std::vector<std::pair<std::string, SortableValue>> ExtractStoredValues(
DocId doc,
const absl::flat_hash_map<std::string_view /*identifier*/, std::string_view /*alias*/>&
aliases) const;

absl::flat_hash_set<std::string_view> GetSortIndiciesFields() const;
SortableValue GetSortIndexValue(DocId doc, std::string_view field_identifier) const;

private:
void CreateIndices(PMR_NS::memory_resource* mr);
Expand All @@ -107,8 +100,8 @@ class FieldIndices {
const Schema& schema_;
const IndicesOptions& options_;
std::vector<DocId> all_ids_;
absl::flat_hash_map<std::string, std::unique_ptr<BaseIndex>> indices_;
absl::flat_hash_map<std::string, std::unique_ptr<BaseSortIndex>> sort_indices_;
absl::flat_hash_map<std::string_view, std::unique_ptr<BaseIndex>> indices_;
absl::flat_hash_map<std::string_view, std::unique_ptr<BaseSortIndex>> sort_indices_;
};

struct AlgorithmProfile {
Expand Down
109 changes: 62 additions & 47 deletions src/server/search/doc_index.cc
Original file line number Diff line number Diff line change
Expand Up @@ -224,8 +224,7 @@ bool ShardDocIndex::Matches(string_view key, unsigned obj_code) const {
return base_->Matches(key, obj_code);
}

SearchFieldsList ToSV(const search::Schema& schema,
const std::optional<OwnedSearchFieldsList>& fields) {
SearchFieldsList ToSV(const search::Schema& schema, const std::optional<SearchFieldsList>& fields) {
SearchFieldsList sv_fields;
if (fields) {
sv_fields.reserve(fields->size());
Expand Down Expand Up @@ -286,6 +285,57 @@ SearchResult ShardDocIndex::Search(const OpArgs& op_args, const SearchParams& pa
std::move(search_results.profile)};
}

using SortIndiciesFieldsList =
std::vector<std::pair<string_view /*identifier*/, string_view /*alias*/>>;

std::pair<SearchFieldsList, SortIndiciesFieldsList> PreprocessAggregateFields(
const search::Schema& schema, const AggregateParams& params,
const std::optional<SearchFieldsList>& load_fields) {
auto is_sortable = [&schema](std::string_view fident) {
auto it = schema.fields.find(fident);
return it != schema.fields.end() && (it->second.flags & search::SchemaField::SORTABLE);
};

absl::flat_hash_map<std::string_view, SearchField> fields_by_identifier;
absl::flat_hash_map<std::string_view, std::string_view> sort_indicies_aliases;
fields_by_identifier.reserve(schema.field_names.size());
sort_indicies_aliases.reserve(schema.field_names.size());

for (const auto& [fname, fident] : schema.field_names) {
if (!is_sortable(fident)) {
fields_by_identifier[fident] = {StringOrView::FromView(fident), true,
StringOrView::FromView(fname)};
} else {
sort_indicies_aliases[fident] = fname;
}
}

if (load_fields) {
for (const auto& field : load_fields.value()) {
const auto& fident = field.GetIdentifier(schema, false);
if (!is_sortable(fident)) {
fields_by_identifier[fident] = field.View();
} else {
sort_indicies_aliases[fident] = field.GetShortName();
}
}
}

SearchFieldsList fields;
fields.reserve(fields_by_identifier.size());
for (auto& [_, field] : fields_by_identifier) {
fields.emplace_back(std::move(field));
}

SortIndiciesFieldsList sort_fields;
sort_fields.reserve(sort_indicies_aliases.size());
for (auto& [fident, fname] : sort_indicies_aliases) {
sort_fields.emplace_back(fident, fname);
}

return {fields, sort_fields};
}

vector<SearchDocData> ShardDocIndex::SearchForAggregator(
const OpArgs& op_args, const AggregateParams& params,
search::SearchAlgorithm* search_algo) const {
Expand All @@ -295,19 +345,8 @@ vector<SearchDocData> ShardDocIndex::SearchForAggregator(
if (!search_results.error.empty())
return {};

auto sort_indicies_fields = indices_->GetSortIndiciesFields();
SearchFieldsList fields_to_load = GetFieldsToLoad(params.load_fields, sort_indicies_fields);

// aliases for ExtractStoredValues
absl::flat_hash_map<std::string_view, std::string_view> sort_indicies_aliases;
if (params.load_fields) {
for (const auto& field : params.load_fields.value()) {
auto ident = field.GetIdentifier(base_->schema, false);
if (sort_indicies_fields.contains(ident)) {
sort_indicies_aliases[ident] = field.GetShortName();
}
}
}
auto [fields_to_load, sort_indicies] =
PreprocessAggregateFields(base_->schema, params, params.load_fields);

vector<absl::flat_hash_map<string, search::SortableValue>> out;
for (DocId doc : search_results.ids) {
Expand All @@ -318,47 +357,23 @@ vector<SearchDocData> ShardDocIndex::SearchForAggregator(
continue;

auto accessor = GetAccessor(op_args.db_cntx, (*it)->second);
auto extracted = indices_->ExtractStoredValues(doc, sort_indicies_aliases);

SearchDocData extracted_sort_indicies;
extracted_sort_indicies.reserve(sort_indicies.size());
for (const auto& [fident, fname] : sort_indicies) {
extracted_sort_indicies[fname] = indices_->GetSortIndexValue(doc, fident);
}

SearchDocData loaded = accessor->Serialize(base_->schema, fields_to_load);

out.emplace_back(make_move_iterator(extracted.begin()), make_move_iterator(extracted.end()));
out.emplace_back(make_move_iterator(extracted_sort_indicies.begin()),
make_move_iterator(extracted_sort_indicies.end()));
out.back().insert(make_move_iterator(loaded.begin()), make_move_iterator(loaded.end()));
}

return out;
}

SearchFieldsList ShardDocIndex::GetFieldsToLoad(
const std::optional<OwnedSearchFieldsList>& load_fields,
const absl::flat_hash_set<std::string_view>& skip_fields) const {
absl::flat_hash_map<std::string_view, SearchField> unique_fields;
unique_fields.reserve(base_->schema.field_names.size());

for (const auto& [fname, fident] : base_->schema.field_names) {
if (!skip_fields.contains(fident)) {
unique_fields[fident] = {StringOrView::FromView(fident), true, StringOrView::FromView(fname)};
}
}

if (load_fields) {
for (const auto& field : load_fields.value()) {
const auto& fident = field.GetIdentifier(base_->schema, false);
if (!skip_fields.contains(fident)) {
unique_fields[fident] = field.View();
}
}
}

SearchFieldsList fields;
fields.reserve(unique_fields.size());
for (auto& [_, field] : unique_fields) {
fields.emplace_back(std::move(field));
}

return fields;
}

DocIndexInfo ShardDocIndex::GetInfo() const {
return {*base_, key_index_.Size()};
}
Expand Down
13 changes: 3 additions & 10 deletions src/server/search/doc_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,6 @@ class SearchField {
};

using SearchFieldsList = std::vector<SearchField>;
using OwnedSearchFieldsList = std::vector<SearchField>;

struct SearchParams {
// Parameters for "LIMIT offset total": select total amount documents with a specific offset from
Expand All @@ -141,14 +140,14 @@ struct SearchParams {
2. If set but empty -> no fields should be returned
3. If set and not empty -> return only these fields
*/
std::optional<OwnedSearchFieldsList> return_fields;
std::optional<SearchFieldsList> return_fields;

/*
Fields that should be also loaded from the document.
Only one of load_fields and return_fields should be set.
*/
std::optional<OwnedSearchFieldsList> load_fields;
std::optional<SearchFieldsList> load_fields;

std::optional<search::SortOption> sort_option;
search::QueryParams query_params;
Expand All @@ -168,7 +167,7 @@ struct AggregateParams {
std::string_view index, query;
search::QueryParams params;

std::optional<OwnedSearchFieldsList> load_fields;
std::optional<SearchFieldsList> load_fields;
std::vector<aggregate::PipelineStep> steps;
};

Expand Down Expand Up @@ -242,12 +241,6 @@ class ShardDocIndex {
io::Result<StringVec, facade::ErrorReply> GetTagVals(std::string_view field) const;

private:
/* Returns the fields that are the union of the already indexed fields and load_fields, excluding
skip_fields.
Load_fields should not be destroyed while the result of this function is being used */
SearchFieldsList GetFieldsToLoad(const std::optional<OwnedSearchFieldsList>& load_fields,
const absl::flat_hash_set<std::string_view>& skip_fields) const;

// Clears internal data. Traverses all matching documents and assigns ids.
void Rebuild(const OpArgs& op_args, PMR_NS::memory_resource* mr);

Expand Down
5 changes: 3 additions & 2 deletions src/server/search/search_family.cc
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ std::string_view ParseFieldWithAtSign(CmdArgParser* parser) {
return field;
}

void ParseLoadFields(CmdArgParser* parser, std::optional<OwnedSearchFieldsList>* load_fields) {
void ParseLoadFields(CmdArgParser* parser, std::optional<SearchFieldsList>* load_fields) {
// TODO: Change to num_strings. In Redis strings number is expected. For example: LOAD 3 $.a AS a
size_t num_fields = parser->Next<size_t>();
if (!load_fields->has_value()) {
Expand Down Expand Up @@ -263,7 +263,8 @@ optional<SearchParams> ParseSearchParamsOrReply(CmdArgParser* parser, SinkReplyB
}

// RETURN {num} [{ident} AS {name}...]
/* TODO: Change to num_strings. In Redis strings number is expected. For example: RETURN 3 $.a AS a */
/* TODO: Change to num_strings. In Redis strings number is expected. For example: RETURN 3 $.a
* AS a */
size_t num_fields = parser->Next<size_t>();
params.return_fields.emplace();
while (params.return_fields->size() < num_fields) {
Expand Down

0 comments on commit a62cf79

Please sign in to comment.