Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow using 'serverReturnFinalResult' to optimize server partitioned table #13208

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,10 @@ public static boolean isServerReturnFinalResult(Map<String, String> queryOptions
return Boolean.parseBoolean(queryOptions.get(QueryOptionKey.SERVER_RETURN_FINAL_RESULT));
}

public static boolean isServerReturnFinalResultKeyUnpartitioned(Map<String, String> queryOptions) {
return Boolean.parseBoolean(queryOptions.get(QueryOptionKey.SERVER_RETURN_FINAL_RESULT_KEY_UNPARTITIONED));
}

@Nullable
public static String getOrderByAlgorithm(Map<String, String> queryOptions) {
return queryOptions.get(QueryOptionKey.ORDER_BY_ALGORITHM);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import io.grpc.netty.shaded.io.netty.handler.ssl.SslContext;
import io.grpc.netty.shaded.io.netty.handler.ssl.SslContextBuilder;
import io.grpc.netty.shaded.io.netty.handler.ssl.SslProvider;
import java.io.Closeable;
import java.util.Collections;
import java.util.Iterator;
import java.util.Map;
Expand All @@ -42,7 +43,7 @@
import org.slf4j.LoggerFactory;


public class GrpcQueryClient {
public class GrpcQueryClient implements Closeable {
private static final Logger LOGGER = LoggerFactory.getLogger(GrpcQueryClient.class);
private static final int DEFAULT_CHANNEL_SHUTDOWN_TIMEOUT_SECOND = 10;
// the key is the hashCode of the TlsConfig, the value is the SslContext
Expand Down Expand Up @@ -74,9 +75,8 @@ private SslContext buildSslContext(TlsConfig tlsConfig) {
LOGGER.info("Building gRPC SSL context");
SslContext sslContext = CLIENT_SSL_CONTEXTS_CACHE.computeIfAbsent(tlsConfig.hashCode(), tlsConfigHashCode -> {
try {
SSLFactory sslFactory =
RenewableTlsUtils.createSSLFactoryAndEnableAutoRenewalWhenUsingFileStores(
tlsConfig, PinotInsecureMode::isPinotInInsecureMode);
SSLFactory sslFactory = RenewableTlsUtils.createSSLFactoryAndEnableAutoRenewalWhenUsingFileStores(tlsConfig,
PinotInsecureMode::isPinotInInsecureMode);
SslContextBuilder sslContextBuilder = SslContextBuilder.forClient();
sslFactory.getKeyManagerFactory().ifPresent(sslContextBuilder::keyManager);
sslFactory.getTrustManagerFactory().ifPresent(sslContextBuilder::trustManager);
Expand All @@ -98,6 +98,7 @@ public Iterator<Server.ServerResponse> submit(Server.ServerRequest request) {
return _blockingStub.submit(request);
}

@Override
public void close() {
if (!_managedChannel.isShutdown()) {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,12 @@ public class ConcurrentIndexedTable extends IndexedTable {

public ConcurrentIndexedTable(DataSchema dataSchema, QueryContext queryContext, int resultSize, int trimSize,
int trimThreshold) {
super(dataSchema, queryContext, resultSize, trimSize, trimThreshold, new ConcurrentHashMap<>());
this(dataSchema, false, queryContext, resultSize, trimSize, trimThreshold);
}

public ConcurrentIndexedTable(DataSchema dataSchema, boolean hasFinalInput, QueryContext queryContext, int resultSize,
int trimSize, int trimThreshold) {
super(dataSchema, hasFinalInput, queryContext, resultSize, trimSize, trimThreshold, new ConcurrentHashMap<>());
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
@SuppressWarnings({"rawtypes", "unchecked"})
public abstract class IndexedTable extends BaseTable {
protected final Map<Key, Record> _lookupMap;
protected final boolean _hasFinalInput;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this IndexedTable has multiple aggregations, do you need this _hasFinalInput also to be per agg function or it's just an indicate ?
E.g. SELECT distinctCount(pk), avg(other_column) from myTable

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In V1 engine we don't have a way to make per function hint. We can make it more smart in V2 which can carry more information.

protected final int _resultSize;
protected final int _numKeyColumns;
protected final AggregationFunction[] _aggregationFunctions;
Expand All @@ -54,16 +55,18 @@ public abstract class IndexedTable extends BaseTable {
* Constructor for the IndexedTable.
*
* @param dataSchema Data schema of the table
* @param hasFinalInput Whether the input is the final aggregate result
* @param queryContext Query context
* @param resultSize Number of records to keep in the final result after calling {@link #finish(boolean, boolean)}
* @param trimSize Number of records to keep when trimming the table
* @param trimThreshold Trim the table when the number of records exceeds the threshold
* @param lookupMap Map from keys to records
*/
protected IndexedTable(DataSchema dataSchema, QueryContext queryContext, int resultSize, int trimSize,
int trimThreshold, Map<Key, Record> lookupMap) {
protected IndexedTable(DataSchema dataSchema, boolean hasFinalInput, QueryContext queryContext, int resultSize,
int trimSize, int trimThreshold, Map<Key, Record> lookupMap) {
super(dataSchema);
_lookupMap = lookupMap;
_hasFinalInput = hasFinalInput;
_resultSize = resultSize;

List<ExpressionContext> groupByExpressions = queryContext.getGroupByExpressions();
Expand All @@ -74,7 +77,7 @@ protected IndexedTable(DataSchema dataSchema, QueryContext queryContext, int res
if (orderByExpressions != null) {
// GROUP BY with ORDER BY
_hasOrderBy = true;
_tableResizer = new TableResizer(dataSchema, queryContext);
_tableResizer = new TableResizer(dataSchema, hasFinalInput, queryContext);
// NOTE: trimSize is bounded by trimThreshold/2 to protect the server from using too much memory.
// TODO: Re-evaluate it as it can lead to in-accurate results
_trimSize = Math.min(trimSize, trimThreshold / 2);
Expand Down Expand Up @@ -102,34 +105,32 @@ public boolean upsert(Record record) {
* Adds a record with new key or updates a record with existing key.
*/
protected void addOrUpdateRecord(Key key, Record newRecord) {
_lookupMap.compute(key, (k, v) -> {
if (v == null) {
return newRecord;
} else {
Object[] existingValues = v.getValues();
Object[] newValues = newRecord.getValues();
int aggNum = 0;
for (int i = _numKeyColumns; i < _numColumns; i++) {
existingValues[i] = _aggregationFunctions[aggNum++].merge(existingValues[i], newValues[i]);
}
return v;
}
});
_lookupMap.compute(key, (k, v) -> v == null ? newRecord : updateRecord(v, newRecord));
}

/**
* Updates a record with existing key. Record with new key will be ignored.
*/
protected void updateExistingRecord(Key key, Record newRecord) {
_lookupMap.computeIfPresent(key, (k, v) -> {
Object[] existingValues = v.getValues();
Object[] newValues = newRecord.getValues();
int aggNum = 0;
for (int i = _numKeyColumns; i < _numColumns; i++) {
existingValues[i] = _aggregationFunctions[aggNum++].merge(existingValues[i], newValues[i]);
_lookupMap.computeIfPresent(key, (k, v) -> updateRecord(v, newRecord));
}

private Record updateRecord(Record existingRecord, Record newRecord) {
Object[] existingValues = existingRecord.getValues();
Object[] newValues = newRecord.getValues();
int numAggregations = _aggregationFunctions.length;
int index = _numKeyColumns;
if (!_hasFinalInput) {
for (int i = 0; i < numAggregations; i++, index++) {
existingValues[index] = _aggregationFunctions[i].merge(existingValues[index], newValues[index]);
}
return v;
});
} else {
for (int i = 0; i < numAggregations; i++, index++) {
existingValues[index] = _aggregationFunctions[i].mergeFinalResult((Comparable) existingValues[index],
(Comparable) newValues[index]);
}
}
return existingRecord;
}

/**
Expand All @@ -156,7 +157,8 @@ public void finish(boolean sort, boolean storeFinalResult) {
_topRecords = _lookupMap.values();
}
// TODO: Directly return final result in _tableResizer.getTopRecords to avoid extracting final result multiple times
if (storeFinalResult) {
assert !(_hasFinalInput && !storeFinalResult);
if (storeFinalResult && !_hasFinalInput) {
ColumnDataType[] columnDataTypes = _dataSchema.getColumnDataTypes();
int numAggregationFunctions = _aggregationFunctions.length;
for (int i = 0; i < numAggregationFunctions; i++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,12 @@ public class SimpleIndexedTable extends IndexedTable {

public SimpleIndexedTable(DataSchema dataSchema, QueryContext queryContext, int resultSize, int trimSize,
int trimThreshold) {
super(dataSchema, queryContext, resultSize, trimSize, trimThreshold, new HashMap<>());
this(dataSchema, false, queryContext, resultSize, trimSize, trimThreshold);
}

public SimpleIndexedTable(DataSchema dataSchema, boolean hasFinalInput, QueryContext queryContext, int resultSize,
int trimSize, int trimThreshold) {
super(dataSchema, hasFinalInput, queryContext, resultSize, trimSize, trimThreshold, new HashMap<>());
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
@SuppressWarnings({"rawtypes", "unchecked"})
public class TableResizer {
private final DataSchema _dataSchema;
private final boolean _hasFinalInput;
private final int _numGroupByExpressions;
private final Map<ExpressionContext, Integer> _groupByExpressionIndexMap;
private final AggregationFunction[] _aggregationFunctions;
Expand All @@ -61,7 +62,12 @@ public class TableResizer {
private final Comparator<IntermediateRecord> _intermediateRecordComparator;

public TableResizer(DataSchema dataSchema, QueryContext queryContext) {
this(dataSchema, false, queryContext);
}

public TableResizer(DataSchema dataSchema, boolean hasFinalInput, QueryContext queryContext) {
_dataSchema = dataSchema;
_hasFinalInput = hasFinalInput;

// NOTE: The data schema will always have group-by expressions in the front, followed by aggregation functions of
// the same order as in the query context. This is handled in AggregationGroupByOrderByOperator.
Expand Down Expand Up @@ -144,16 +150,20 @@ private OrderByValueExtractor getOrderByValueExtractor(ExpressionContext express
expression);
if (function.getType() == FunctionContext.Type.AGGREGATION) {
// Aggregation function
return new AggregationFunctionExtractor(_aggregationFunctionIndexMap.get(function));
} else if (function.getType() == FunctionContext.Type.TRANSFORM
&& "FILTER".equalsIgnoreCase(function.getFunctionName())) {
int index = _aggregationFunctionIndexMap.get(function);
// For final aggregate result, we can handle it the same way as group key
return _hasFinalInput ? new GroupByExpressionExtractor(_numGroupByExpressions + index)
: new AggregationFunctionExtractor(index);
} else if (function.getType() == FunctionContext.Type.TRANSFORM && "FILTER".equalsIgnoreCase(
function.getFunctionName())) {
// Filtered aggregation
FunctionContext aggregation = function.getArguments().get(0).getFunction();
ExpressionContext filterExpression = function.getArguments().get(1);
FilterContext filter = RequestContextUtils.getFilter(filterExpression);

int functionIndex = _filteredAggregationIndexMap.get(Pair.of(aggregation, filter));
AggregationFunction aggregationFunction = _filteredAggregationFunctions.get(functionIndex).getLeft();
return new AggregationFunctionExtractor(functionIndex, aggregationFunction);
int index = _filteredAggregationIndexMap.get(Pair.of(aggregation, filter));
// For final aggregate result, we can handle it the same way as group key
return _hasFinalInput ? new GroupByExpressionExtractor(_numGroupByExpressions + index)
: new AggregationFunctionExtractor(index, _filteredAggregationFunctions.get(index).getLeft());
} else {
// Post-aggregation function
return new PostAggregationFunctionExtractor(function);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,12 @@
public class UnboundedConcurrentIndexedTable extends ConcurrentIndexedTable {

public UnboundedConcurrentIndexedTable(DataSchema dataSchema, QueryContext queryContext, int resultSize) {
super(dataSchema, queryContext, resultSize, Integer.MAX_VALUE, Integer.MAX_VALUE);
this(dataSchema, false, queryContext, resultSize);
}

public UnboundedConcurrentIndexedTable(DataSchema dataSchema, boolean hasFinalInput, QueryContext queryContext,
int resultSize) {
super(dataSchema, hasFinalInput, queryContext, resultSize, Integer.MAX_VALUE, Integer.MAX_VALUE);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,10 +239,12 @@ public BaseResultsBlock mergeResults()
}

IndexedTable indexedTable = _indexedTable;
if (!_queryContext.isServerReturnFinalResult()) {
indexedTable.finish(false);
} else {
if (_queryContext.isServerReturnFinalResult()) {
indexedTable.finish(true, true);
} else if (_queryContext.isServerReturnFinalResultKeyUnpartitioned()) {
indexedTable.finish(false, true);
} else {
indexedTable.finish(false);
}
GroupByResultsBlock mergedBlock = new GroupByResultsBlock(indexedTable, _queryContext);
mergedBlock.setNumGroupsLimitReached(_numGroupsLimitReached);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -244,10 +244,12 @@ private BaseResultsBlock getFinalResult()
}

IndexedTable indexedTable = _indexedTable;
if (!_queryContext.isServerReturnFinalResult()) {
indexedTable.finish(false);
} else {
if (_queryContext.isServerReturnFinalResult()) {
indexedTable.finish(true, true);
} else if (_queryContext.isServerReturnFinalResultKeyUnpartitioned()) {
indexedTable.finish(false, true);
} else {
indexedTable.finish(false);
}
GroupByResultsBlock mergedBlock = new GroupByResultsBlock(indexedTable, _queryContext);
mergedBlock.setNumGroupsLimitReached(_numGroupsLimitReached);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,14 @@ void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder
*/
FinalResult extractFinalResult(IntermediateResult intermediateResult);

/**
* Merges two final results. This can be used to optimized certain functions (e.g. DISTINCT_COUNT) when data is
* partitioned on each server, where we may directly request servers to return final result and merge them on broker.
*/
default FinalResult mergeFinalResult(FinalResult finalResult1, FinalResult finalResult2) {
throw new UnsupportedOperationException("Cannot merge final results for function: " + getType());
}

/** @return Description of this operator for Explain Plan */
String toExplainString();
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@
*/
package org.apache.pinot.core.query.aggregation.function;

import it.unimi.dsi.fastutil.doubles.DoubleArrayList;
import it.unimi.dsi.fastutil.floats.FloatArrayList;
import it.unimi.dsi.fastutil.ints.IntArrayList;
import it.unimi.dsi.fastutil.longs.LongArrayList;
import it.unimi.dsi.fastutil.objects.ObjectArrayList;
import java.sql.Timestamp;
import java.util.ArrayList;
import java.util.Collections;
Expand Down Expand Up @@ -141,7 +146,7 @@ public static Map<ExpressionContext, BlockValSet> getBlockValSetMap(
* TODO: Move ser/de into AggregationFunction interface
*/
public static Object getIntermediateResult(DataTable dataTable, ColumnDataType columnDataType, int rowId, int colId) {
switch (columnDataType) {
switch (columnDataType.getStoredType()) {
case INT:
return dataTable.getInt(rowId, colId);
case LONG:
Expand All @@ -156,9 +161,43 @@ public static Object getIntermediateResult(DataTable dataTable, ColumnDataType c
}
}

/**
* Reads the final result from the {@link DataTable}.
*/
public static Comparable getFinalResult(DataTable dataTable, ColumnDataType columnDataType, int rowId, int colId) {
switch (columnDataType.getStoredType()) {
case INT:
return dataTable.getInt(rowId, colId);
case LONG:
return dataTable.getLong(rowId, colId);
case FLOAT:
return dataTable.getFloat(rowId, colId);
case DOUBLE:
return dataTable.getDouble(rowId, colId);
case BIG_DECIMAL:
return dataTable.getBigDecimal(rowId, colId);
case STRING:
return dataTable.getString(rowId, colId);
case BYTES:
return dataTable.getBytes(rowId, colId);
case INT_ARRAY:
return IntArrayList.wrap(dataTable.getIntArray(rowId, colId));
case LONG_ARRAY:
return LongArrayList.wrap(dataTable.getLongArray(rowId, colId));
case FLOAT_ARRAY:
return FloatArrayList.wrap(dataTable.getFloatArray(rowId, colId));
case DOUBLE_ARRAY:
return DoubleArrayList.wrap(dataTable.getDoubleArray(rowId, colId));
case STRING_ARRAY:
return ObjectArrayList.wrap(dataTable.getStringArray(rowId, colId));
default:
throw new IllegalStateException("Illegal column data type in final result: " + columnDataType);
}
}

/**
* Reads the converted final result from the {@link DataTable}. It should be equivalent to running
* {@link AggregationFunction#extractFinalResult(Object)} and {@link ColumnDataType#convert(Object)}.
* {@link #getFinalResult} and {@link ColumnDataType#convert}.
*/
public static Object getConvertedFinalResult(DataTable dataTable, ColumnDataType columnDataType, int rowId,
int colId) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,11 @@ public Integer extractFinalResult(Integer intermediateResult) {
return intermediateResult;
}

@Override
public Integer mergeFinalResult(Integer finalResult1, Integer finalResult2) {
return merge(finalResult1, finalResult2);
}

private int getInt(Integer val) {
return val == null ? _merger.getDefaultValue() : val;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,11 @@ public final Long extractFinalResult(Long longValue) {
return 0L;
}

@Override
public Long mergeFinalResult(Long finalResult1, Long finalResult2) {
return 0L;
}

/**
* The name of the column as follows:
* CHILD_AGGREGATION_NAME_PREFIX + actual function type + operands + CHILD_AGGREGATION_SEPERATOR
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,11 @@ public Long extractFinalResult(Long intermediateResult) {
return intermediateResult;
}

@Override
public Long mergeFinalResult(Long finalResult1, Long finalResult2) {
return finalResult1 + finalResult2;
}

@Override
public String toExplainString() {
StringBuilder stringBuilder = new StringBuilder(getType().getName()).append('(');
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,4 +66,9 @@ public ColumnDataType getFinalResultColumnType() {
public Integer extractFinalResult(Set intermediateResult) {
return intermediateResult.size();
}

@Override
public Integer mergeFinalResult(Integer finalResult1, Integer finalResult2) {
return finalResult1 + finalResult2;
}
}
Loading
Loading