Skip to content

Commit

Permalink
Test: Mock TiFlash compute service and dispatch MPPTask to single ser…
Browse files Browse the repository at this point in the history
…vice. (#5450)

ref #4609
  • Loading branch information
ywqzzy authored Aug 3, 2022
1 parent 1022499 commit 3153a3b
Show file tree
Hide file tree
Showing 26 changed files with 517 additions and 178 deletions.
275 changes: 148 additions & 127 deletions dbms/src/Debug/dbgFuncCoprocessor.cpp

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions dbms/src/Debug/dbgFuncCoprocessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ QueryTasks queryPlanToQueryTasks(
ExecutorPtr root_executor,
size_t & executor_index,
const Context & context);

BlockInputStreamPtr executeQuery(Context & context, RegionID region_id, const DAGProperties & properties, QueryTasks & query_tasks, MakeResOutputStream & func_wrap_output_stream);

BlockInputStreamPtr executeMPPQuery(Context & context, const DAGProperties & properties, QueryTasks & query_tasks);
namespace Debug
{
void setServiceAddr(const std::string & addr);
Expand Down
10 changes: 0 additions & 10 deletions dbms/src/Flash/Coprocessor/DAGContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -236,14 +236,4 @@ const SingleTableRegions & DAGContext::getTableRegionsInfoByTableID(Int64 table_
{
return tables_regions_info.getTableRegionInfoByTableID(table_id);
}

ColumnsWithTypeAndName DAGContext::columnsForTest(String executor_id)
{
auto it = columns_for_test_map.find(executor_id);
if (unlikely(it == columns_for_test_map.end()))
{
throw DB::Exception("Don't have columns for mock source executors");
}
return it->second;
}
} // namespace DB
12 changes: 0 additions & 12 deletions dbms/src/Flash/Coprocessor/DAGContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,6 @@ class DAGContext
, warning_count(0)
{
assert(dag_request->has_root_executor() && dag_request->root_executor().has_executor_id());

// only mpp task has join executor.
initExecutorIdToJoinIdMap();
initOutputInfo();
Expand All @@ -179,7 +178,6 @@ class DAGContext
, max_recorded_error_count(max_error_count_)
, warnings(max_recorded_error_count)
, warning_count(0)
, is_test(true)
{}

// for tests need to run query tasks.
Expand All @@ -194,7 +192,6 @@ class DAGContext
, max_recorded_error_count(getMaxErrorCount(*dag_request))
, warnings(max_recorded_error_count)
, warning_count(0)
, is_test(true)
{
assert(dag_request->has_root_executor() || dag_request->executors_size() > 0);
return_executor_id = dag_request->root_executor().has_executor_id() || dag_request->executors(0).has_executor_id();
Expand Down Expand Up @@ -309,12 +306,6 @@ class DAGContext

void updateFinalConcurrency(size_t cur_streams_size, size_t streams_upper_limit);

bool isTest() const { return is_test; }
void setColumnsForTest(std::unordered_map<String, ColumnsWithTypeAndName> & columns_for_test_map_) { columns_for_test_map = columns_for_test_map_; }
ColumnsWithTypeAndName columnsForTest(String executor_id);

bool columnsForTestEmpty() { return columns_for_test_map.empty(); }

ExchangeReceiverPtr getMPPExchangeReceiver(const String & executor_id) const;
void setMPPReceiverSet(const MPPReceiverSetPtr & receiver_set)
{
Expand Down Expand Up @@ -391,9 +382,6 @@ class DAGContext
/// vector of SubqueriesForSets(such as join build subquery).
/// The order of the vector is also the order of the subquery.
std::vector<SubqueriesForSets> subqueries;

bool is_test = false; /// switch for test, do not use it in production.
std::unordered_map<String, ColumnsWithTypeAndName> columns_for_test_map; /// <executor_id, columns>, for multiple sources
};

} // namespace DB
22 changes: 15 additions & 7 deletions dbms/src/Flash/Coprocessor/DAGQueryBlockInterpreter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ AnalysisResult analyzeExpressions(
ExpressionActionsChain chain;
// selection on table scan had been executed in handleTableScan
// In test mode, filter is not pushed down to table scan
if (query_block.selection && (!query_block.isTableScanSource() || context.getDAGContext()->isTest()))
if (query_block.selection && (!query_block.isTableScanSource() || context.isTest()))
{
std::vector<const tipb::Expr *> where_conditions;
for (const auto & c : query_block.selection->selection().conditions())
Expand Down Expand Up @@ -159,7 +159,7 @@ AnalysisResult analyzeExpressions(
// for tests, we need to mock tableScan blockInputStream as the source stream.
void DAGQueryBlockInterpreter::handleMockTableScan(const TiDBTableScan & table_scan, DAGPipeline & pipeline)
{
if (context.getDAGContext()->columnsForTestEmpty() || context.getDAGContext()->columnsForTest(table_scan.getTableScanExecutorID()).empty())
if (context.columnsForTestEmpty() || context.columnsForTest(table_scan.getTableScanExecutorID()).empty())
{
auto names_and_types = genNamesAndTypes(table_scan);
auto columns_with_type_and_name = getColumnWithTypeAndName(names_and_types);
Expand Down Expand Up @@ -279,7 +279,7 @@ void DAGQueryBlockInterpreter::handleJoin(const tipb::Join & join, DAGPipeline &
join_execute_info.join_build_streams.push_back(stream);
});
// for test, join executor need the return blocks to output.
executeUnion(build_pipeline, max_streams, log, /*ignore_block=*/!dagContext().isTest(), "for join");
executeUnion(build_pipeline, max_streams, log, /*ignore_block=*/!context.isTest(), "for join");

right_query.source = build_pipeline.firstStream();
right_query.join = join_ptr;
Expand Down Expand Up @@ -492,7 +492,7 @@ void DAGQueryBlockInterpreter::handleExchangeReceiver(DAGPipeline & pipeline)
// for tests, we need to mock ExchangeReceiver blockInputStream as the source stream.
void DAGQueryBlockInterpreter::handleMockExchangeReceiver(DAGPipeline & pipeline)
{
if (context.getDAGContext()->columnsForTestEmpty() || context.getDAGContext()->columnsForTest(query_block.source_name).empty())
if (context.columnsForTestEmpty() || context.columnsForTest(query_block.source_name).empty())
{
for (size_t i = 0; i < max_streams; ++i)
{
Expand Down Expand Up @@ -590,10 +590,14 @@ void DAGQueryBlockInterpreter::executeImpl(DAGPipeline & pipeline)
}
else if (query_block.source->tp() == tipb::ExecType::TypeExchangeReceiver)
{
if (unlikely(dagContext().isTest()))
if (unlikely(context.isExecutorTest()))
handleMockExchangeReceiver(pipeline);
else
{
// for MPP test, we can use real exchangeReceiver to run an query across different compute nodes
// or use one compute node to simulate MPP process.
handleExchangeReceiver(pipeline);
}
recordProfileStreams(pipeline, query_block.source_name);
}
else if (query_block.source->tp() == tipb::ExecType::TypeProjection)
Expand All @@ -604,7 +608,7 @@ void DAGQueryBlockInterpreter::executeImpl(DAGPipeline & pipeline)
else if (query_block.isTableScanSource())
{
TiDBTableScan table_scan(query_block.source, query_block.source_name, dagContext());
if (unlikely(dagContext().isTest()))
if (unlikely(context.isTest()))
handleMockTableScan(table_scan, pipeline);
else
handleTableScan(table_scan, pipeline);
Expand Down Expand Up @@ -685,10 +689,14 @@ void DAGQueryBlockInterpreter::executeImpl(DAGPipeline & pipeline)
// execute exchange_sender
if (query_block.exchange_sender)
{
if (unlikely(dagContext().isTest()))
if (unlikely(context.isExecutorTest()))
handleMockExchangeSender(pipeline);
else
{
// for MPP test, we can use real exchangeReceiver to run an query across different compute nodes
// or use one compute node to simulate MPP process.
handleExchangeSender(pipeline);
}
recordProfileStreams(pipeline, query_block.exchange_sender_name);
}
}
Expand Down
4 changes: 3 additions & 1 deletion dbms/src/Flash/Coprocessor/InterpreterUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,10 @@ void executeCreatingSets(
{
DAGContext & dag_context = *context.getDAGContext();
/// add union to run in parallel if needed
if (unlikely(dag_context.isTest()))
if (unlikely(context.isExecutorTest()))
executeUnion(pipeline, max_streams, log, /*ignore_block=*/false, "for test");
else if (context.isMPPTest())
executeUnion(pipeline, max_streams, log, /*ignore_block=*/true, "for mpp test");
else if (dag_context.isMPPTask())
/// MPPTask do not need the returned blocks.
executeUnion(pipeline, max_streams, log, /*ignore_block=*/true, "for mpp");
Expand Down
2 changes: 1 addition & 1 deletion dbms/src/Flash/Coprocessor/MockSourceStream.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ std::pair<NamesAndTypes, std::vector<std::shared_ptr<SourceType>>> mockSourceStr
NamesAndTypes names_and_types;
size_t rows = 0;
std::vector<std::shared_ptr<SourceType>> mock_source_streams;
columns_with_type_and_name = context.getDAGContext()->columnsForTest(executor_id);
columns_with_type_and_name = context.columnsForTest(executor_id);
for (const auto & col : columns_with_type_and_name)
{
if (rows == 0)
Expand Down
2 changes: 1 addition & 1 deletion dbms/src/Flash/Coprocessor/TablesRegionsInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ static void insertRegionInfoToTablesRegionInfo(const google::protobuf::RepeatedP
auto & table_region_info = tables_region_infos.getOrCreateTableRegionInfoByTableID(table_id);
for (const auto & r : regions)
{
RegionInfo region_info(r.region_id(), r.region_epoch().version(), r.region_epoch().conf_ver(), CoprocessorHandler::GenCopKeyRange(r.ranges()), nullptr);
RegionInfo region_info(r.region_id(), r.region_epoch().version(), r.region_epoch().conf_ver(), CoprocessorHandler::genCopKeyRange(r.ranges()), nullptr);
if (region_info.key_ranges.empty())
{
throw TiFlashException(
Expand Down
4 changes: 2 additions & 2 deletions dbms/src/Flash/CoprocessorHandler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ CoprocessorHandler::CoprocessorHandler(
, log(&Poco::Logger::get("CoprocessorHandler"))
{}

std::vector<std::pair<DecodedTiKVKeyPtr, DecodedTiKVKeyPtr>> CoprocessorHandler::GenCopKeyRange(
std::vector<std::pair<DecodedTiKVKeyPtr, DecodedTiKVKeyPtr>> CoprocessorHandler::genCopKeyRange(
const ::google::protobuf::RepeatedPtrField<::coprocessor::KeyRange> & ranges)
{
std::vector<std::pair<DecodedTiKVKeyPtr, DecodedTiKVKeyPtr>> key_ranges;
Expand Down Expand Up @@ -100,7 +100,7 @@ grpc::Status CoprocessorHandler::execute()
cop_context.kv_context.region_id(),
cop_context.kv_context.region_epoch().version(),
cop_context.kv_context.region_epoch().conf_ver(),
GenCopKeyRange(cop_request->ranges()),
genCopKeyRange(cop_request->ranges()),
&bypass_lock_ts));

DAGContext dag_context(dag_request);
Expand Down
2 changes: 1 addition & 1 deletion dbms/src/Flash/CoprocessorHandler.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class CoprocessorHandler

virtual grpc::Status execute();

static std::vector<std::pair<DecodedTiKVKeyPtr, DecodedTiKVKeyPtr>> GenCopKeyRange(
static std::vector<std::pair<DecodedTiKVKeyPtr, DecodedTiKVKeyPtr>> genCopKeyRange(
const ::google::protobuf::RepeatedPtrField<::coprocessor::KeyRange> & ranges);

protected:
Expand Down
9 changes: 6 additions & 3 deletions dbms/src/Flash/FlashService.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,8 @@ ::grpc::Status FlashService::DispatchMPPTask(
{
CPUAffinityManager::getInstance().bindSelfGrpcThread();
LOG_FMT_DEBUG(log, "Handling mpp dispatch request: {}", request->DebugString());
if (!security_config.checkGrpcContext(grpc_context))
// For MPP test, we don't care about security config.
if (!context.isMPPTest() && !security_config.checkGrpcContext(grpc_context))
{
return grpc::Status(grpc::PERMISSION_DENIED, tls_err_msg);
}
Expand Down Expand Up @@ -380,7 +381,9 @@ std::tuple<ContextPtr, grpc::Status> FlashService::createDBContext(const grpc::S
std::string client_ip = peer.substr(pos + 1);
Poco::Net::SocketAddress client_address(client_ip);

tmp_context->setUser(user, password, client_address, quota_key);
// For MPP test, we don't care about security config.
if (!context.isMPPTest())
tmp_context->setUser(user, password, client_address, quota_key);

String query_id = getClientMetaVarWithDefault(grpc_context, "query_id", "");
tmp_context->setCurrentQueryId(query_id);
Expand Down Expand Up @@ -436,4 +439,4 @@ ::grpc::Status FlashService::Compact(::grpc::ServerContext * grpc_context, const
return manual_compact_manager->handleRequest(request, response);
}

} // namespace DB
} // namespace DB
2 changes: 1 addition & 1 deletion dbms/src/Flash/FlashService.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,4 +122,4 @@ class AsyncFlashService final : public FlashService
}
};

} // namespace DB
} // namespace DB
1 change: 1 addition & 0 deletions dbms/src/Flash/Mpp/MPPTask.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,7 @@ void MPPTask::prepare(const mpp::DispatchTaskRequest & task_request)
dag_context->log = log;
dag_context->tables_regions_info = std::move(tables_regions_info);
dag_context->tidb_host = context->getClientInfo().current_address.toString();

context->setDAGContext(dag_context.get());

if (dag_context->isRootMPPTask())
Expand Down
4 changes: 2 additions & 2 deletions dbms/src/Flash/tests/gtest_aggregation_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -325,8 +325,8 @@ try
.aggregation({}, {col("s1")})
.build(context);
{
ASSERT_COLUMNS_EQ_R(executeStreams(request),
createColumns({toNullableVec<String>("s1", {{}, "banana"})}));
ASSERT_COLUMNS_EQ_UR(executeStreams(request),
createColumns({toNullableVec<String>("s1", {{}, "banana"})}));
}
}
CATCH
Expand Down
109 changes: 109 additions & 0 deletions dbms/src/Flash/tests/gtest_compute_server.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
// Copyright 2022 PingCAP, Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <TestUtils/MPPTaskTestUtils.h>

namespace DB
{
namespace tests
{
class ComputeServerRunner : public DB::tests::MPPTaskTestUtils
{
public:
void initializeContext() override
{
ExecutorTest::initializeContext();
/// for agg
context.addMockTable(
{"test_db", "test_table_1"},
{{"s1", TiDB::TP::TypeLong}, {"s2", TiDB::TP::TypeString}, {"s3", TiDB::TP::TypeString}},
{toNullableVec<Int32>("s1", {1, {}, 10000000}), toNullableVec<String>("s2", {"apple", {}, "banana"}), toNullableVec<String>("s3", {"apple", {}, "banana"})});

/// for join
context.addMockTable(
{"test_db", "l_table"},
{{"s", TiDB::TP::TypeString}, {"join_c", TiDB::TP::TypeString}},
{toNullableVec<String>("s", {"banana", {}, "banana"}), toNullableVec<String>("join_c", {"apple", {}, "banana"})});
context.addMockTable(
{"test_db", "r_table"},
{{"s", TiDB::TP::TypeString}, {"join_c", TiDB::TP::TypeString}},
{toNullableVec<String>("s", {"banana", {}, "banana"}), toNullableVec<String>("join_c", {"apple", {}, "banana"})});
}
};

TEST_F(ComputeServerRunner, runAggTasks)
try
{
{
auto tasks = context.scan("test_db", "test_table_1")
.aggregation({Max(col("s1"))}, {col("s2"), col("s3")})
.project({"max(s1)"})
.buildMPPTasks(context);

size_t task_size = tasks.size();

std::vector<String> expected_strings = {
"exchange_sender_5 | type:Hash, {<0, Long>, <1, String>, <2, String>}\n"
" aggregation_4 | group_by: {<1, String>, <2, String>}, agg_func: {max(<0, Long>)}\n"
" table_scan_0 | {<0, Long>, <1, String>, <2, String>}\n",
"exchange_sender_3 | type:PassThrough, {<0, Long>}\n"
" project_2 | {<0, Long>}\n"
" aggregation_1 | group_by: {<1, String>, <2, String>}, agg_func: {max(<0, Long>)}\n"
" exchange_receiver_6 | type:PassThrough, {<0, Long>, <1, String>, <2, String>}\n"};
for (size_t i = 0; i < task_size; ++i)
{
ASSERT_DAGREQUEST_EQAUL(expected_strings[i], tasks[i].dag_request);
}

auto expected_cols = {toNullableVec<Int32>({1, {}, 10000000})};
ASSERT_MPPTASK_EQUAL(tasks, expected_cols);
}
}
CATCH

TEST_F(ComputeServerRunner, runJoinTasks)
try
{
auto tasks = context
.scan("test_db", "l_table")
.join(context.scan("test_db", "r_table"), {col("join_c")}, tipb::JoinType::TypeLeftOuterJoin)
.topN("join_c", false, 2)
.buildMPPTasks(context);

size_t task_size = tasks.size();
std::vector<String> expected_strings = {
"exchange_sender_6 | type:Hash, {<0, String>}\n"
" table_scan_1 | {<0, String>}",
"exchange_sender_5 | type:Hash, {<0, String>, <1, String>}\n"
" table_scan_0 | {<0, String>, <1, String>}",
"exchange_sender_4 | type:PassThrough, {<0, String>, <1, String>, <2, String>}\n"
" topn_3 | order_by: {(<1, String>, desc: false)}, limit: 2\n"
" Join_2 | LeftOuterJoin, HashJoin. left_join_keys: {<0, String>}, right_join_keys: {<0, String>}\n"
" exchange_receiver_7 | type:PassThrough, {<0, String>, <1, String>}\n"
" exchange_receiver_8 | type:PassThrough, {<0, String>}"};
for (size_t i = 0; i < task_size; ++i)
{
ASSERT_DAGREQUEST_EQAUL(expected_strings[i], tasks[i].dag_request);
}

auto expected_cols = {
toNullableVec<String>({{}, "banana"}),
toNullableVec<String>({{}, "apple"}),
toNullableVec<String>({{}, {}})};
ASSERT_MPPTASK_EQUAL(tasks, expected_cols);
}
CATCH

} // namespace tests
} // namespace DB
Loading

0 comments on commit 3153a3b

Please sign in to comment.