Skip to content

Commit

Permalink
refine handle_other_conditions in join
Browse files Browse the repository at this point in the history
Signed-off-by: xufei <[email protected]>
  • Loading branch information
windtalker committed Jan 8, 2024
1 parent 383e1bd commit b261cdd
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 73 deletions.
100 changes: 36 additions & 64 deletions dbms/src/Interpreters/Join.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -780,6 +780,8 @@ void Join::cancelRuntimeFilter(const String & reason)
}
}

namespace
{
void mergeNullAndFilterResult(
Block & block,
ColumnVector<UInt8>::Container & filter_column,
Expand Down Expand Up @@ -833,6 +835,27 @@ void mergeNullAndFilterResult(
}
}
}
void applyNullToNotMatchedRows(Block & block, const Block & right_columns, const ColumnUInt8 & filter_column)
{
for (size_t i = 0; i < block.columns(); ++i)
{
auto & column = block.getByPosition(i);
if (right_columns.has(column.name))
{
auto full_column
= column.column->isColumnConst() ? column.column->convertToFullColumnIfConst() : column.column;
if (!full_column->isColumnNullable())
{
throw Exception("Should not reach here, the right table column for left join must be nullable");
}
auto current_column = full_column;
auto result_column = (*std::move(current_column)).mutate();
static_cast<ColumnNullable &>(*result_column).applyNegatedNullMap(filter_column);
column.column = std::move(result_column);
}
}
}
} // namespace

/**
* handle other join conditions
Expand All @@ -847,11 +870,8 @@ void mergeNullAndFilterResult(
* @param left_table_columns
* @param right_table_columns
*/
void Join::handleOtherConditions(
Block & block,
IColumn::Filter * anti_filter,
IColumn::Offsets * offsets_to_replicate,
const std::vector<size_t> & right_table_columns) const
void Join::handleOtherConditions(Block & block, IColumn::Filter * anti_filter, IColumn::Offsets * offsets_to_replicate)
const
{
/// save block_rows because block.rows() returns the first column's size, after other_cond_expr->execute(block),
/// some column maybe removed, and the first column maybe the match_helper_column which does not have the same size
Expand Down Expand Up @@ -1023,34 +1043,17 @@ void Join::handleOtherConditions(
}
prev_offset = current_offset;
}
erase_useless_column(block);
if (isLeftOuterJoin(kind))
{
/// for left join, convert right column to null if not joined
for (size_t right_table_column : right_table_columns)
{
auto & column = block.getByPosition(right_table_column);
if (output_column_names_set_after_finalize.contains(column.name))
{
auto full_column
= column.column->isColumnConst() ? column.column->convertToFullColumnIfConst() : column.column;
if (!full_column->isColumnNullable())
{
throw Exception("Should not reach here, the right table column for left join must be nullable");
}
auto current_column = full_column;
auto result_column = (*std::move(current_column)).mutate();
static_cast<ColumnNullable &>(*result_column).applyNegatedNullMap(*filter_column);
column.column = std::move(result_column);
}
}
erase_useless_column(block);
applyNullToNotMatchedRows(block, sample_block_without_keys, *filter_column);
for (size_t i = 0; i < block.columns(); ++i)
block.getByPosition(i).column = block.getByPosition(i).column->filter(row_filter, -1);
return;
}
if (is_semi_family)
{
erase_useless_column(block);
/// for semi/anti join, filter out not matched rows
for (size_t i = 0; i < block.columns(); ++i)
block.getByPosition(i).column = block.getByPosition(i).column->filter(row_filter, -1);
Expand Down Expand Up @@ -1123,10 +1126,10 @@ void Join::handleOtherConditionsForOneProbeRow(Block & block, ProbeProcessInfo &
matched_row_count_in_current_block = countBytesInFilter(filter);
probe_process_info.cross_join_data->has_row_matched |= matched_row_count_in_current_block != 0;
}
erase_useless_column(block);
/// case 1, inner join
if (kind == ASTTableJoin::Kind::Cross)
{
erase_useless_column(block);
if (matched_row_count_in_current_block > 0)
{
for (size_t i = 0; i < block.columns(); ++i)
Expand All @@ -1144,7 +1147,6 @@ void Join::handleOtherConditionsForOneProbeRow(Block & block, ProbeProcessInfo &
{
if (matched_row_count_in_current_block > 0)
{
erase_useless_column(block);
for (size_t i = 0; i < block.columns(); ++i)
block.safeGetByPosition(i).column
= block.safeGetByPosition(i).column->filter(filter, matched_row_count_in_current_block);
Expand All @@ -1155,36 +1157,17 @@ void Join::handleOtherConditionsForOneProbeRow(Block & block, ProbeProcessInfo &
for (size_t i = 0; i < block.columns(); ++i)
block.getByPosition(i).column = block.getByPosition(i).column->cut(0, 1);
filter.resize(1);
for (size_t right_table_column : probe_process_info.cross_join_data->right_column_index_in_result_block)
{
auto & column = block.getByPosition(right_table_column);
if (output_column_names_set_after_finalize.contains(column.name))
{
auto full_column
= column.column->isColumnConst() ? column.column->convertToFullColumnIfConst() : column.column;
if (!full_column->isColumnNullable())
{
throw Exception("Should not reach here, the right table column for left join must be nullable");
}
auto current_column = full_column;
auto result_column = (*std::move(current_column)).mutate();
static_cast<ColumnNullable &>(*result_column).applyNegatedNullMap(*filter_column);
column.column = std::move(result_column);
}
}
erase_useless_column(block);
applyNullToNotMatchedRows(block, sample_block_without_keys, *filter_column);
}
else
{
erase_useless_column(block);
block = block.cloneEmpty();
}
return;
}
/// case 3, semi join
if (kind == ASTTableJoin::Kind::Cross_Semi)
{
erase_useless_column(block);
if (probe_process_info.cross_join_data->has_row_matched)
{
/// has matched rows, return the first row, and set the current row probe done
Expand All @@ -1202,7 +1185,6 @@ void Join::handleOtherConditionsForOneProbeRow(Block & block, ProbeProcessInfo &
/// case 4, anti join
if (kind == ASTTableJoin::Kind::Cross_Anti)
{
erase_useless_column(block);
if (probe_process_info.cross_join_data->has_row_matched)
{
block = block.cloneEmpty();
Expand All @@ -1222,7 +1204,6 @@ void Join::handleOtherConditionsForOneProbeRow(Block & block, ProbeProcessInfo &
/// case 5, left outer semi join
if (isLeftOuterSemiFamily(kind))
{
erase_useless_column(block);
if (probe_process_info.cross_join_data->has_row_matched || probe_process_info.isCurrentProbeRowFinished())
{
for (size_t i = 0; i < block.columns(); ++i)
Expand Down Expand Up @@ -1284,14 +1265,6 @@ Block Join::doJoinBlockHash(ProbeProcessInfo & probe_process_info, const JoinBui
num_columns_to_add.push_back(i);
}

std::vector<size_t> right_table_column_indexes;
right_table_column_indexes.reserve(num_columns_to_add.size());

for (size_t i = 0; i < num_columns_to_add.size(); ++i)
{
right_table_column_indexes.push_back(i + existing_columns);
}

MutableColumns added_columns;
added_columns.reserve(num_columns_to_add.size());

Expand Down Expand Up @@ -1389,7 +1362,7 @@ Block Join::doJoinBlockHash(ProbeProcessInfo & probe_process_info, const JoinBui
if (has_other_condition)
{
assert(offsets_to_replicate != nullptr);
handleOtherConditions(block, nullptr, offsets_to_replicate.get(), right_table_column_indexes);
handleOtherConditions(block, nullptr, offsets_to_replicate.get());

if (useRowFlaggedHashMap(kind, has_other_condition))
{
Expand Down Expand Up @@ -1484,8 +1457,7 @@ Block Join::doJoinBlockCross(ProbeProcessInfo & probe_process_info) const
handleOtherConditions(
block,
probe_process_info.filter.get(),
probe_process_info.offsets_to_replicate.get(),
probe_process_info.cross_join_data->right_column_index_in_result_block);
probe_process_info.offsets_to_replicate.get());
}
return block;
}
Expand Down Expand Up @@ -1528,8 +1500,7 @@ Block Join::doJoinBlockCross(ProbeProcessInfo & probe_process_info) const
handleOtherConditions(
block,
probe_process_info.filter.get(),
probe_process_info.offsets_to_replicate.get(),
probe_process_info.cross_join_data->right_column_index_in_result_block);
probe_process_info.offsets_to_replicate.get());
}
return block;
}
Expand Down Expand Up @@ -2647,16 +2618,17 @@ void Join::finalize(const Names & parent_require)
updated_require.push_back(non_equal_conditions.other_eq_cond_from_in_name);
if (!non_equal_conditions.other_cond_name.empty())
updated_require.push_back(non_equal_conditions.other_cond_name);
auto keep_used_input_columns
= !isCrossJoin(kind) && (isNullAwareSemiFamily(kind) || isSemiFamily(kind) || isLeftOuterSemiFamily(kind));
/// nullaware/semi join will reuse the input columns so need to let finalize keep the input columns
if (non_equal_conditions.null_aware_eq_cond_expr != nullptr)
{
non_equal_conditions.null_aware_eq_cond_expr->finalize(updated_require, true);
non_equal_conditions.null_aware_eq_cond_expr->finalize(updated_require, keep_used_input_columns);
updated_require = non_equal_conditions.null_aware_eq_cond_expr->getRequiredColumns();
}
if (non_equal_conditions.other_cond_expr != nullptr)
{
/// todo don't keep input columns for non-semi/non-nullaware join
non_equal_conditions.other_cond_expr->finalize(updated_require, true);
non_equal_conditions.other_cond_expr->finalize(updated_require, keep_used_input_columns);
updated_require = non_equal_conditions.other_cond_expr->getRequiredColumns();
}
/// remove duplicated column
Expand Down
6 changes: 1 addition & 5 deletions dbms/src/Interpreters/Join.h
Original file line number Diff line number Diff line change
Expand Up @@ -481,11 +481,7 @@ class Join
*
* @param block
*/
void handleOtherConditions(
Block & block,
IColumn::Filter * filter,
IColumn::Offsets * offsets_to_replicate,
const std::vector<size_t> & right_table_column) const;
void handleOtherConditions(Block & block, IColumn::Filter * filter, IColumn::Offsets * offsets_to_replicate) const;

void handleOtherConditionsForOneProbeRow(Block & block, ProbeProcessInfo & probe_process_info) const;

Expand Down
3 changes: 0 additions & 3 deletions dbms/src/Interpreters/ProbeProcessInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,6 @@ void ProbeProcessInfo::prepareForCrossProbe(
cross_join_data->right_column_index_in_right_block.push_back(i);
}
}
auto offset = cross_join_data->left_column_index_in_left_block.size();
for (size_t i = 0; i < cross_join_data->right_column_index_in_right_block.size(); ++i)
cross_join_data->right_column_index_in_result_block.push_back(offset + i);
}
if (cross_join_data->cross_probe_mode == CrossProbeMode::SHALLOW_COPY_RIGHT_BLOCK && null_map != nullptr)
cross_join_data->row_num_filtered_by_left_condition = countBytesInFilter(*null_map);
Expand Down
1 change: 0 additions & 1 deletion dbms/src/Interpreters/ProbeProcessInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ struct HashJoinProbeProcessData
struct CrossJoinProbeProcessData
{
Block result_block_schema;
std::vector<size_t> right_column_index_in_result_block;
std::vector<size_t> right_column_index_in_right_block;
std::vector<size_t> left_column_index_in_left_block;
size_t right_rows_to_be_added_when_matched = 0;
Expand Down

0 comments on commit b261cdd

Please sign in to comment.