Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "fix(search_family): Remove the output of extra fields in the FT.AGGREGATE command" #4298

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 14 additions & 33 deletions src/server/search/aggregator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ namespace dfly::aggregate {
namespace {

struct GroupStep {
PipelineResult operator()(PipelineResult result) {
PipelineResult operator()(std::vector<DocValues> values) {
// Separate items into groups
absl::flat_hash_map<absl::FixedArray<Value>, std::vector<DocValues>> groups;
for (auto& value : result.values) {
for (auto& value : values) {
groups[Extract(value)].push_back(std::move(value));
}

Expand All @@ -28,18 +28,7 @@ struct GroupStep {
}
out.push_back(std::move(doc));
}

absl::flat_hash_set<std::string> fields_to_print;
fields_to_print.reserve(fields_.size() + reducers_.size());

for (auto& field : fields_) {
fields_to_print.insert(std::move(field));
}
for (auto& reducer : reducers_) {
fields_to_print.insert(std::move(reducer.result_field));
}

return {std::move(out), std::move(fields_to_print)};
return out;
}

absl::FixedArray<Value> Extract(const DocValues& dv) {
Expand Down Expand Up @@ -115,42 +104,34 @@ PipelineStep MakeGroupStep(absl::Span<const std::string_view> fields,
}

PipelineStep MakeSortStep(std::string_view field, bool descending) {
return [field = std::string(field), descending](PipelineResult result) -> PipelineResult {
auto& values = result.values;

return [field = std::string(field), descending](std::vector<DocValues> values) -> PipelineResult {
std::sort(values.begin(), values.end(), [field](const DocValues& l, const DocValues& r) {
auto it1 = l.find(field);
auto it2 = r.find(field);
return it1 == l.end() || (it2 != r.end() && it1->second < it2->second);
});

if (descending) {
if (descending)
std::reverse(values.begin(), values.end());
}

result.fields_to_print.insert(field);
return result;
return values;
};
}

PipelineStep MakeLimitStep(size_t offset, size_t num) {
return [offset, num](PipelineResult result) {
auto& values = result.values;
return [offset, num](std::vector<DocValues> values) -> PipelineResult {
values.erase(values.begin(), values.begin() + std::min(offset, values.size()));
values.resize(std::min(num, values.size()));
return result;
return values;
};
}

PipelineResult Process(std::vector<DocValues> values,
absl::Span<const std::string_view> fields_to_print,
absl::Span<const PipelineStep> steps) {
PipelineResult result{std::move(values), {fields_to_print.begin(), fields_to_print.end()}};
PipelineResult Process(std::vector<DocValues> values, absl::Span<const PipelineStep> steps) {
for (auto& step : steps) {
PipelineResult step_result = step(std::move(result));
result = std::move(step_result);
auto result = step(std::move(values));
if (!result.has_value())
return result;
values = std::move(result.value());
}
return result;
return values;
}

} // namespace dfly::aggregate
17 changes: 4 additions & 13 deletions src/server/search/aggregator.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
#pragma once

#include <absl/container/flat_hash_map.h>
#include <absl/container/flat_hash_set.h>
#include <absl/types/span.h>

#include <string>
Expand All @@ -20,16 +19,10 @@ namespace dfly::aggregate {
using Value = ::dfly::search::SortableValue;
using DocValues = absl::flat_hash_map<std::string, Value>; // documents sent through the pipeline

struct PipelineResult {
// Values to be passed to the next step
// TODO: Replace DocValues with compact linear search map instead of hash map
std::vector<DocValues> values;
// TODO: Replace DocValues with compact linear search map instead of hash map

// Fields from values to be printed
absl::flat_hash_set<std::string> fields_to_print;
};

using PipelineStep = std::function<PipelineResult(PipelineResult)>; // Group, Sort, etc.
using PipelineResult = io::Result<std::vector<DocValues>, facade::ErrorReply>;
using PipelineStep = std::function<PipelineResult(std::vector<DocValues>)>; // Group, Sort, etc.

// Iterator over Span<DocValues> that yields doc[field] or monostate if not present.
// Extra clumsy for STL compatibility!
Expand Down Expand Up @@ -89,8 +82,6 @@ PipelineStep MakeSortStep(std::string_view field, bool descending = false);
PipelineStep MakeLimitStep(size_t offset, size_t num);

// Process values with given steps
PipelineResult Process(std::vector<DocValues> values,
absl::Span<const std::string_view> fields_to_print,
absl::Span<const PipelineStep> steps);
PipelineResult Process(std::vector<DocValues> values, absl::Span<const PipelineStep> steps);

} // namespace dfly::aggregate
52 changes: 28 additions & 24 deletions src/server/search/aggregator_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@ TEST(AggregatorTest, Sort) {
};
PipelineStep steps[] = {MakeSortStep("a", false)};

auto result = Process(values, {"a"}, steps);
auto result = Process(values, steps);

EXPECT_EQ(result.values[0]["a"], Value(0.5));
EXPECT_EQ(result.values[1]["a"], Value(1.0));
EXPECT_EQ(result.values[2]["a"], Value(1.5));
EXPECT_TRUE(result);
EXPECT_EQ(result->at(0)["a"], Value(0.5));
EXPECT_EQ(result->at(1)["a"], Value(1.0));
EXPECT_EQ(result->at(2)["a"], Value(1.5));
}

TEST(AggregatorTest, Limit) {
Expand All @@ -34,11 +35,12 @@ TEST(AggregatorTest, Limit) {
};
PipelineStep steps[] = {MakeLimitStep(1, 2)};

auto result = Process(values, {"i"}, steps);
auto result = Process(values, steps);

EXPECT_EQ(result.values.size(), 2);
EXPECT_EQ(result.values[0]["i"], Value(2.0));
EXPECT_EQ(result.values[1]["i"], Value(3.0));
EXPECT_TRUE(result);
EXPECT_EQ(result->size(), 2);
EXPECT_EQ(result->at(0)["i"], Value(2.0));
EXPECT_EQ(result->at(1)["i"], Value(3.0));
}

TEST(AggregatorTest, SimpleGroup) {
Expand All @@ -52,11 +54,12 @@ TEST(AggregatorTest, SimpleGroup) {
std::string_view fields[] = {"tag"};
PipelineStep steps[] = {MakeGroupStep(fields, {})};

auto result = Process(values, {"i", "tag"}, steps);
EXPECT_EQ(result.values.size(), 2);
auto result = Process(values, steps);
EXPECT_TRUE(result);
EXPECT_EQ(result->size(), 2);

EXPECT_EQ(result.values[0].size(), 1);
std::set<Value> groups{result.values[0]["tag"], result.values[1]["tag"]};
EXPECT_EQ(result->at(0).size(), 1);
std::set<Value> groups{result->at(0)["tag"], result->at(1)["tag"]};
std::set<Value> expected{"even", "odd"};
EXPECT_EQ(groups, expected);
}
Expand All @@ -80,24 +83,25 @@ TEST(AggregatorTest, GroupWithReduce) {
Reducer{"null-field", "distinct-null", FindReducerFunc(ReducerFunc::COUNT_DISTINCT)}};
PipelineStep steps[] = {MakeGroupStep(fields, std::move(reducers))};

auto result = Process(values, {"i", "half-i", "tag"}, steps);
EXPECT_EQ(result.values.size(), 2);
auto result = Process(values, steps);
EXPECT_TRUE(result);
EXPECT_EQ(result->size(), 2);

// Reorder even first
if (result.values[0].at("tag") == Value("odd"))
std::swap(result.values[0], result.values[1]);
if (result->at(0).at("tag") == Value("odd"))
std::swap(result->at(0), result->at(1));

// Even
EXPECT_EQ(result.values[0].at("count"), Value{(double)5});
EXPECT_EQ(result.values[0].at("sum-i"), Value{(double)2 + 4 + 6 + 8});
EXPECT_EQ(result.values[0].at("distinct-hi"), Value{(double)3});
EXPECT_EQ(result.values[0].at("distinct-null"), Value{(double)1});
EXPECT_EQ(result->at(0).at("count"), Value{(double)5});
EXPECT_EQ(result->at(0).at("sum-i"), Value{(double)2 + 4 + 6 + 8});
EXPECT_EQ(result->at(0).at("distinct-hi"), Value{(double)3});
EXPECT_EQ(result->at(0).at("distinct-null"), Value{(double)1});

// Odd
EXPECT_EQ(result.values[1].at("count"), Value{(double)5});
EXPECT_EQ(result.values[1].at("sum-i"), Value{(double)1 + 3 + 5 + 7 + 9});
EXPECT_EQ(result.values[1].at("distinct-hi"), Value{(double)3});
EXPECT_EQ(result.values[1].at("distinct-null"), Value{(double)1});
EXPECT_EQ(result->at(1).at("count"), Value{(double)5});
EXPECT_EQ(result->at(1).at("sum-i"), Value{(double)1 + 3 + 5 + 7 + 9});
EXPECT_EQ(result->at(1).at("distinct-hi"), Value{(double)3});
EXPECT_EQ(result->at(1).at("distinct-null"), Value{(double)1});
}

} // namespace dfly::aggregate
30 changes: 9 additions & 21 deletions src/server/search/search_family.cc
Original file line number Diff line number Diff line change
Expand Up @@ -981,34 +981,22 @@ void SearchFamily::FtAggregate(CmdArgList args, const CommandContext& cmd_cntx)
make_move_iterator(sub_results.end()));
}

std::vector<std::string_view> load_fields;
if (params->load_fields) {
load_fields.reserve(params->load_fields->size());
for (const auto& field : params->load_fields.value()) {
load_fields.push_back(field.GetShortName());
}
}

auto agg_results = aggregate::Process(std::move(values), load_fields, params->steps);
auto agg_results = aggregate::Process(std::move(values), params->steps);
if (!agg_results.has_value())
return builder->SendError(agg_results.error());

size_t result_size = agg_results->size();
auto* rb = static_cast<RedisReplyBuilder*>(cmd_cntx.rb);
auto sortable_value_sender = SortableValueSender(rb);

const size_t result_size = agg_results.values.size();
rb->StartArray(result_size + 1);
rb->SendLong(result_size);

const size_t field_count = agg_results.fields_to_print.size();
for (const auto& value : agg_results.values) {
rb->StartArray(field_count * 2);
for (const auto& field : agg_results.fields_to_print) {
rb->SendBulkString(field);

if (auto it = value.find(field); it != value.end()) {
std::visit(sortable_value_sender, it->second);
} else {
rb->SendNull();
}
for (const auto& result : agg_results.value()) {
rb->StartArray(result.size() * 2);
for (const auto& [k, v] : result) {
rb->SendBulkString(k);
std::visit(sortable_value_sender, v);
}
}
}
Expand Down
41 changes: 8 additions & 33 deletions src/server/search/search_family_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -962,12 +962,15 @@ TEST_F(SearchFamilyTest, AggregateGroupBy) {
EXPECT_THAT(resp, IsUnordArrayWithSize(IsMap("foo_total", "20", "word", "item2"),
IsMap("foo_total", "50", "word", "item1")));

/*
Temporary not supported

resp = Run({"ft.aggregate", "i1", "*", "LOAD", "2", "foo", "text", "GROUPBY", "2", "@word",
"@text", "REDUCE", "SUM", "1", "@foo", "AS", "foo_total"});
EXPECT_THAT(resp, IsUnordArrayWithSize(
IsMap("foo_total", "40", "word", "item1", "text", "\"third key\""),
IsMap("foo_total", "20", "word", "item2", "text", "\"second key\""),
IsMap("foo_total", "10", "word", "item1", "text", "\"first key\"")));
"@text", "REDUCE", "SUM", "1", "@foo", "AS", "foo_total"}); EXPECT_THAT(resp,
IsUnordArrayWithSize(IsMap("foo_total", "20", "word", ArgType(RespExpr::NIL), "text", "\"second
key\""), IsMap("foo_total", "40", "word", ArgType(RespExpr::NIL), "text", "\"third key\""),
IsMap({"foo_total", "10", "word", ArgType(RespExpr::NIL), "text", "\"first key"})));
*/
}

TEST_F(SearchFamilyTest, JsonAggregateGroupBy) {
Expand Down Expand Up @@ -1629,32 +1632,4 @@ TEST_F(SearchFamilyTest, SearchLoadReturnHash) {
EXPECT_THAT(resp, IsMapWithSize("h2", IsMap("a", "two"), "h1", IsMap("a", "one")));
}

// Test that FT.AGGREGATE prints only needed fields
TEST_F(SearchFamilyTest, AggregateResultFields) {
Run({"JSON.SET", "j1", ".", R"({"a":"1","b":"2","c":"3"})"});
Run({"JSON.SET", "j2", ".", R"({"a":"4","b":"5","c":"6"})"});
Run({"JSON.SET", "j3", ".", R"({"a":"7","b":"8","c":"9"})"});

auto resp = Run({"FT.CREATE", "index", "ON", "JSON", "SCHEMA", "$.a", "AS", "a", "TEXT",
"SORTABLE", "$.b", "AS", "b", "TEXT", "$.c", "AS", "c", "TEXT"});
EXPECT_EQ(resp, "OK");

resp = Run({"FT.AGGREGATE", "index", "*"});
EXPECT_THAT(resp, IsUnordArrayWithSize(IsMap(), IsMap(), IsMap()));

resp = Run({"FT.AGGREGATE", "index", "*", "SORTBY", "1", "a"});
EXPECT_THAT(resp, IsUnordArrayWithSize(IsMap("a", "1"), IsMap("a", "4"), IsMap("a", "7")));

resp = Run({"FT.AGGREGATE", "index", "*", "LOAD", "1", "@b", "SORTBY", "1", "a"});
EXPECT_THAT(resp,
IsUnordArrayWithSize(IsMap("b", "\"2\"", "a", "1"), IsMap("b", "\"5\"", "a", "4"),
IsMap("b", "\"8\"", "a", "7")));

resp = Run({"FT.AGGREGATE", "index", "*", "SORTBY", "1", "a", "GROUPBY", "2", "@b", "@a",
"REDUCE", "COUNT", "0", "AS", "count"});
EXPECT_THAT(resp, IsUnordArrayWithSize(IsMap("b", "\"8\"", "a", "7", "count", "1"),
IsMap("b", "\"2\"", "a", "1", "count", "1"),
IsMap("b", "\"5\"", "a", "4", "count", "1")));
}

} // namespace dfly
Loading