diff --git a/dbms/src/Flash/tests/gtest_aggregation_executor.cpp b/dbms/src/Flash/tests/gtest_aggregation_executor.cpp index 3f1766769f5..30f85867a5d 100644 --- a/dbms/src/Flash/tests/gtest_aggregation_executor.cpp +++ b/dbms/src/Flash/tests/gtest_aggregation_executor.cpp @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include +#include #include namespace DB @@ -30,7 +30,7 @@ namespace tests types_col_name[a], types_col_name[b] \ } -class ExecutorAggTestRunner : public DB::tests::ExecutorTest +class ExecutorAggTestRunner : public DB::tests::AggregationTest { public: using ColStringNullableType = std::optional::FieldType>; @@ -355,6 +355,16 @@ try } CATCH +TEST_F(ExecutorAggTestRunner, TestFramwork) +try +{ + executeGroupByAndAssert({toNullableVec("tinyint_", col_tinyint)}, {toNullableVec({-1, 2, {}, 0, 1, 3, -2})}); + executeGroupByAndAssert({toNullableVec("tinyint_", col_tinyint), toNullableVec("smallint_", col_smallint)}, {toNullableVec({0, 2, 0, -1, 1, -2, 3, {}, {}}), toNullableVec({-1, 3, -2, 4, 2, 0, {}, {}, 0})}); + executeAggFunctionAndAssert({"Max"}, toNullableVec("tinyint_", col_tinyint), {toNullableVec(ColumnWithNullableInt8{3})}); + executeAggFunctionAndAssert({"Max", "Min"}, toNullableVec("tinyint_", col_tinyint), {toNullableVec(ColumnWithNullableInt8{3}), toNullableVec(ColumnWithNullableInt8{-2})}); +} +CATCH + // TODO support more type of min, max, count. // support more aggregation functions: sum, forst_row, group_concat diff --git a/dbms/src/TestUtils/AggregationTestUtils.cpp b/dbms/src/TestUtils/AggregationTestUtils.cpp index 114ad0a11e0..a66149e84f0 100644 --- a/dbms/src/TestUtils/AggregationTestUtils.cpp +++ b/dbms/src/TestUtils/AggregationTestUtils.cpp @@ -30,5 +30,85 @@ void AggregationTest::SetUpTestCase() }; register_func(DB::registerAggregateFunctions); + register_func(DB::registerFunctions); +} + +::testing::AssertionResult AggregationTest::checkAggReturnType(const String & agg_name, const DataTypes & data_types, const DataTypePtr & expect_type) +{ + AggregateFunctionPtr agg_ptr = DB::AggregateFunctionFactory::instance().get(agg_name, data_types, {}); + const DataTypePtr & ret_type = agg_ptr->getReturnType(); + if (ret_type->equals(*expect_type)) + return ::testing::AssertionSuccess(); + return ::testing::AssertionFailure() << "Expect type: " << expect_type->getName() << " Actual type: " << ret_type->getName(); +} + +void AggregationTest::executeAggFunctionAndAssert(const std::vector & func_names, const ColumnWithTypeAndName & column, const ColumnsWithTypeAndName & expected_cols) +{ + String db_name = "test_agg_function"; + String table_name = "test_table_agg"; + std::vector agg_funcs; + for (const auto & func_name : func_names) + agg_funcs.push_back(aggFunctionBuilder(func_name, column.name)); + + MockColumnInfoVec column_infos; + column_infos.push_back({column.name, dataTypeToTP(column.type)}); + context.addMockTable(db_name, table_name, column_infos, {column}); + + auto request = context.scan(db_name, table_name) + .aggregation(agg_funcs, {}) + .build(context); + + checkResult(request, expected_cols); +} + +void AggregationTest::executeGroupByAndAssert(const ColumnsWithTypeAndName & cols, const ColumnsWithTypeAndName & expected_cols) +{ + RUNTIME_CHECK_MSG(cols.size() == expected_cols.size(), "number of group_by columns don't match number of expected columns"); + + String db_name = "test_group"; + String table_name = "test_table_group"; + MockAstVec group_by_cols; + MockColumnNameVec proj_names; + MockColumnInfoVec column_infos; + for (const auto & col : cols) + { + group_by_cols.push_back(col(col.name)); + proj_names.push_back(col.name); + column_infos.push_back({col.name, dataTypeToTP(col.type)}); + } + + context.addMockTable(db_name, table_name, column_infos, cols); + + auto request = context.scan(db_name, table_name) + .aggregation({}, group_by_cols) + .project(proj_names) + .build(context); + + checkResult(request, expected_cols); +} + +void AggregationTest::checkResult(std::shared_ptr request, const ColumnsWithTypeAndName & expected_cols) +{ + for (size_t i = 1; i <= 10; ++i) + ASSERT_COLUMNS_EQ_UR(expected_cols, executeStreams(request, i)) << "expected_cols: " << getColumnsContent(expected_cols) << ", actual_cols: " << getColumnsContent(executeStreams(request, i)); +} + +ASTPtr AggregationTest::aggFunctionBuilder(const String & func_name, const String & col_name) +{ + ASTPtr func; + String func_name_lowercase = Poco::toLower(func_name); + + // TODO support more agg functions. + if (func_name_lowercase == "max") + func = Max(col(col_name)); + else if (func_name_lowercase == "min") + func = Min(col(col_name)); + else if (func_name_lowercase == "count") + func = Count(col(col_name)); + else if (func_name_lowercase == "sum") + func = Sum(col(col_name)); + else + throw Exception(fmt::format("Unsupported agg function {}", func_name), ErrorCodes::LOGICAL_ERROR); + return func; } } // namespace DB::tests diff --git a/dbms/src/TestUtils/AggregationTestUtils.h b/dbms/src/TestUtils/AggregationTestUtils.h index f69395794cb..638b8ed2504 100644 --- a/dbms/src/TestUtils/AggregationTestUtils.h +++ b/dbms/src/TestUtils/AggregationTestUtils.h @@ -15,27 +15,33 @@ #pragma once #include -#include -#include -#include -#include +#include namespace DB::tests { -class AggregationTest : public ::testing::Test +class AggregationTest : public ExecutorTest { public: - ::testing::AssertionResult checkAggReturnType(const String & agg_name, const DataTypes & data_types, const DataTypePtr & expect_type) - { - AggregateFunctionPtr agg_ptr = DB::AggregateFunctionFactory::instance().get(agg_name, data_types, {}); - const DataTypePtr & ret_type = agg_ptr->getReturnType(); - if (ret_type->equals(*expect_type)) - return ::testing::AssertionSuccess(); - return ::testing::AssertionFailure() << "Expect type: " << expect_type->getName() << " Actual type: " << ret_type->getName(); - } + static ::testing::AssertionResult checkAggReturnType(const String & agg_name, const DataTypes & data_types, const DataTypePtr & expect_type); + + // Test aggregation functions without group by. + void executeAggFunctionAndAssert( + const std::vector & func_names, + const ColumnWithTypeAndName & column, + const ColumnsWithTypeAndName & expected_cols); + + // Test group by columns + // Note that we must give columns in cols a name. + void executeGroupByAndAssert( + const ColumnsWithTypeAndName & cols, + const ColumnsWithTypeAndName & expected_cols); static void SetUpTestCase(); + +private: + void checkResult(std::shared_ptr request, const ColumnsWithTypeAndName & expected_cols); + ASTPtr aggFunctionBuilder(const String & func_name, const String & col_name); }; } // namespace DB::tests