Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refine handleOtherConditions in join #8642

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 37 additions & 72 deletions dbms/src/Interpreters/Join.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -780,6 +780,8 @@ void Join::cancelRuntimeFilter(const String & reason)
}
}

namespace
{
void mergeNullAndFilterResult(
Block & block,
ColumnVector<UInt8>::Container & filter_column,
Expand Down Expand Up @@ -833,6 +835,24 @@ void mergeNullAndFilterResult(
}
}
}
void applyNullToNotMatchedRows(Block & block, const Block & right_columns, const ColumnUInt8 & filter_column)
{
for (size_t i = 0; i < block.columns(); ++i)
{
auto & column = block.getByPosition(i);
if (right_columns.has(column.name))
{
auto full_column
= column.column->isColumnConst() ? column.column->convertToFullColumnIfConst() : column.column;
RUNTIME_CHECK_MSG(full_column->isColumnNullable(), "the right table column for left join must be nullable");
auto current_column = full_column;
auto result_column = (*std::move(current_column)).mutate();
static_cast<ColumnNullable &>(*result_column).applyNegatedNullMap(filter_column);
column.column = std::move(result_column);
}
}
}
} // namespace

/**
* handle other join conditions
Expand All @@ -847,11 +867,8 @@ void mergeNullAndFilterResult(
* @param left_table_columns
* @param right_table_columns
*/
void Join::handleOtherConditions(
Block & block,
IColumn::Filter * anti_filter,
IColumn::Offsets * offsets_to_replicate,
const std::vector<size_t> & right_table_columns) const
void Join::handleOtherConditions(Block & block, IColumn::Filter * anti_filter, IColumn::Offsets * offsets_to_replicate)
const
{
/// save block_rows because block.rows() returns the first column's size, after other_cond_expr->execute(block),
/// some column maybe removed, and the first column maybe the match_helper_column which does not have the same size
Expand All @@ -870,8 +887,7 @@ void Join::handleOtherConditions(
{
auto & col_name = input_block.getByPosition(i).name;
if ((!flag_mapped_entry_helper_name.empty() && col_name == flag_mapped_entry_helper_name)
|| output_column_names_set_after_finalize.find(col_name)
!= output_column_names_set_after_finalize.end())
|| output_column_names_set_after_finalize.contains(col_name))
++i;
else
input_block.erase(i);
Expand All @@ -885,10 +901,8 @@ void Join::handleOtherConditions(
assert(filter.empty());
filter.assign(block_rows, static_cast<UInt8>(1));
}
auto helper_pos = block.getPositionByName(match_helper_name);

const auto * old_match_nullable
= checkAndGetColumn<ColumnNullable>(block.safeGetByPosition(helper_pos).column.get());
= checkAndGetColumn<ColumnNullable>(block.getByName(match_helper_name).column.get());
const auto & old_match_vec
= static_cast<const ColumnVector<Int8> *>(old_match_nullable->getNestedColumnPtr().get())->getData();

Expand Down Expand Up @@ -959,7 +973,7 @@ void Join::handleOtherConditions(
}

erase_useless_column(block);
helper_pos = block.getPositionByName(match_helper_name);
auto helper_pos = block.getPositionByName(match_helper_name);
for (size_t i = 0; i < block.columns(); ++i)
if (i != helper_pos)
block.getByPosition(i).column = block.getByPosition(i).column->filter(row_filter, -1);
Expand Down Expand Up @@ -1023,34 +1037,17 @@ void Join::handleOtherConditions(
}
prev_offset = current_offset;
}
erase_useless_column(block);
if (isLeftOuterJoin(kind))
{
/// for left join, convert right column to null if not joined
for (size_t right_table_column : right_table_columns)
{
auto & column = block.getByPosition(right_table_column);
if (output_column_names_set_after_finalize.contains(column.name))
{
auto full_column
= column.column->isColumnConst() ? column.column->convertToFullColumnIfConst() : column.column;
if (!full_column->isColumnNullable())
{
throw Exception("Should not reach here, the right table column for left join must be nullable");
}
auto current_column = full_column;
auto result_column = (*std::move(current_column)).mutate();
static_cast<ColumnNullable &>(*result_column).applyNegatedNullMap(*filter_column);
column.column = std::move(result_column);
}
}
erase_useless_column(block);
applyNullToNotMatchedRows(block, sample_block_without_keys, *filter_column);
for (size_t i = 0; i < block.columns(); ++i)
block.getByPosition(i).column = block.getByPosition(i).column->filter(row_filter, -1);
return;
}
if (is_semi_family)
{
erase_useless_column(block);
/// for semi/anti join, filter out not matched rows
for (size_t i = 0; i < block.columns(); ++i)
block.getByPosition(i).column = block.getByPosition(i).column->filter(row_filter, -1);
Expand Down Expand Up @@ -1080,8 +1077,7 @@ void Join::handleOtherConditionsForOneProbeRow(Block & block, ProbeProcessInfo &
{
auto & col_name = input_block.getByPosition(i).name;
if ((!flag_mapped_entry_helper_name.empty() && col_name == flag_mapped_entry_helper_name)
|| output_column_names_set_after_finalize.find(col_name)
!= output_column_names_set_after_finalize.end())
|| output_column_names_set_after_finalize.contains(col_name))
++i;
else
input_block.erase(i);
Expand Down Expand Up @@ -1123,10 +1119,10 @@ void Join::handleOtherConditionsForOneProbeRow(Block & block, ProbeProcessInfo &
matched_row_count_in_current_block = countBytesInFilter(filter);
probe_process_info.cross_join_data->has_row_matched |= matched_row_count_in_current_block != 0;
}
erase_useless_column(block);
/// case 1, inner join
if (kind == ASTTableJoin::Kind::Cross)
{
erase_useless_column(block);
if (matched_row_count_in_current_block > 0)
{
for (size_t i = 0; i < block.columns(); ++i)
Expand All @@ -1144,7 +1140,6 @@ void Join::handleOtherConditionsForOneProbeRow(Block & block, ProbeProcessInfo &
{
if (matched_row_count_in_current_block > 0)
{
erase_useless_column(block);
for (size_t i = 0; i < block.columns(); ++i)
block.safeGetByPosition(i).column
= block.safeGetByPosition(i).column->filter(filter, matched_row_count_in_current_block);
Expand All @@ -1155,36 +1150,17 @@ void Join::handleOtherConditionsForOneProbeRow(Block & block, ProbeProcessInfo &
for (size_t i = 0; i < block.columns(); ++i)
block.getByPosition(i).column = block.getByPosition(i).column->cut(0, 1);
filter.resize(1);
for (size_t right_table_column : probe_process_info.cross_join_data->right_column_index_in_result_block)
{
auto & column = block.getByPosition(right_table_column);
if (output_column_names_set_after_finalize.contains(column.name))
{
auto full_column
= column.column->isColumnConst() ? column.column->convertToFullColumnIfConst() : column.column;
if (!full_column->isColumnNullable())
{
throw Exception("Should not reach here, the right table column for left join must be nullable");
}
auto current_column = full_column;
auto result_column = (*std::move(current_column)).mutate();
static_cast<ColumnNullable &>(*result_column).applyNegatedNullMap(*filter_column);
column.column = std::move(result_column);
}
}
erase_useless_column(block);
applyNullToNotMatchedRows(block, sample_block_without_keys, *filter_column);
}
else
{
erase_useless_column(block);
block = block.cloneEmpty();
}
return;
}
/// case 3, semi join
if (kind == ASTTableJoin::Kind::Cross_Semi)
{
erase_useless_column(block);
if (probe_process_info.cross_join_data->has_row_matched)
{
/// has matched rows, return the first row, and set the current row probe done
Expand All @@ -1202,7 +1178,6 @@ void Join::handleOtherConditionsForOneProbeRow(Block & block, ProbeProcessInfo &
/// case 4, anti join
if (kind == ASTTableJoin::Kind::Cross_Anti)
{
erase_useless_column(block);
if (probe_process_info.cross_join_data->has_row_matched)
{
block = block.cloneEmpty();
Expand All @@ -1222,7 +1197,6 @@ void Join::handleOtherConditionsForOneProbeRow(Block & block, ProbeProcessInfo &
/// case 5, left outer semi join
if (isLeftOuterSemiFamily(kind))
{
erase_useless_column(block);
if (probe_process_info.cross_join_data->has_row_matched || probe_process_info.isCurrentProbeRowFinished())
{
for (size_t i = 0; i < block.columns(); ++i)
Expand Down Expand Up @@ -1284,14 +1258,6 @@ Block Join::doJoinBlockHash(ProbeProcessInfo & probe_process_info, const JoinBui
num_columns_to_add.push_back(i);
}

std::vector<size_t> right_table_column_indexes;
right_table_column_indexes.reserve(num_columns_to_add.size());

for (size_t i = 0; i < num_columns_to_add.size(); ++i)
{
right_table_column_indexes.push_back(i + existing_columns);
}

MutableColumns added_columns;
added_columns.reserve(num_columns_to_add.size());

Expand Down Expand Up @@ -1389,7 +1355,7 @@ Block Join::doJoinBlockHash(ProbeProcessInfo & probe_process_info, const JoinBui
if (has_other_condition)
{
assert(offsets_to_replicate != nullptr);
handleOtherConditions(block, nullptr, offsets_to_replicate.get(), right_table_column_indexes);
handleOtherConditions(block, nullptr, offsets_to_replicate.get());

if (useRowFlaggedHashMap(kind, has_other_condition))
{
Expand Down Expand Up @@ -1484,8 +1450,7 @@ Block Join::doJoinBlockCross(ProbeProcessInfo & probe_process_info) const
handleOtherConditions(
block,
probe_process_info.filter.get(),
probe_process_info.offsets_to_replicate.get(),
probe_process_info.cross_join_data->right_column_index_in_result_block);
probe_process_info.offsets_to_replicate.get());
}
return block;
}
Expand Down Expand Up @@ -1528,8 +1493,7 @@ Block Join::doJoinBlockCross(ProbeProcessInfo & probe_process_info) const
handleOtherConditions(
block,
probe_process_info.filter.get(),
probe_process_info.offsets_to_replicate.get(),
probe_process_info.cross_join_data->right_column_index_in_result_block);
probe_process_info.offsets_to_replicate.get());
}
return block;
}
Expand Down Expand Up @@ -2647,16 +2611,17 @@ void Join::finalize(const Names & parent_require)
updated_require.push_back(non_equal_conditions.other_eq_cond_from_in_name);
if (!non_equal_conditions.other_cond_name.empty())
updated_require.push_back(non_equal_conditions.other_cond_name);
auto keep_used_input_columns
= !isCrossJoin(kind) && (isNullAwareSemiFamily(kind) || isSemiFamily(kind) || isLeftOuterSemiFamily(kind));
/// nullaware/semi join will reuse the input columns so need to let finalize keep the input columns
if (non_equal_conditions.null_aware_eq_cond_expr != nullptr)
{
non_equal_conditions.null_aware_eq_cond_expr->finalize(updated_require, true);
non_equal_conditions.null_aware_eq_cond_expr->finalize(updated_require, keep_used_input_columns);
updated_require = non_equal_conditions.null_aware_eq_cond_expr->getRequiredColumns();
}
if (non_equal_conditions.other_cond_expr != nullptr)
{
/// todo don't keep input columns for non-semi/non-nullaware join
non_equal_conditions.other_cond_expr->finalize(updated_require, true);
non_equal_conditions.other_cond_expr->finalize(updated_require, keep_used_input_columns);
updated_require = non_equal_conditions.other_cond_expr->getRequiredColumns();
}
/// remove duplicated column
Expand Down
6 changes: 1 addition & 5 deletions dbms/src/Interpreters/Join.h
Original file line number Diff line number Diff line change
Expand Up @@ -481,11 +481,7 @@ class Join
*
* @param block
*/
void handleOtherConditions(
Block & block,
IColumn::Filter * filter,
IColumn::Offsets * offsets_to_replicate,
const std::vector<size_t> & right_table_column) const;
void handleOtherConditions(Block & block, IColumn::Filter * filter, IColumn::Offsets * offsets_to_replicate) const;

void handleOtherConditionsForOneProbeRow(Block & block, ProbeProcessInfo & probe_process_info) const;

Expand Down
3 changes: 0 additions & 3 deletions dbms/src/Interpreters/ProbeProcessInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,6 @@ void ProbeProcessInfo::prepareForCrossProbe(
cross_join_data->right_column_index_in_right_block.push_back(i);
}
}
auto offset = cross_join_data->left_column_index_in_left_block.size();
for (size_t i = 0; i < cross_join_data->right_column_index_in_right_block.size(); ++i)
cross_join_data->right_column_index_in_result_block.push_back(offset + i);
}
if (cross_join_data->cross_probe_mode == CrossProbeMode::SHALLOW_COPY_RIGHT_BLOCK && null_map != nullptr)
cross_join_data->row_num_filtered_by_left_condition = countBytesInFilter(*null_map);
Expand Down
1 change: 0 additions & 1 deletion dbms/src/Interpreters/ProbeProcessInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ struct HashJoinProbeProcessData
struct CrossJoinProbeProcessData
{
Block result_block_schema;
std::vector<size_t> right_column_index_in_result_block;
std::vector<size_t> right_column_index_in_right_block;
std::vector<size_t> left_column_index_in_left_block;
size_t right_rows_to_be_added_when_matched = 0;
Expand Down