diff --git a/pom.xml b/pom.xml index f641b5dbcea09..8180462d9dee3 100644 --- a/pom.xml +++ b/pom.xml @@ -2197,6 +2197,18 @@ opentelemetry-semconv 1.19.0-alpha + + + org.apache.datasketches + datasketches-memory + 2.2.0 + + + + org.apache.datasketches + datasketches-java + 4.2.0 + diff --git a/presto-docs/src/main/sphinx/functions.rst b/presto-docs/src/main/sphinx/functions.rst index e3a63772292fa..7ac3846d4ccd8 100644 --- a/presto-docs/src/main/sphinx/functions.rst +++ b/presto-docs/src/main/sphinx/functions.rst @@ -35,3 +35,4 @@ Functions and Operators functions/teradata functions/internationalization functions/setdigest + functions/sketch diff --git a/presto-docs/src/main/sphinx/functions/sketch.rst b/presto-docs/src/main/sphinx/functions/sketch.rst new file mode 100644 index 0000000000000..92f1d9307f679 --- /dev/null +++ b/presto-docs/src/main/sphinx/functions/sketch.rst @@ -0,0 +1,32 @@ +=========================== +Sketch Functions +=========================== + +Sketches are data structures that can approximately answer particular questions +about a dataset when full accuracy is not required. The benefit of approximate +answers is that they are often faster and more efficient to compute than +functions which result in full accuracy. + +Presto provides support for computing some sketches available in the `Apache +DataSketches`_ library. + +.. function:: sketch_theta(data) -> varbinary + + Computes a `theta sketch`_ from an input dataset. The output from + this function can be used as an input to any of the other ``sketch_theta_*`` + family of functions. + +.. function:: sketch_theta_estimate(sketch) -> double + + Returns the estimate of distinct values from the input sketch. + +.. function:: sketch_theta_summary(sketch) -> row(estimate double, theta double, upper_bound_std double, lower_bound_std double, retained_entries int) + + Returns a summary of the input sketch which includes the distinct values + estimate alongside other useful information such as the sketch theta + parameter, current error bounds corresponding to 1 standard deviation, and + the number of retained entries in the sketch. + + +.. _Apache DataSketches: https://datasketches.apache.org/ +.. _theta sketch: https://datasketches.apache.org/docs/Theta/ThetaSketchFramework.html \ No newline at end of file diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/HiveMetadata.java b/presto-hive/src/main/java/com/facebook/presto/hive/HiveMetadata.java index 536f934e83587..682aceb949570 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/HiveMetadata.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/HiveMetadata.java @@ -3201,7 +3201,7 @@ private List getColumnStatisticMetadataForTemporaryTabl private List getColumnStatisticMetadata(String columnName, Set statisticTypes) { return statisticTypes.stream() - .map(type -> new ColumnStatisticMetadata(columnName, type)) + .map(type -> type.getColumnStatisticMetadata(columnName)) .collect(toImmutableList()); } diff --git a/presto-iceberg/pom.xml b/presto-iceberg/pom.xml index c900e0615bb06..23909cc57d0d4 100644 --- a/presto-iceberg/pom.xml +++ b/presto-iceberg/pom.xml @@ -453,6 +453,15 @@ provided + + org.apache.datasketches + datasketches-java + + + org.apache.datasketches + datasketches-memory + + com.facebook.presto diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/HiveTableOperations.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/HiveTableOperations.java index b86d413ee51a2..bf001e68eafb8 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/HiveTableOperations.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/HiveTableOperations.java @@ -21,6 +21,7 @@ import com.facebook.presto.hive.metastore.ExtendedHiveMetastore; import com.facebook.presto.hive.metastore.HivePrivilegeInfo; import com.facebook.presto.hive.metastore.MetastoreContext; +import com.facebook.presto.hive.metastore.PartitionStatistics; import com.facebook.presto.hive.metastore.PrestoTableType; import com.facebook.presto.hive.metastore.PrincipalPrivileges; import com.facebook.presto.hive.metastore.StorageFormat; @@ -305,7 +306,11 @@ public void commit(@Nullable TableMetadata base, TableMetadata metadata) metastore.createTable(metastoreContext, table, privileges); } else { + PartitionStatistics tableStats = metastore.getTableStatistics(metastoreContext, database, tableName); metastore.replaceTable(metastoreContext, database, tableName, table, privileges); + + // attempt to put back previous table statistics + metastore.updateTableStatistics(metastoreContext, database, tableName, oldStats -> tableStats); } } finally { diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergAbstractMetadata.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergAbstractMetadata.java index 78516400a9948..6e0dbb93c6421 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergAbstractMetadata.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergAbstractMetadata.java @@ -18,6 +18,7 @@ import com.facebook.presto.common.predicate.TupleDomain; import com.facebook.presto.common.type.TypeManager; import com.facebook.presto.hive.HiveWrittenPartitions; +import com.facebook.presto.hive.NodeVersion; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.ColumnMetadata; import com.facebook.presto.spi.ConnectorInsertTableHandle; @@ -37,11 +38,15 @@ import com.facebook.presto.spi.TableNotFoundException; import com.facebook.presto.spi.connector.ConnectorMetadata; import com.facebook.presto.spi.connector.ConnectorOutputMetadata; +import com.facebook.presto.spi.statistics.ColumnStatisticMetadata; import com.facebook.presto.spi.statistics.ComputedStatistics; +import com.facebook.presto.spi.statistics.TableStatisticType; import com.facebook.presto.spi.statistics.TableStatistics; +import com.facebook.presto.spi.statistics.TableStatisticsMetadata; import com.google.common.base.VerifyException; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import io.airlift.slice.Slice; import org.apache.iceberg.AppendFiles; import org.apache.iceberg.BaseTable; @@ -65,6 +70,7 @@ import java.io.IOException; import java.util.ArrayList; import java.util.Collection; +import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Optional; @@ -89,11 +95,13 @@ import static com.facebook.presto.iceberg.PartitionFields.getPartitionColumnName; import static com.facebook.presto.iceberg.PartitionFields.getTransformTerm; import static com.facebook.presto.iceberg.PartitionFields.toPartitionFields; +import static com.facebook.presto.iceberg.TableStatisticsMaker.getSupportedColumnStatistics; import static com.facebook.presto.iceberg.TableType.DATA; import static com.facebook.presto.iceberg.TypeConverter.toIcebergType; import static com.facebook.presto.iceberg.TypeConverter.toPrestoType; import static com.facebook.presto.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED; +import static com.facebook.presto.spi.statistics.TableStatisticType.ROW_COUNT; import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableMap.toImmutableMap; @@ -110,13 +118,15 @@ public abstract class IcebergAbstractMetadata protected final TypeManager typeManager; protected final JsonCodec commitTaskCodec; + protected final NodeVersion nodeVersion; protected Transaction transaction; - public IcebergAbstractMetadata(TypeManager typeManager, JsonCodec commitTaskCodec) + public IcebergAbstractMetadata(TypeManager typeManager, JsonCodec commitTaskCodec, NodeVersion nodeVersion) { this.typeManager = requireNonNull(typeManager, "typeManager is null"); this.commitTaskCodec = requireNonNull(commitTaskCodec, "commitTaskCodec is null"); + this.nodeVersion = requireNonNull(nodeVersion, "nodeVersion is null"); } protected abstract Table getIcebergTable(ConnectorSession session, SchemaTableName schemaTableName); @@ -342,6 +352,38 @@ protected static Schema toIcebergSchema(List columns) return new Schema(icebergSchema.asStructType().fields()); } + @Override + public ConnectorTableHandle getTableHandleForStatisticsCollection(ConnectorSession session, SchemaTableName tableName, Map analyzeProperties) + { + return getTableHandle(session, tableName); + } + + @Override + public TableStatisticsMetadata getStatisticsCollectionMetadata(ConnectorSession session, ConnectorTableMetadata tableMetadata) + { + Set columnStatistics = tableMetadata.getColumns().stream() + .filter(column -> !column.isHidden()) + .flatMap(meta -> getSupportedColumnStatistics(meta.getName(), meta.getType()).stream()) + .collect(toImmutableSet()); + + Set tableStatistics = ImmutableSet.of(ROW_COUNT); + return new TableStatisticsMetadata(columnStatistics, tableStatistics, Collections.emptyList()); + } + + @Override + public ConnectorTableHandle beginStatisticsCollection(ConnectorSession session, ConnectorTableHandle tableHandle) + { + return tableHandle; + } + + @Override + public void finishStatisticsCollection(ConnectorSession session, ConnectorTableHandle tableHandle, Collection computedStatistics) + { + IcebergTableHandle icebergTableHandle = (IcebergTableHandle) tableHandle; + Table icebergTable = getIcebergTable(session, icebergTableHandle.getSchemaTableName()); + TableStatisticsMaker.writeTableStatistics(nodeVersion, typeManager, icebergTableHandle, icebergTable, session, computedStatistics); + } + public void rollback() { // TODO: cleanup open transaction @@ -419,7 +461,7 @@ public TableStatistics getTableStatistics(ConnectorSession session, ConnectorTab { IcebergTableHandle handle = (IcebergTableHandle) tableHandle; Table icebergTable = getIcebergTable(session, handle.getSchemaTableName()); - return TableStatisticsMaker.getTableStatistics(typeManager, constraint, handle, icebergTable, columnHandles.stream().map(IcebergColumnHandle.class::cast).collect(Collectors.toList())); + return TableStatisticsMaker.getTableStatistics(session, typeManager, constraint, handle, icebergTable, columnHandles.stream().map(IcebergColumnHandle.class::cast).collect(Collectors.toList())); } @Override diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergConfig.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergConfig.java index 293da507a3aab..25e11cb572451 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergConfig.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergConfig.java @@ -44,6 +44,7 @@ public class IcebergConfig private double minimumAssignedSplitWeight = 0.05; private boolean parquetDereferencePushdownEnabled = true; private boolean mergeOnReadModeEnabled; + private double statisticSnapshotRecordDifferenceWeight; private HiveStatisticsMergeStrategy hiveStatisticsMergeStrategy = HiveStatisticsMergeStrategy.NONE; @@ -196,4 +197,19 @@ public HiveStatisticsMergeStrategy getHiveStatisticsMergeStrategy() { return hiveStatisticsMergeStrategy; } + + @Config("iceberg.statistic-snapshot-record-difference-weight") + @ConfigDescription("the amount that the difference in total record count matters when " + + "calculating the closest snapshot when picking statistics. A value of 1 means a single " + + "record is equivalent to 1 millisecond of time difference.") + public IcebergConfig setStatisticSnapshotRecordDifferenceWeight(double weight) + { + this.statisticSnapshotRecordDifferenceWeight = weight; + return this; + } + + public double getStatisticSnapshotRecordDifferenceWeight() + { + return statisticSnapshotRecordDifferenceWeight; + } } diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergHiveMetadata.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergHiveMetadata.java index 94ce43fed7065..99b02b8080e74 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergHiveMetadata.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergHiveMetadata.java @@ -23,6 +23,7 @@ import com.facebook.presto.hive.HiveColumnConverterProvider; import com.facebook.presto.hive.HiveColumnHandle; import com.facebook.presto.hive.HiveTypeTranslator; +import com.facebook.presto.hive.NodeVersion; import com.facebook.presto.hive.TableAlreadyExistsException; import com.facebook.presto.hive.ViewAlreadyExistsException; import com.facebook.presto.hive.metastore.Column; @@ -135,7 +136,6 @@ public class IcebergHiveMetadata private static final Logger log = Logger.get(IcebergAbstractMetadata.class); private final ExtendedHiveMetastore metastore; private final HdfsEnvironment hdfsEnvironment; - private final String prestoVersion; private final DateTimeZone timeZone = DateTimeZone.forTimeZone(TimeZone.getTimeZone(ZoneId.of(TimeZone.getDefault().getID()))); private final FilterStatsCalculatorService filterStatsCalculatorService; @@ -146,14 +146,13 @@ public IcebergHiveMetadata( HdfsEnvironment hdfsEnvironment, TypeManager typeManager, JsonCodec commitTaskCodec, - String prestoVersion, + NodeVersion nodeVersion, FilterStatsCalculatorService filterStatsCalculatorService, RowExpressionService rowExpressionService) { - super(typeManager, commitTaskCodec); + super(typeManager, commitTaskCodec, nodeVersion); this.metastore = requireNonNull(metastore, "metastore is null"); this.hdfsEnvironment = requireNonNull(hdfsEnvironment, "hdfsEnvironment is null"); - this.prestoVersion = requireNonNull(prestoVersion, "prestoVersion is null"); this.filterStatsCalculatorService = requireNonNull(filterStatsCalculatorService, "filterStatsCalculatorService is null"); this.rowExpressionService = requireNonNull(rowExpressionService, "rowExpressionService is null"); } @@ -364,7 +363,7 @@ public void createView(ConnectorSession session, ConnectorTableMetadata viewMeta Table table = createTableObjectForViewCreation( session, viewMetadata, - createIcebergViewProperties(session, prestoVersion), + createIcebergViewProperties(session, nodeVersion.toString()), new HiveTypeTranslator(), metastoreContext, encodeViewData(viewData)); @@ -457,7 +456,7 @@ public TableStatistics getTableStatistics(ConnectorSession session, ConnectorTab { IcebergTableHandle handle = (IcebergTableHandle) tableHandle; org.apache.iceberg.Table icebergTable = getHiveIcebergTable(metastore, hdfsEnvironment, session, handle.getSchemaTableName()); - TableStatistics icebergStatistics = TableStatisticsMaker.getTableStatistics(typeManager, constraint, handle, icebergTable, columnHandles.stream().map(IcebergColumnHandle.class::cast).collect(Collectors.toList())); + TableStatistics icebergStatistics = TableStatisticsMaker.getTableStatistics(session, typeManager, constraint, handle, icebergTable, columnHandles.stream().map(IcebergColumnHandle.class::cast).collect(Collectors.toList())); HiveStatisticsMergeStrategy mergeStrategy = getHiveStatisticsMergeStrategy(session); return tableLayoutHandle.map(IcebergTableLayoutHandle.class::cast).map(layoutHandle -> { TupleDomain predicate = layoutHandle.getTupleDomain().transform(icebergLayout -> { @@ -513,7 +512,7 @@ public TableStatisticsMetadata getStatisticsCollectionMetadata(ConnectorSession .filter(column -> !column.isHidden()) .flatMap(meta -> metastore.getSupportedColumnStatistics(getMetastoreContext(session), meta.getType()) .stream() - .map(statType -> new ColumnStatisticMetadata(meta.getName(), statType))) + .map(statType -> statType.getColumnStatisticMetadata(meta.getName()))) .collect(toImmutableSet()); Set tableStatistics = ImmutableSet.of(ROW_COUNT); diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergHiveMetadataFactory.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergHiveMetadataFactory.java index 51a0c553a0c07..dcc833320b9c5 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergHiveMetadataFactory.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergHiveMetadataFactory.java @@ -33,7 +33,7 @@ public class IcebergHiveMetadataFactory final HdfsEnvironment hdfsEnvironment; final TypeManager typeManager; final JsonCodec commitTaskCodec; - final String prestoVersion; + final NodeVersion nodeVersion; final FilterStatsCalculatorService filterStatsCalculatorService; final RowExpressionService rowExpressionService; @@ -52,8 +52,7 @@ public IcebergHiveMetadataFactory( this.hdfsEnvironment = requireNonNull(hdfsEnvironment, "hdfsEnvironment is null"); this.typeManager = requireNonNull(typeManager, "typeManager is null"); this.commitTaskCodec = requireNonNull(commitTaskCodec, "commitTaskCodec is null"); - requireNonNull(nodeVersion, "nodeVersion is null"); - this.prestoVersion = nodeVersion.toString(); + this.nodeVersion = requireNonNull(nodeVersion, "nodeVersion is null"); this.filterStatsCalculatorService = requireNonNull(filterStatsCalculatorService, "filterStatsCalculatorService is null"); this.rowExpressionService = requireNonNull(rowExpressionService, "rowExpressionService is null"); requireNonNull(config, "config is null"); @@ -61,6 +60,6 @@ public IcebergHiveMetadataFactory( public ConnectorMetadata create() { - return new IcebergHiveMetadata(metastore, hdfsEnvironment, typeManager, commitTaskCodec, prestoVersion, filterStatsCalculatorService, rowExpressionService); + return new IcebergHiveMetadata(metastore, hdfsEnvironment, typeManager, commitTaskCodec, nodeVersion, filterStatsCalculatorService, rowExpressionService); } } diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergNativeMetadata.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergNativeMetadata.java index 60d46f1547d82..622b22cddf3c4 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergNativeMetadata.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergNativeMetadata.java @@ -15,6 +15,7 @@ import com.facebook.airlift.json.JsonCodec; import com.facebook.presto.common.type.TypeManager; +import com.facebook.presto.hive.NodeVersion; import com.facebook.presto.hive.TableAlreadyExistsException; import com.facebook.presto.iceberg.util.IcebergPrestoModelConverters; import com.facebook.presto.spi.ColumnMetadata; @@ -74,9 +75,10 @@ public IcebergNativeMetadata( IcebergResourceFactory resourceFactory, TypeManager typeManager, JsonCodec commitTaskCodec, - CatalogType catalogType) + CatalogType catalogType, + NodeVersion nodeVersion) { - super(typeManager, commitTaskCodec); + super(typeManager, commitTaskCodec, nodeVersion); this.resourceFactory = requireNonNull(resourceFactory, "resourceFactory is null"); this.catalogType = requireNonNull(catalogType, "catalogType is null"); } diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergNativeMetadataFactory.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergNativeMetadataFactory.java index 156e4cdbc29a1..5991f0edc3224 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergNativeMetadataFactory.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergNativeMetadataFactory.java @@ -15,6 +15,7 @@ import com.facebook.airlift.json.JsonCodec; import com.facebook.presto.common.type.TypeManager; +import com.facebook.presto.hive.NodeVersion; import com.facebook.presto.spi.connector.ConnectorMetadata; import javax.inject.Inject; @@ -28,23 +29,26 @@ public class IcebergNativeMetadataFactory final JsonCodec commitTaskCodec; final IcebergResourceFactory resourceFactory; final CatalogType catalogType; + final NodeVersion nodeVersion; @Inject public IcebergNativeMetadataFactory( IcebergConfig config, IcebergResourceFactory resourceFactory, TypeManager typeManager, - JsonCodec commitTaskCodec) + JsonCodec commitTaskCodec, + NodeVersion nodeVersion) { this.resourceFactory = requireNonNull(resourceFactory, "resourceFactory is null"); this.typeManager = requireNonNull(typeManager, "typeManager is null"); this.commitTaskCodec = requireNonNull(commitTaskCodec, "commitTaskCodec is null"); + this.nodeVersion = requireNonNull(nodeVersion, "nodeVersion is null"); requireNonNull(config, "config is null"); this.catalogType = config.getCatalogType(); } public ConnectorMetadata create() { - return new IcebergNativeMetadata(resourceFactory, typeManager, commitTaskCodec, catalogType); + return new IcebergNativeMetadata(resourceFactory, typeManager, commitTaskCodec, catalogType, nodeVersion); } } diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergSessionProperties.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergSessionProperties.java index f70f24ffdb478..6077ad8371825 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergSessionProperties.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergSessionProperties.java @@ -84,6 +84,8 @@ public final class IcebergSessionProperties public static final String PARQUET_DEREFERENCE_PUSHDOWN_ENABLED = "parquet_dereference_pushdown_enabled"; public static final String MERGE_ON_READ_MODE_ENABLED = "merge_on_read_enabled"; public static final String HIVE_METASTORE_STATISTICS_MERGE_STRATEGY = "hive_statistics_merge_strategy"; + public static final String STATISTIC_SNAPSHOT_RECORD_DIFFERENCE_WEIGHT = "statistic_snapshot_record_difference_weight"; + private final List> sessionProperties; @Inject @@ -303,7 +305,14 @@ public IcebergSessionProperties( icebergConfig.getHiveStatisticsMergeStrategy(), false, val -> HiveStatisticsMergeStrategy.valueOf((String) val), - HiveStatisticsMergeStrategy::name)); + HiveStatisticsMergeStrategy::name), + doubleProperty(STATISTIC_SNAPSHOT_RECORD_DIFFERENCE_WEIGHT, + "the amount that the difference in total record count matters" + + "when calculating the closest snapshot when picking statistics. A " + + "value of 1 means a single record is equivalent to 1 millisecond of " + + "time difference.", + icebergConfig.getStatisticSnapshotRecordDifferenceWeight(), + false)); } public List> getSessionProperties() @@ -490,4 +499,9 @@ public static HiveStatisticsMergeStrategy getHiveStatisticsMergeStrategy(Connect { return session.getProperty(HIVE_METASTORE_STATISTICS_MERGE_STRATEGY, HiveStatisticsMergeStrategy.class); } + + public static double getStatisticSnapshotRecordDifferenceWeight(ConnectorSession session) + { + return session.getProperty(STATISTIC_SNAPSHOT_RECORD_DIFFERENCE_WEIGHT, Double.class); + } } diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/Partition.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/Partition.java index b39f71e57cde2..d99e3a2945169 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/Partition.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/Partition.java @@ -235,6 +235,10 @@ public static Map toMap(Map idToTy ImmutableMap.Builder map = ImmutableMap.builder(); idToMetricMap.forEach((id, value) -> { Type.PrimitiveType type = idToTypeMapping.get(id); + if (type == null) { + // may occur for non-primitive types such as row-types + return; + } map.put(id, Conversions.fromByteBuffer(type, value)); }); return map.build(); diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/TableStatisticsMaker.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/TableStatisticsMaker.java index 1331765b98dbf..77d6125e6d6b9 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/TableStatisticsMaker.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/TableStatisticsMaker.java @@ -16,30 +16,47 @@ import com.facebook.airlift.log.Logger; import com.facebook.presto.common.predicate.TupleDomain; import com.facebook.presto.common.type.TypeManager; +import com.facebook.presto.hive.NodeVersion; +import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.Constraint; import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.statistics.ColumnStatisticMetadata; import com.facebook.presto.spi.statistics.ColumnStatistics; +import com.facebook.presto.spi.statistics.ComputedStatistics; import com.facebook.presto.spi.statistics.DoubleRange; import com.facebook.presto.spi.statistics.Estimate; import com.facebook.presto.spi.statistics.TableStatistics; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import org.apache.iceberg.BlobMetadata; +import org.apache.datasketches.memory.Memory; +import org.apache.datasketches.theta.CompactSketch; import org.apache.iceberg.DataFile; import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.GenericBlobMetadata; +import org.apache.iceberg.GenericStatisticsFile; +import org.apache.iceberg.HasTableOperations; import org.apache.iceberg.PartitionField; import org.apache.iceberg.Snapshot; import org.apache.iceberg.StatisticsFile; import org.apache.iceberg.Table; import org.apache.iceberg.TableScan; import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.io.FileIO; +import org.apache.iceberg.io.OutputFile; +import org.apache.iceberg.puffin.Blob; +import org.apache.iceberg.puffin.Puffin; +import org.apache.iceberg.puffin.PuffinWriter; import org.apache.iceberg.types.Comparators; import org.apache.iceberg.types.Type; import org.apache.iceberg.types.Types; import java.io.IOException; import java.io.UncheckedIOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.Collection; +import java.util.Collections; import java.util.Comparator; -import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Map; @@ -49,38 +66,50 @@ import java.util.function.Predicate; import java.util.stream.Collectors; +import static com.facebook.presto.common.type.DateType.DATE; +import static com.facebook.presto.common.type.TimestampType.TIMESTAMP; +import static com.facebook.presto.common.type.TimestampWithTimeZoneType.TIMESTAMP_WITH_TIME_ZONE; +import static com.facebook.presto.common.type.TypeUtils.isNumericType; +import static com.facebook.presto.common.type.VarbinaryType.VARBINARY; +import static com.facebook.presto.common.type.Varchars.isVarcharType; import static com.facebook.presto.iceberg.ExpressionConverter.toIcebergExpression; +import static com.facebook.presto.iceberg.IcebergErrorCode.ICEBERG_FILESYSTEM_ERROR; import static com.facebook.presto.iceberg.IcebergErrorCode.ICEBERG_INVALID_METADATA; +import static com.facebook.presto.iceberg.IcebergSessionProperties.getStatisticSnapshotRecordDifferenceWeight; import static com.facebook.presto.iceberg.IcebergUtil.getColumns; import static com.facebook.presto.iceberg.IcebergUtil.getIdentityPartitions; import static com.facebook.presto.iceberg.Partition.toMap; +import static com.facebook.presto.spi.statistics.ColumnStatisticType.NUMBER_OF_DISTINCT_VALUES; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableMap.toImmutableMap; import static com.google.common.collect.Iterables.getOnlyElement; import static com.google.common.collect.Streams.stream; -import static java.lang.Long.parseLong; +import static java.lang.Math.abs; import static java.lang.String.format; +import static java.util.UUID.randomUUID; import static java.util.function.Function.identity; import static java.util.stream.Collectors.toSet; -import static org.apache.iceberg.puffin.StandardBlobTypes.APACHE_DATASKETCHES_THETA_V1; +import static org.apache.iceberg.SnapshotSummary.TOTAL_RECORDS_PROP; public class TableStatisticsMaker { private static final Logger log = Logger.get(TableStatisticsMaker.class); - - private static final String ICEBERG_APACHE_DATASKETCHES_THETA_V1_NDV_PROPERTY = "ndv"; + private static final String ICEBERG_THETA_SKETCH_BLOB_TYPE_ID = "apache-datasketches-theta-v1"; + private static final String ICEBERG_THETA_SKETCH_BLOB_PROPERTY_NDV_KEY = "ndv"; private final TypeManager typeManager; private final Table icebergTable; + private final ConnectorSession session; - private TableStatisticsMaker(TypeManager typeManager, Table icebergTable) + private TableStatisticsMaker(TypeManager typeManager, Table icebergTable, ConnectorSession session) { this.typeManager = typeManager; this.icebergTable = icebergTable; + this.session = session; } - public static TableStatistics getTableStatistics(TypeManager typeManager, Constraint constraint, IcebergTableHandle tableHandle, Table icebergTable, List columns) + public static TableStatistics getTableStatistics(ConnectorSession session, TypeManager typeManager, Constraint constraint, IcebergTableHandle tableHandle, Table icebergTable, List columns) { - return new TableStatisticsMaker(typeManager, icebergTable).makeTableStatistics(tableHandle, constraint, columns); + return new TableStatisticsMaker(typeManager, icebergTable, session).makeTableStatistics(tableHandle, constraint, columns); } private TableStatistics makeTableStatistics(IcebergTableHandle tableHandle, Constraint constraint, List selectedColumns) @@ -164,48 +193,15 @@ private TableStatistics makeTableStatistics(IcebergTableHandle tableHandle, Cons .build(); } - // get NDVs from statistics file(s) - ImmutableMap.Builder ndvByColumnId = ImmutableMap.builder(); - Set remainingColumnIds = new HashSet<>(idToColumnHandle.keySet()); - - getLatestStatisticsFile(icebergTable, tableHandle.getSnapshotId()).ifPresent(statisticsFile -> { - Map thetaBlobsByFieldId = statisticsFile.blobMetadata().stream() - .filter(blobMetadata -> blobMetadata.type().equals(APACHE_DATASKETCHES_THETA_V1)) - .filter(blobMetadata -> { - try { - return remainingColumnIds.contains(getOnlyElement(blobMetadata.fields())); - } - catch (IllegalArgumentException e) { - throw new PrestoException(ICEBERG_INVALID_METADATA, - format("blob metadata for blob type %s in statistics file %s must contain only one field. Found %d fields", - APACHE_DATASKETCHES_THETA_V1, statisticsFile.path(), blobMetadata.fields().size())); - } - }) - .collect(toImmutableMap(blobMetadata -> getOnlyElement(blobMetadata.fields()), identity())); - - for (Map.Entry entry : thetaBlobsByFieldId.entrySet()) { - int fieldId = entry.getKey(); - BlobMetadata blobMetadata = entry.getValue(); - String ndv = blobMetadata.properties().get(ICEBERG_APACHE_DATASKETCHES_THETA_V1_NDV_PROPERTY); - if (ndv == null) { - log.debug("Blob %s is missing %s property", blobMetadata.type(), ICEBERG_APACHE_DATASKETCHES_THETA_V1_NDV_PROPERTY); - remainingColumnIds.remove(fieldId); - } - else { - remainingColumnIds.remove(fieldId); - ndvByColumnId.put(fieldId, parseLong(ndv)); - } - } - }); - Map ndvById = ndvByColumnId.build(); - double recordCount = summary.getRecordCount(); TableStatistics.Builder result = TableStatistics.builder(); result.setRowCount(Estimate.of(recordCount)); result.setTotalSize(Estimate.of(summary.getSize())); + Map tableStats = getClosestStatisticsFileForSnapshot(tableHandle) + .map(TableStatisticsMaker::loadStatisticsFile).orElseGet(Collections::emptyMap); for (IcebergColumnHandle columnHandle : selectedColumns) { int fieldId = columnHandle.getId(); - ColumnStatistics.Builder columnBuilder = new ColumnStatistics.Builder(); + ColumnStatistics.Builder columnBuilder = tableStats.getOrDefault(fieldId, ColumnStatistics.builder()); Long nullCount = summary.getNullCounts().get(fieldId); if (nullCount != null) { columnBuilder.setNullsFraction(Estimate.of(nullCount / recordCount)); @@ -221,12 +217,74 @@ private TableStatistics makeTableStatistics(IcebergTableHandle tableHandle, Cons if (min instanceof Number && max instanceof Number) { columnBuilder.setRange(Optional.of(new DoubleRange(((Number) min).doubleValue(), ((Number) max).doubleValue()))); } - Optional.ofNullable(ndvById.get(fieldId)).ifPresent(ndv -> columnBuilder.setDistinctValuesCount(Estimate.of(ndv))); result.setColumnStatistics(columnHandle, columnBuilder.build()); } return result.build(); } + public static void writeTableStatistics(NodeVersion nodeVersion, TypeManager typeManager, IcebergTableHandle tableHandle, Table icebergTable, ConnectorSession session, Collection computedStatistics) + { + new TableStatisticsMaker(typeManager, icebergTable, session).writeTableStatistics(nodeVersion, tableHandle, computedStatistics); + } + + private void writeTableStatistics(NodeVersion nodeVersion, IcebergTableHandle tableHandle, Collection computedStatistics) + { + Snapshot snapshot = tableHandle.getSnapshotId().map(icebergTable::snapshot).orElseGet(icebergTable::currentSnapshot); + if (snapshot == null) { + // this may occur if the table has not been written to. + return; + } + try (FileIO io = icebergTable.io()) { + String path = ((HasTableOperations) icebergTable).operations().metadataFileLocation(format("%s-%s.stats", session.getQueryId(), randomUUID())); + OutputFile outputFile = io.newOutputFile(path); + try (PuffinWriter writer = Puffin.write(outputFile) + .createdBy("presto-" + nodeVersion) + .build()) { + computedStatistics.stream() + .map(ComputedStatistics::getColumnStatistics) + .filter(Objects::nonNull) + .flatMap(map -> map.entrySet().stream()) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)) + .forEach((key, value) -> { + if (!key.getStatisticType().equals(NUMBER_OF_DISTINCT_VALUES)) { + return; + } + Optional id = Optional.ofNullable(icebergTable.schema().findField(key.getColumnName())).map(Types.NestedField::fieldId); + if (!id.isPresent()) { + log.warn("failed to find column name %s in schema of table %s when writing distinct value statistics", key.getColumnName(), icebergTable.name()); + throw new PrestoException(ICEBERG_INVALID_METADATA, format("failed to find column name %s in schema of table %s when writing distinct value statistics", key.getColumnName(), icebergTable.name())); + } + ByteBuffer raw = VARBINARY.getSlice(value, 0).toByteBuffer(); + CompactSketch sketch = CompactSketch.wrap(Memory.wrap(raw, ByteOrder.nativeOrder())); + writer.add(new Blob( + ICEBERG_THETA_SKETCH_BLOB_TYPE_ID, + ImmutableList.of(id.get()), + snapshot.snapshotId(), + snapshot.sequenceNumber(), + raw, + null, + ImmutableMap.of(ICEBERG_THETA_SKETCH_BLOB_PROPERTY_NDV_KEY, Long.toString((long) sketch.getEstimate())))); + }); + writer.finish(); + icebergTable.updateStatistics().setStatistics( + snapshot.snapshotId(), + new GenericStatisticsFile( + snapshot.snapshotId(), + path, + writer.fileSize(), + writer.footerSize(), + writer.writtenBlobsMetadata().stream() + .map(GenericBlobMetadata::from) + .collect(toImmutableList()))) + .commit(); + } + catch (IOException e) { + log.warn(e, "failed to write table statistics file"); + throw new PrestoException(ICEBERG_FILESYSTEM_ERROR, "failed to write statistics file", e); + } + } + } + private static Optional getLatestStatisticsFile(Table table, Optional snapshotId) { if (table.statisticsFiles().isEmpty()) { @@ -327,4 +385,82 @@ private void updatePartitionedStats( } } } + + private Optional getClosestStatisticsFileForSnapshot(IcebergTableHandle handle) + { + Snapshot target = handle.getSnapshotId().map(icebergTable::snapshot).orElseGet(icebergTable::currentSnapshot); + return icebergTable.statisticsFiles() + .stream() + .min((first, second) -> { + if (first == second) { + return 0; + } + if (icebergTable.snapshot(first.snapshotId()) == null) { + return 1; + } + if (icebergTable.snapshot(second.snapshotId()) == null) { + return -1; + } + Snapshot firstSnap = icebergTable.snapshot(first.snapshotId()); + Snapshot secondSnap = icebergTable.snapshot(second.snapshotId()); + long firstDiff = abs(target.timestampMillis() - firstSnap.timestampMillis()); + long secondDiff = abs(target.timestampMillis() - secondSnap.timestampMillis()); + + // check if total-record exists + Optional targetTotalRecords = Optional.ofNullable(target.summary().get(TOTAL_RECORDS_PROP)).map(Long::parseLong); + Optional firstTotalRecords = Optional.ofNullable(firstSnap.summary().get(TOTAL_RECORDS_PROP)) + .map(Long::parseLong); + Optional secondTotalRecords = Optional.ofNullable(secondSnap.summary().get(TOTAL_RECORDS_PROP)) + .map(Long::parseLong); + + if (targetTotalRecords.isPresent() && firstTotalRecords.isPresent() && secondTotalRecords.isPresent()) { + long targetTotal = targetTotalRecords.get(); + double weight = getStatisticSnapshotRecordDifferenceWeight(session); + firstDiff += (long) (weight * abs(firstTotalRecords.get() - targetTotal)); + secondDiff += (long) (weight * abs(secondTotalRecords.get() - targetTotal)); + } + + return Long.compare(firstDiff, secondDiff); + }); + } + + /** + * Builds a map of field ID to ColumnStatistics for a particular {@link StatisticsFile}. + * + * @return + */ + private static Map loadStatisticsFile(StatisticsFile file) + { + ImmutableMap.Builder result = ImmutableMap.builder(); + file.blobMetadata().forEach(blob -> { + Integer field = getOnlyElement(blob.fields()); + ColumnStatistics.Builder colStats = ColumnStatistics.builder(); + Optional.ofNullable(blob.properties().get(ICEBERG_THETA_SKETCH_BLOB_PROPERTY_NDV_KEY)) + .ifPresent(ndvProp -> { + try { + long ndv = Long.parseLong(ndvProp); + colStats.setDistinctValuesCount(Estimate.of(ndv)); + } + catch (NumberFormatException e) { + colStats.setDistinctValuesCount(Estimate.unknown()); + log.warn("bad long value when parsing statistics file %s, bad value: %d", file.path(), ndvProp); + } + }); + result.put(field, colStats); + }); + return result.build(); + } + + public static List getSupportedColumnStatistics(String columnName, com.facebook.presto.common.type.Type type) + { + ImmutableList.Builder supportedStatistics = ImmutableList.builder(); + // all types which support being passed to the sketch_theta function + if (isNumericType(type) || type.equals(DATE) || isVarcharType(type) || + type.equals(TIMESTAMP) || + type.equals(TIMESTAMP_WITH_TIME_ZONE)) { + supportedStatistics.add(NUMBER_OF_DISTINCT_VALUES.getColumnStatisticMetadataWithCustomFunction(columnName, "sketch_theta")); + } + + return supportedStatistics.build(); + } } diff --git a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/IcebergDistributedTestBase.java b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/IcebergDistributedTestBase.java index b39bf002bf568..464e930bf538c 100644 --- a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/IcebergDistributedTestBase.java +++ b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/IcebergDistributedTestBase.java @@ -13,19 +13,34 @@ */ package com.facebook.presto.iceberg; +import com.facebook.presto.Session; +import com.facebook.presto.common.QualifiedObjectName; +import com.facebook.presto.common.transaction.TransactionId; +import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.spi.Constraint; +import com.facebook.presto.spi.TableHandle; +import com.facebook.presto.spi.analyzer.MetadataResolver; +import com.facebook.presto.spi.security.AllowAllAccessControl; +import com.facebook.presto.spi.statistics.ColumnStatistics; +import com.facebook.presto.spi.statistics.Estimate; +import com.facebook.presto.spi.statistics.TableStatistics; import com.facebook.presto.testing.MaterializedResult; import com.facebook.presto.testing.QueryRunner; import com.facebook.presto.tests.AbstractTestDistributedQueries; import com.google.common.collect.ImmutableMap; -import org.intellij.lang.annotations.Language; import org.testng.annotations.Test; import java.time.LocalDateTime; import java.time.format.DateTimeFormatter; +import java.util.ArrayList; import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.function.Function; import java.util.stream.Collectors; import static com.facebook.presto.common.type.VarcharType.VARCHAR; +import static com.facebook.presto.iceberg.IcebergSessionProperties.STATISTIC_SNAPSHOT_RECORD_DIFFERENCE_WEIGHT; import static com.facebook.presto.testing.MaterializedResult.resultBuilder; import static com.facebook.presto.testing.TestingAccessControlManager.TestingPrivilegeType.SELECT_COLUMN; import static com.facebook.presto.testing.TestingAccessControlManager.privilege; @@ -38,17 +53,24 @@ public class IcebergDistributedTestBase extends AbstractTestDistributedQueries { private final CatalogType catalogType; + private final Map extraConnectorProperties; - protected IcebergDistributedTestBase(CatalogType catalogType) + protected IcebergDistributedTestBase(CatalogType catalogType, Map extraConnectorProperties) { this.catalogType = requireNonNull(catalogType, "catalogType is null"); + this.extraConnectorProperties = requireNonNull(extraConnectorProperties, "extraConnectorProperties is null"); + } + + protected IcebergDistributedTestBase(CatalogType catalogType) + { + this(catalogType, ImmutableMap.of()); } @Override protected QueryRunner createQueryRunner() throws Exception { - return IcebergQueryRunner.createIcebergQueryRunner(ImmutableMap.of(), catalogType); + return IcebergQueryRunner.createIcebergQueryRunner(ImmutableMap.of(), catalogType, extraConnectorProperties); } @Override @@ -317,14 +339,14 @@ public void testPartitionedByVarbinaryType() // validate return data of VarbinaryType List varbinaryColumnDatas = getQueryRunner().execute("select b from test_partition_columns_varbinary order by a asc").getOnlyColumn().collect(Collectors.toList()); assertEquals(varbinaryColumnDatas.size(), 2); - assertEquals(varbinaryColumnDatas.get(0), new byte[]{(byte) 0xbc, (byte) 0xd1}); - assertEquals(varbinaryColumnDatas.get(1), new byte[]{(byte) 0xe3, (byte) 0xbc, (byte) 0xd1}); + assertEquals(varbinaryColumnDatas.get(0), new byte[] {(byte) 0xbc, (byte) 0xd1}); + assertEquals(varbinaryColumnDatas.get(1), new byte[] {(byte) 0xe3, (byte) 0xbc, (byte) 0xd1}); // validate column of VarbinaryType exists in query filter assertEquals(getQueryRunner().execute("select b from test_partition_columns_varbinary where b = X'bcd1'").getOnlyValue(), - new byte[]{(byte) 0xbc, (byte) 0xd1}); + new byte[] {(byte) 0xbc, (byte) 0xd1}); assertEquals(getQueryRunner().execute("select b from test_partition_columns_varbinary where b = X'e3bcd1'").getOnlyValue(), - new byte[]{(byte) 0xe3, (byte) 0xbc, (byte) 0xd1}); + new byte[] {(byte) 0xe3, (byte) 0xbc, (byte) 0xd1}); // validate column of VarbinaryType in system table "partitions" assertEquals(getQueryRunner().execute("select count(*) FROM \"test_partition_columns_varbinary$partitions\"").getOnlyValue(), 2L); @@ -335,9 +357,9 @@ public void testPartitionedByVarbinaryType() assertUpdate("delete from test_partition_columns_varbinary WHERE b = X'bcd1'", 1); varbinaryColumnDatas = getQueryRunner().execute("select b from test_partition_columns_varbinary order by a asc").getOnlyColumn().collect(Collectors.toList()); assertEquals(varbinaryColumnDatas.size(), 1); - assertEquals(varbinaryColumnDatas.get(0), new byte[]{(byte) 0xe3, (byte) 0xbc, (byte) 0xd1}); + assertEquals(varbinaryColumnDatas.get(0), new byte[] {(byte) 0xe3, (byte) 0xbc, (byte) 0xd1}); assertEquals(getQueryRunner().execute("select b FROM test_partition_columns_varbinary where b = X'e3bcd1'").getOnlyValue(), - new byte[]{(byte) 0xe3, (byte) 0xbc, (byte) 0xd1}); + new byte[] {(byte) 0xe3, (byte) 0xbc, (byte) 0xd1}); assertEquals(getQueryRunner().execute("select count(*) from \"test_partition_columns_varbinary$partitions\"").getOnlyValue(), 1L); assertEquals(getQueryRunner().execute("select row_count from \"test_partition_columns_varbinary$partitions\" where b = X'e3bcd1'").getOnlyValue(), 1L); @@ -371,10 +393,236 @@ public void testStringFilters() assertQuery("SELECT count(*) FROM test_varcharn_filter WHERE shipmode = 'NONEXIST'", "VALUES (0)"); } - private void assertExplainAnalyze(@Language("SQL") String query) + @Test + public void testReadWriteNDVs() + { + assertUpdate("CREATE TABLE test_stat_ndv (col0 int)"); + assertTrue(getQueryRunner().tableExists(getSession(), "test_stat_ndv")); + assertTableColumnNames("test_stat_ndv", "col0"); + + // test that stats don't exist before analyze + TableStatistics stats = getTableStats("test_stat_ndv"); + assertTrue(stats.getColumnStatistics().isEmpty()); + + // test after simple insert we get a good estimate + assertUpdate("INSERT INTO test_stat_ndv VALUES 1, 2, 3", 3); + getQueryRunner().execute("ANALYZE test_stat_ndv"); + stats = getTableStats("test_stat_ndv"); + assertEquals(stats.getColumnStatistics().values().stream().findFirst().get().getDistinctValuesCount(), Estimate.of(3.0)); + + // test after inserting the same values, we still get the same estimate + assertUpdate("INSERT INTO test_stat_ndv VALUES 1, 2, 3", 3); + stats = getTableStats("test_stat_ndv"); + assertEquals(stats.getColumnStatistics().values().stream().findFirst().get().getDistinctValuesCount(), Estimate.of(3.0)); + + // test after ANALYZING with the new inserts that the NDV estimate is the same + getQueryRunner().execute("ANALYZE test_stat_ndv"); + stats = getTableStats("test_stat_ndv"); + assertEquals(stats.getColumnStatistics().values().stream().findFirst().get().getDistinctValuesCount(), Estimate.of(3.0)); + + // test after inserting a new value, but not analyzing, the estimate is the same. + assertUpdate("INSERT INTO test_stat_ndv VALUES 4", 1); + stats = getTableStats("test_stat_ndv"); + assertEquals(stats.getColumnStatistics().values().stream().findFirst().get().getDistinctValuesCount(), Estimate.of(3.0)); + + // test that after analyzing, the updates stats show up. + getQueryRunner().execute("ANALYZE test_stat_ndv"); + stats = getTableStats("test_stat_ndv"); + assertEquals(stats.getColumnStatistics().values().stream().findFirst().get().getDistinctValuesCount(), Estimate.of(4.0)); + + // test adding a null value is successful, and analyze still runs successfully + assertUpdate("INSERT INTO test_stat_ndv VALUES NULL", 1); + assertQuerySucceeds("ANALYZE test_stat_ndv"); + stats = getTableStats("test_stat_ndv"); + assertEquals(stats.getColumnStatistics().values().stream().findFirst().get().getDistinctValuesCount(), Estimate.of(4.0)); + + assertUpdate("DROP TABLE test_stat_ndv"); + } + + @Test + public void testReadWriteNDVsComplexTypes() + { + assertUpdate("CREATE TABLE test_stat_ndv_complex (col0 int, col1 date, col2 varchar, col3 row(c0 int))"); + assertTrue(getQueryRunner().tableExists(getSession(), "test_stat_ndv_complex")); + assertTableColumnNames("test_stat_ndv_complex", "col0", "col1", "col2", "col3"); + + // test that stats don't exist before analyze + TableStatistics stats = getTableStats("test_stat_ndv_complex"); + assertTrue(stats.getColumnStatistics().isEmpty()); + + // test after simple insert we get a good estimate + assertUpdate("INSERT INTO test_stat_ndv_complex VALUES (0, current_date, 't1', row(0))", 1); + getQueryRunner().execute("ANALYZE test_stat_ndv_complex"); + stats = getTableStats("test_stat_ndv_complex"); + assertEquals(columnStatsFor(stats, "col0").getDistinctValuesCount(), Estimate.of(1.0)); + assertEquals(columnStatsFor(stats, "col1").getDistinctValuesCount(), Estimate.of(1.0)); + assertEquals(columnStatsFor(stats, "col2").getDistinctValuesCount(), Estimate.of(1.0)); + assertEquals(columnStatsFor(stats, "col3").getDistinctValuesCount(), Estimate.unknown()); + + // test after inserting the same values, we still get the same estimate + assertUpdate("INSERT INTO test_stat_ndv_complex VALUES (0, current_date, 't1', row(0))", 1); + stats = getTableStats("test_stat_ndv_complex"); + assertEquals(columnStatsFor(stats, "col0").getDistinctValuesCount(), Estimate.of(1.0)); + assertEquals(columnStatsFor(stats, "col1").getDistinctValuesCount(), Estimate.of(1.0)); + assertEquals(columnStatsFor(stats, "col2").getDistinctValuesCount(), Estimate.of(1.0)); + assertEquals(columnStatsFor(stats, "col3").getDistinctValuesCount(), Estimate.unknown()); + + // test after ANALYZING with the new inserts that the NDV estimate is the same + getQueryRunner().execute("ANALYZE test_stat_ndv_complex"); + stats = getTableStats("test_stat_ndv_complex"); + assertEquals(columnStatsFor(stats, "col0").getDistinctValuesCount(), Estimate.of(1.0)); + assertEquals(columnStatsFor(stats, "col1").getDistinctValuesCount(), Estimate.of(1.0)); + assertEquals(columnStatsFor(stats, "col2").getDistinctValuesCount(), Estimate.of(1.0)); + assertEquals(columnStatsFor(stats, "col3").getDistinctValuesCount(), Estimate.unknown()); + + // test after inserting a new value, but not analyzing, the estimate is the same. + assertUpdate("INSERT INTO test_stat_ndv_complex VALUES (1, current_date + interval '1' day, 't2', row(1))", 1); + stats = getTableStats("test_stat_ndv_complex"); + assertEquals(columnStatsFor(stats, "col0").getDistinctValuesCount(), Estimate.of(1.0)); + assertEquals(columnStatsFor(stats, "col1").getDistinctValuesCount(), Estimate.of(1.0)); + assertEquals(columnStatsFor(stats, "col2").getDistinctValuesCount(), Estimate.of(1.0)); + assertEquals(columnStatsFor(stats, "col3").getDistinctValuesCount(), Estimate.unknown()); + + // test that after analyzing, the updates stats show up. + getQueryRunner().execute("ANALYZE test_stat_ndv_complex"); + stats = getTableStats("test_stat_ndv_complex"); + assertEquals(columnStatsFor(stats, "col0").getDistinctValuesCount(), Estimate.of(2.0)); + assertEquals(columnStatsFor(stats, "col1").getDistinctValuesCount(), Estimate.of(2.0)); + assertEquals(columnStatsFor(stats, "col2").getDistinctValuesCount(), Estimate.of(2.0)); + assertEquals(columnStatsFor(stats, "col3").getDistinctValuesCount(), Estimate.unknown()); + + // test adding a null value is successful, and analyze still runs successfully + assertUpdate("INSERT INTO test_stat_ndv_complex VALUES (NULL, NULL, NULL, NULL)", 1); + assertQuerySucceeds("ANALYZE test_stat_ndv_complex"); + stats = getTableStats("test_stat_ndv_complex"); + assertEquals(columnStatsFor(stats, "col0").getDistinctValuesCount(), Estimate.of(2.0)); + assertEquals(columnStatsFor(stats, "col1").getDistinctValuesCount(), Estimate.of(2.0)); + assertEquals(columnStatsFor(stats, "col2").getDistinctValuesCount(), Estimate.of(2.0)); + assertEquals(columnStatsFor(stats, "col3").getDistinctValuesCount(), Estimate.unknown()); + + assertUpdate("DROP TABLE test_stat_ndv_complex"); + } + + @Test + public void testNDVsAtSnapshot() { - String value = (String) computeActual(query).getOnlyValue(); + assertUpdate("CREATE TABLE test_stat_snap (col0 int, col1 varchar)"); + assertTrue(getQueryRunner().tableExists(getSession(), "test_stat_snap")); + assertTableColumnNames("test_stat_snap", "col0", "col1"); + assertEquals(getTableSnapshots("test_stat_snap").size(), 0); + + assertQuerySucceeds("ANALYZE test_stat_snap"); + assertUpdate("INSERT INTO test_stat_snap VALUES (0, '0')", 1); + assertQuerySucceeds("ANALYZE test_stat_snap"); + assertUpdate("INSERT INTO test_stat_snap VALUES (1, '1')", 1); + assertQuerySucceeds("ANALYZE test_stat_snap"); + assertUpdate("INSERT INTO test_stat_snap VALUES (2, '2')", 1); + assertQuerySucceeds("ANALYZE test_stat_snap"); + assertUpdate("INSERT INTO test_stat_snap VALUES (3, '3')", 1); + assertQuerySucceeds("ANALYZE test_stat_snap"); + assertEquals(getTableSnapshots("test_stat_snap").size(), 4); + + List snaps = getTableSnapshots("test_stat_snap"); + for (int i = 0; i < snaps.size(); i++) { + TableStatistics statistics = getTableStats("test_stat_snap", Optional.of(snaps.get(i))); + // assert either case as we don't have good control over the timing of when statistics files are written + ColumnStatistics col0Stats = columnStatsFor(statistics, "col0"); + ColumnStatistics col1Stats = columnStatsFor(statistics, "col1"); + System.out.printf("distinct @ %s count col0: %s%n", snaps.get(i), col0Stats.getDistinctValuesCount()); + final int idx = i; + assertEither( + () -> assertEquals(col0Stats.getDistinctValuesCount(), Estimate.of(idx)), + () -> assertEquals(col0Stats.getDistinctValuesCount(), Estimate.of(idx + 1))); + assertEither( + () -> assertEquals(col1Stats.getDistinctValuesCount(), Estimate.of(idx)), + () -> assertEquals(col1Stats.getDistinctValuesCount(), Estimate.of(idx + 1))); + } + assertUpdate("DROP TABLE test_stat_snap"); + } - assertTrue(value.matches("(?s:.*)CPU:.*, Input:.*, Output(?s:.*)"), format("Expected output to contain \"CPU:.*, Input:.*, Output\", but it is %s", value)); + @Test + public void testStatsByDistance() + { + assertUpdate("CREATE TABLE test_stat_dist (col0 int)"); + assertTrue(getQueryRunner().tableExists(getSession(), "test_stat_dist")); + assertTableColumnNames("test_stat_dist", "col0"); + assertEquals(getTableSnapshots("test_stat_dist").size(), 0); + + assertUpdate("INSERT INTO test_stat_dist VALUES 0", 1); + assertQuerySucceeds("ANALYZE test_stat_dist"); + assertUpdate("INSERT INTO test_stat_dist VALUES 1", 1); + assertUpdate("INSERT INTO test_stat_dist VALUES 2", 1); + assertUpdate("INSERT INTO test_stat_dist VALUES 3", 1); + assertUpdate("INSERT INTO test_stat_dist VALUES 4", 1); + assertUpdate("INSERT INTO test_stat_dist VALUES 5", 1); + assertQuerySucceeds("ANALYZE test_stat_dist"); + assertEquals(getTableSnapshots("test_stat_dist").size(), 6); + List snapshots = getTableSnapshots("test_stat_dist"); + // set a high weight so the weighting calculation is mostly done by record count + Session weightedSession = Session.builder(getSession()) + .setCatalogSessionProperty("iceberg", STATISTIC_SNAPSHOT_RECORD_DIFFERENCE_WEIGHT, "10000000") + .build(); + Function ndvs = (x) -> columnStatsFor(getTableStats("test_stat_dist", Optional.of(snapshots.get(x)), weightedSession), "col0") + .getDistinctValuesCount(); + assertEquals(ndvs.apply(0).getValue(), 1); + assertEquals(ndvs.apply(1).getValue(), 1); + assertEquals(ndvs.apply(2).getValue(), 1); + assertEquals(ndvs.apply(3).getValue(), 6); + assertEquals(ndvs.apply(4).getValue(), 6); + assertEquals(ndvs.apply(5).getValue(), 6); + assertUpdate("DROP TABLE test_stat_dist"); + } + + private static void assertEither(Runnable first, Runnable second) + { + try { + first.run(); + } + catch (AssertionError e) { + second.run(); + } + } + + private List getTableSnapshots(String tableName) + { + MaterializedResult result = getQueryRunner().execute(format("SELECT snapshot_id FROM \"%s$snapshots\" ORDER BY committed_at", tableName)); + return result.getOnlyColumn().map(Long.class::cast).collect(Collectors.toList()); + } + + private TableStatistics getTableStats(String name) + { + return getTableStats(name, Optional.empty()); + } + + private TableStatistics getTableStats(String name, Optional snapshot) + { + return getTableStats(name, snapshot, getSession()); + } + + private TableStatistics getTableStats(String name, Optional snapshot, Session session) + { + TransactionId transactionId = getQueryRunner().getTransactionManager().beginTransaction(false); + Session metadataSession = session.beginTransactionId( + transactionId, + getQueryRunner().getTransactionManager(), + new AllowAllAccessControl()); + Metadata metadata = getDistributedQueryRunner().getMetadata(); + MetadataResolver resolver = metadata.getMetadataResolver(metadataSession); + String tableName = snapshot.map(snap -> format("%s@%d", name, snap)).orElse(name); + String qualifiedName = format("%s.%s.%s", getSession().getCatalog().get(), getSession().getSchema().get(), tableName); + TableHandle handle = resolver.getTableHandle(QualifiedObjectName.valueOf(qualifiedName)).get(); + return metadata.getTableStatistics(metadataSession, + handle, + new ArrayList<>(resolver.getColumnHandles(handle).values()), + Constraint.alwaysTrue()); + } + + private static ColumnStatistics columnStatsFor(TableStatistics statistics, String name) + { + return statistics.getColumnStatistics().entrySet() + .stream().filter(entry -> ((IcebergColumnHandle) entry.getKey()).getName().equals(name)) + .map(Map.Entry::getValue) + .findFirst() + .get(); } } diff --git a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/IcebergQueryRunner.java b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/IcebergQueryRunner.java index 93f04259918e6..6b4c3f582a28e 100644 --- a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/IcebergQueryRunner.java +++ b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/IcebergQueryRunner.java @@ -56,6 +56,17 @@ public static DistributedQueryRunner createIcebergQueryRunner(Map extraProperties, CatalogType catalogType, Map extraConnectorProperties) + throws Exception + { + return createIcebergQueryRunner( + extraProperties, + ImmutableMap.builder() + .putAll(extraConnectorProperties) + .put("iceberg.catalog.type", catalogType.name()) + .build()); + } + public static DistributedQueryRunner createIcebergQueryRunner(Map extraProperties, Map extraConnectorProperties) throws Exception { diff --git a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/TestIcebergConfig.java b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/TestIcebergConfig.java index cf6dc6d3a7ffe..00906347ec522 100644 --- a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/TestIcebergConfig.java +++ b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/TestIcebergConfig.java @@ -43,6 +43,7 @@ public void testDefaults() .setCatalogCacheSize(10) .setHadoopConfigResources(null) .setHiveStatisticsMergeStrategy(HiveStatisticsMergeStrategy.NONE) + .setStatisticSnapshotRecordDifferenceWeight(0.0) .setMaxPartitionsPerWriter(100) .setMinimumAssignedSplitWeight(0.05) .setParquetDereferencePushdownEnabled(true) @@ -63,6 +64,7 @@ public void testExplicitPropertyMappings() .put("iceberg.minimum-assigned-split-weight", "0.01") .put("iceberg.enable-parquet-dereference-pushdown", "false") .put("iceberg.enable-merge-on-read-mode", "true") + .put("iceberg.statistic-snapshot-record-difference-weight", "1.0") .put("iceberg.hive-statistics-merge-strategy", "USE_NDV") .build(); @@ -75,6 +77,7 @@ public void testExplicitPropertyMappings() .setHadoopConfigResources("/etc/hadoop/conf/core-site.xml") .setMaxPartitionsPerWriter(222) .setMinimumAssignedSplitWeight(0.01) + .setStatisticSnapshotRecordDifferenceWeight(1.0) .setParquetDereferencePushdownEnabled(false) .setMergeOnReadModeEnabled(true) .setHiveStatisticsMergeStrategy(USE_NDV); diff --git a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/hive/TestIcebergDistributedHive.java b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/hive/TestIcebergDistributedHive.java index a64e10e6e0fbf..e573d663dfe8f 100644 --- a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/hive/TestIcebergDistributedHive.java +++ b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/hive/TestIcebergDistributedHive.java @@ -14,6 +14,7 @@ package com.facebook.presto.iceberg.hive; import com.facebook.presto.iceberg.IcebergDistributedTestBase; +import com.google.common.collect.ImmutableMap; import org.testng.annotations.Test; import static com.facebook.presto.iceberg.CatalogType.HIVE; @@ -24,6 +25,18 @@ public class TestIcebergDistributedHive { public TestIcebergDistributedHive() { - super(HIVE); + super(HIVE, ImmutableMap.of("iceberg.hive-statistics-merge-strategy", "USE_NULLS_FRACTION_AND_NDV")); + } + + @Override + public void testNDVsAtSnapshot() + { + // ignore because HMS doesn't support statistics versioning + } + + @Override + public void testStatsByDistance() + { + // ignore because HMS doesn't support statistics versioning } } diff --git a/presto-main/pom.xml b/presto-main/pom.xml index 80d4d123d4f22..d9e1ebfd42044 100644 --- a/presto-main/pom.xml +++ b/presto-main/pom.xml @@ -357,6 +357,16 @@ jjwt + + org.apache.datasketches + datasketches-memory + + + + org.apache.datasketches + datasketches-java + + org.testng diff --git a/presto-main/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java b/presto-main/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java index 09d525311713d..ddeec4874a267 100644 --- a/presto-main/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java +++ b/presto-main/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java @@ -177,6 +177,7 @@ import com.facebook.presto.operator.scalar.SplitToMultimapFunction; import com.facebook.presto.operator.scalar.StringFunctions; import com.facebook.presto.operator.scalar.TDigestFunctions; +import com.facebook.presto.operator.scalar.ThetaSketchFunctions; import com.facebook.presto.operator.scalar.TryFunction; import com.facebook.presto.operator.scalar.TypeOfFunction; import com.facebook.presto.operator.scalar.UrlFunctions; @@ -357,6 +358,7 @@ import static com.facebook.presto.operator.aggregation.noisyaggregation.NoisySumGaussianClippingAggregation.NOISY_SUM_GAUSSIAN_CLIPPING_AGGREGATION; import static com.facebook.presto.operator.aggregation.noisyaggregation.NoisySumGaussianClippingRandomSeedAggregation.NOISY_SUM_GAUSSIAN_CLIPPING_RANDOM_SEED_AGGREGATION; import static com.facebook.presto.operator.aggregation.noisyaggregation.NoisySumGaussianRandomSeedAggregation.NOISY_SUM_GAUSSIAN_RANDOM_SEED_AGGREGATION; +import static com.facebook.presto.operator.aggregation.sketch.theta.ThetaSketchAggregationFunction.THETA_SKETCH; import static com.facebook.presto.operator.scalar.ArrayConcatFunction.ARRAY_CONCAT_FUNCTION; import static com.facebook.presto.operator.scalar.ArrayConstructor.ARRAY_CONSTRUCTOR; import static com.facebook.presto.operator.scalar.ArrayFlattenFunction.ARRAY_FLATTEN_FUNCTION; @@ -925,6 +927,8 @@ private List getBuildInFunctions(FeaturesConfig featuresC .scalars(TDigestOperators.class) .scalars(TDigestFunctions.class) .functions(TDIGEST_AGG, TDIGEST_AGG_WITH_WEIGHT, TDIGEST_AGG_WITH_WEIGHT_AND_COMPRESSION) + .function(THETA_SKETCH) + .scalars(ThetaSketchFunctions.class) .function(MergeTDigestFunction.MERGE) .sqlInvokedScalar(MapNormalizeFunction.class) .sqlInvokedScalars(ArraySqlFunctions.class) diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/sketch/theta/ThetaSketchAggregationFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/sketch/theta/ThetaSketchAggregationFunction.java new file mode 100644 index 0000000000000..d0e434f6f6bf3 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/sketch/theta/ThetaSketchAggregationFunction.java @@ -0,0 +1,141 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.aggregation.sketch.theta; + +import com.facebook.airlift.log.Logger; +import com.facebook.presto.bytecode.DynamicClassLoader; +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.block.BlockBuilder; +import com.facebook.presto.common.type.StandardTypes; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.metadata.BoundVariables; +import com.facebook.presto.metadata.FunctionAndTypeManager; +import com.facebook.presto.metadata.SqlAggregationFunction; +import com.facebook.presto.operator.aggregation.AccumulatorCompiler; +import com.facebook.presto.operator.aggregation.BuiltInAggregationFunctionImplementation; +import com.facebook.presto.operator.aggregation.state.StateCompiler; +import com.facebook.presto.spi.function.aggregation.Accumulator; +import com.facebook.presto.spi.function.aggregation.AggregationMetadata; +import com.facebook.presto.spi.function.aggregation.GroupedAccumulator; +import com.google.common.collect.ImmutableList; +import io.airlift.slice.Slice; +import io.airlift.slice.Slices; + +import java.lang.invoke.MethodHandle; +import java.util.List; + +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.TypeSignature.parseTypeSignature; +import static com.facebook.presto.common.type.VarbinaryType.VARBINARY; +import static com.facebook.presto.operator.aggregation.AggregationUtils.generateAggregationName; +import static com.facebook.presto.spi.function.Signature.typeVariable; +import static com.facebook.presto.spi.function.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.BLOCK_INDEX; +import static com.facebook.presto.spi.function.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.NULLABLE_BLOCK_INPUT_CHANNEL; +import static com.facebook.presto.spi.function.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.STATE; +import static com.facebook.presto.util.Reflection.methodHandle; +import static com.google.common.collect.ImmutableList.toImmutableList; + +public class ThetaSketchAggregationFunction + extends SqlAggregationFunction +{ + private static final Logger log = Logger.get(ThetaSketchAggregationFunction.class); + public static final String NAME = "sketch_theta"; + + public static final ThetaSketchAggregationFunction THETA_SKETCH = new ThetaSketchAggregationFunction(); + + private static final MethodHandle OUTPUT_FUNCTION = methodHandle(ThetaSketchAggregationFunction.class, "output", ThetaSketchAggregationState.class, BlockBuilder.class); + private static final MethodHandle INPUT_FUNCTION = methodHandle(ThetaSketchAggregationFunction.class, "input", Type.class, ThetaSketchAggregationState.class, Block.class, int.class); + private static final MethodHandle MERGE_FUNCTION = methodHandle(ThetaSketchAggregationFunction.class, "merge", ThetaSketchAggregationState.class, ThetaSketchAggregationState.class); + + public ThetaSketchAggregationFunction() + { + super(NAME, + ImmutableList.of(typeVariable("T")), + ImmutableList.of(), + parseTypeSignature(StandardTypes.VARBINARY), + ImmutableList.of(parseTypeSignature("T"))); + } + + @Override + public String getDescription() + { + return "calculates a theta sketch of the selected input column"; + } + + @Override + public BuiltInAggregationFunctionImplementation specialize(BoundVariables boundVariables, int arity, FunctionAndTypeManager functionAndTypeManager) + { + DynamicClassLoader classLoader = new DynamicClassLoader(ThetaSketchAggregationFunction.class.getClassLoader()); + Type type = boundVariables.getTypeVariable("T"); + List inputTypes = ImmutableList.of(type); + + AggregationMetadata metadata = new AggregationMetadata( + generateAggregationName(NAME, type.getTypeSignature(), inputTypes.stream().map(Type::getTypeSignature).collect(toImmutableList())), + createInputParameterMetadata(type), + INPUT_FUNCTION.bindTo(type), + MERGE_FUNCTION, + OUTPUT_FUNCTION, + ImmutableList.of(new AggregationMetadata.AccumulatorStateDescriptor( + ThetaSketchAggregationState.class, + StateCompiler.generateStateSerializer(ThetaSketchAggregationState.class, classLoader), + StateCompiler.generateStateFactory(ThetaSketchAggregationState.class, classLoader))), + VARBINARY); + + Class accumulatorClass = AccumulatorCompiler.generateAccumulatorClass( + Accumulator.class, + metadata, + classLoader); + Class groupedAccumulatorClass = AccumulatorCompiler.generateAccumulatorClass( + GroupedAccumulator.class, + metadata, + classLoader); + return new BuiltInAggregationFunctionImplementation(NAME, inputTypes, ImmutableList.of(BIGINT), VARBINARY, + true, false, metadata, accumulatorClass, groupedAccumulatorClass); + } + + private static List createInputParameterMetadata(Type type) + { + return ImmutableList.of(new AggregationMetadata.ParameterMetadata(STATE), new AggregationMetadata.ParameterMetadata(NULLABLE_BLOCK_INPUT_CHANNEL, type), new AggregationMetadata.ParameterMetadata(BLOCK_INDEX)); + } + + public static void input(Type type, ThetaSketchAggregationState state, Block block, int position) + { + if (block.isNull(position)) { + return; + } + if (type.getJavaType().equals(Long.class) || type.getJavaType() == long.class) { + state.getSketch().update(type.getLong(block, position)); + } + else if (type.getJavaType().equals(Double.class) || type.getJavaType() == double.class) { + state.getSketch().update(type.getDouble(block, position)); + } + else if (type.getJavaType().equals(String.class) || type.getJavaType().equals(Slice.class)) { + state.getSketch().update(type.getSlice(block, position).getBytes()); + } + else { + throw new RuntimeException("unsupported sketch column type: " + type + " (java type: " + type.getJavaType() + ")"); + } + } + + public static void merge(ThetaSketchAggregationState state, ThetaSketchAggregationState otherState) + { + state.getSketch().union(otherState.getSketch().getResult()); + } + + public static void output(ThetaSketchAggregationState state, BlockBuilder out) + { + Slice output = Slices.wrappedBuffer(state.getSketch().getResult().toByteArray()); + VARBINARY.writeSlice(out, output); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/sketch/theta/ThetaSketchAggregationState.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/sketch/theta/ThetaSketchAggregationState.java new file mode 100644 index 0000000000000..4c6179a9fa91e --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/sketch/theta/ThetaSketchAggregationState.java @@ -0,0 +1,25 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.aggregation.sketch.theta; + +import com.facebook.presto.spi.function.AccumulatorStateMetadata; +import org.apache.datasketches.theta.Union; + +@AccumulatorStateMetadata(stateSerializerClass = ThetaSketchStateSerializer.class, stateFactoryClass = ThetaSketchStateFactory.class) +public interface ThetaSketchAggregationState +{ + Union getSketch(); + + void setSketch(Union value); +} diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/sketch/theta/ThetaSketchStateFactory.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/sketch/theta/ThetaSketchStateFactory.java new file mode 100644 index 0000000000000..6866a38528f8b --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/sketch/theta/ThetaSketchStateFactory.java @@ -0,0 +1,110 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.aggregation.sketch.theta; + +import com.facebook.presto.common.array.ObjectBigArray; +import com.facebook.presto.operator.aggregation.state.AbstractGroupedAccumulatorState; +import com.facebook.presto.spi.function.AccumulatorState; +import com.facebook.presto.spi.function.AccumulatorStateFactory; +import org.apache.datasketches.theta.Union; +import org.openjdk.jol.info.ClassLayout; + +import static java.util.Objects.requireNonNull; + +public class ThetaSketchStateFactory + implements AccumulatorStateFactory +{ + @Override + public SingleThetaSketchState createSingleState() + { + return new SingleThetaSketchState(); + } + + @Override + public Class getSingleStateClass() + { + return SingleThetaSketchState.class; + } + + @Override + public ThetaSketchAggregationState createGroupedState() + { + return new GroupedThetaSketchState(); + } + + @Override + public Class getGroupedStateClass() + { + return GroupedThetaSketchState.class; + } + + public static final class SingleThetaSketchState + implements ThetaSketchAggregationState, AccumulatorState + { + private Union sketch = Union.builder().buildUnion(); + + @Override + public Union getSketch() + { + return sketch; + } + + @Override + public void setSketch(Union sketch) + { + this.sketch = sketch; + } + + @Override + public long getEstimatedSize() + { + return sketch.getCurrentBytes(); + } + } + + public static final class GroupedThetaSketchState + extends AbstractGroupedAccumulatorState + implements ThetaSketchAggregationState + { + private static final int INSTANCE_SIZE = ClassLayout.parseClass(GroupedThetaSketchState.class).instanceSize(); + private final ObjectBigArray sketches = new ObjectBigArray<>(); + + @Override + public Union getSketch() + { + if (sketches.get(getGroupId()) == null) { + setSketch(Union.builder().buildUnion()); + } + return sketches.get(getGroupId()); + } + + @Override + public void setSketch(Union sketch) + { + sketches.set(getGroupId(), requireNonNull(sketch, "sketch is null")); + } + + @Override + public long getEstimatedSize() + { + return INSTANCE_SIZE + sketches.sizeOf(); + } + + @Override + public void ensureCapacity(long size) + { + sketches.ensureCapacity(size); + } + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/sketch/theta/ThetaSketchStateSerializer.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/sketch/theta/ThetaSketchStateSerializer.java new file mode 100644 index 0000000000000..c4e0ffa6fe558 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/sketch/theta/ThetaSketchStateSerializer.java @@ -0,0 +1,54 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.aggregation.sketch.theta; + +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.block.BlockBuilder; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.spi.function.AccumulatorStateSerializer; +import io.airlift.slice.Slice; +import io.airlift.slice.Slices; +import org.apache.datasketches.memory.WritableMemory; +import org.apache.datasketches.theta.SetOperation; +import org.apache.datasketches.theta.Union; + +import static com.facebook.presto.common.type.VarbinaryType.VARBINARY; +import static com.google.common.base.Verify.verify; +import static org.apache.datasketches.common.Family.UNION; + +public class ThetaSketchStateSerializer + implements AccumulatorStateSerializer +{ + @Override + public Type getSerializedType() + { + return VARBINARY; + } + + @Override + public void serialize(ThetaSketchAggregationState state, BlockBuilder out) + { + Slice stateMemory = Slices.wrappedBuffer(state.getSketch().toByteArray()); + VARBINARY.writeSlice(out, stateMemory); + } + + @Override + public void deserialize(Block block, int index, ThetaSketchAggregationState state) + { + Slice data = VARBINARY.getSlice(block, index); + SetOperation op = Union.wrap(WritableMemory.writableWrap(data.getBytes())); + verify(op.getFamily() == UNION); + state.setSketch((Union) op); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ThetaSketchFunctions.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ThetaSketchFunctions.java new file mode 100644 index 0000000000000..c024222a3dd46 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ThetaSketchFunctions.java @@ -0,0 +1,71 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.scalar; + +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.block.BlockBuilder; +import com.facebook.presto.common.type.RowType; +import com.facebook.presto.spi.function.Description; +import com.facebook.presto.spi.function.ScalarFunction; +import com.facebook.presto.spi.function.SqlType; +import com.google.common.collect.ImmutableList; +import io.airlift.slice.Slice; +import org.apache.datasketches.memory.Memory; +import org.apache.datasketches.theta.CompactSketch; + +import java.nio.ByteOrder; +import java.util.Optional; + +import static com.facebook.presto.common.type.DoubleType.DOUBLE; +import static com.facebook.presto.common.type.IntegerType.INTEGER; + +public class ThetaSketchFunctions +{ + private ThetaSketchFunctions() + { + } + + @ScalarFunction(value = "sketch_theta_estimate") + @Description("Get the estimate of unique values from a theta sketch") + @SqlType("double") + public static double thetaSketchEstimate(@SqlType("varbinary") Slice input) + { + CompactSketch sketch = CompactSketch.wrap(Memory.wrap(input.toByteBuffer(), ByteOrder.nativeOrder())); + return sketch.getEstimate(); + } + + private static final RowType SUMMARY_TYPE = RowType.from(ImmutableList.of( + new RowType.Field(Optional.of("estimate"), DOUBLE), + new RowType.Field(Optional.of("theta"), DOUBLE), + new RowType.Field(Optional.of("upper_bound_std"), DOUBLE), + new RowType.Field(Optional.of("lower_bound_std"), DOUBLE), + new RowType.Field(Optional.of("retained_entries"), INTEGER))); + + @ScalarFunction(value = "sketch_theta_summary") + @Description("parses a brief summary from a theta sketch") + @SqlType("row(estimate double, theta double, upper_bound_std double, lower_bound_std double, retained_entries int)") + public static Block thetaSketchSummary(@SqlType("varbinary") Slice input) + { + CompactSketch sketch = CompactSketch.wrap(Memory.wrap(input.toByteBuffer(), ByteOrder.nativeOrder())); + BlockBuilder output = SUMMARY_TYPE.createBlockBuilder(null, 1); + BlockBuilder row = output.beginBlockEntry(); + DOUBLE.writeDouble(row, sketch.getEstimate()); + DOUBLE.writeDouble(row, sketch.getTheta()); + DOUBLE.writeDouble(row, sketch.getUpperBound(1)); + DOUBLE.writeDouble(row, sketch.getLowerBound(1)); + INTEGER.writeLong(row, sketch.getRetainedEntries()); + output.closeEntry(); + return output.build().getBlock(0); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/StatisticsAggregationPlanner.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/StatisticsAggregationPlanner.java index ebb63e48f37d8..0f17eaf78d2d3 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/StatisticsAggregationPlanner.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/StatisticsAggregationPlanner.java @@ -14,15 +14,13 @@ package com.facebook.presto.sql.planner; import com.facebook.presto.common.type.Type; -import com.facebook.presto.operator.aggregation.MaxDataSizeForStats; -import com.facebook.presto.operator.aggregation.SumDataSizeForStats; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.VariableAllocator; import com.facebook.presto.spi.function.FunctionHandle; +import com.facebook.presto.spi.function.FunctionMetadata; import com.facebook.presto.spi.function.StandardFunctionResolution; import com.facebook.presto.spi.plan.AggregationNode; import com.facebook.presto.spi.relation.CallExpression; -import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.spi.statistics.ColumnStatisticMetadata; import com.facebook.presto.spi.statistics.ColumnStatisticType; @@ -41,7 +39,7 @@ import java.util.Optional; import static com.facebook.presto.common.type.BigintType.BIGINT; -import static com.facebook.presto.common.type.BooleanType.BOOLEAN; +import static com.facebook.presto.common.type.UnknownType.UNKNOWN; import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED; import static com.facebook.presto.spi.statistics.TableStatisticType.ROW_COUNT; import static com.google.common.base.Verify.verify; @@ -99,7 +97,7 @@ public TableStatisticAggregation createStatisticsAggregation(TableStatisticsMeta ColumnStatisticType statisticType = columnStatisticMetadata.getStatisticType(); VariableReferenceExpression inputVariable = columnToVariableMap.get(columnName); verify(inputVariable != null, "inputVariable is null"); - ColumnStatisticsAggregation aggregation = createColumnAggregation(statisticType, inputVariable); + ColumnStatisticsAggregation aggregation = createColumnAggregation(columnStatisticMetadata, inputVariable); VariableReferenceExpression variable = variableAllocator.newVariable(statisticType + ":" + columnName, aggregation.getOutputType()); aggregations.put(variable, aggregation.getAggregation()); descriptor.addColumnStatistic(columnStatisticMetadata, variable); @@ -109,38 +107,18 @@ public TableStatisticAggregation createStatisticsAggregation(TableStatisticsMeta return new TableStatisticAggregation(aggregation, descriptor.build()); } - private ColumnStatisticsAggregation createColumnAggregation(ColumnStatisticType statisticType, VariableReferenceExpression input) + private ColumnStatisticsAggregation createColumnAggregation(ColumnStatisticMetadata columnStatisticMetadata, VariableReferenceExpression input) { - switch (statisticType) { - case MIN_VALUE: - return createAggregation("min", input, input.getType(), input.getType()); - case MAX_VALUE: - return createAggregation("max", input, input.getType(), input.getType()); - case NUMBER_OF_DISTINCT_VALUES: - return createAggregation("approx_distinct", input, input.getType(), BIGINT); - case NUMBER_OF_NON_NULL_VALUES: - return createAggregation("count", input, input.getType(), BIGINT); - case NUMBER_OF_TRUE_VALUES: - return createAggregation("count_if", input, BOOLEAN, BIGINT); - case TOTAL_SIZE_IN_BYTES: - return createAggregation(SumDataSizeForStats.NAME, input, input.getType(), BIGINT); - case MAX_VALUE_SIZE_IN_BYTES: - return createAggregation(MaxDataSizeForStats.NAME, input, input.getType(), BIGINT); - default: - throw new IllegalArgumentException("Unsupported statistic type: " + statisticType); - } - } - - private ColumnStatisticsAggregation createAggregation(String functionName, RowExpression input, Type inputType, Type outputType) - { - FunctionHandle functionHandle = functionAndTypeResolver.lookupFunction(functionName, TypeSignatureProvider.fromTypes(ImmutableList.of(inputType))); - Type resolvedType = functionAndTypeResolver.getType(getOnlyElement(functionAndTypeResolver.getFunctionMetadata(functionHandle).getArgumentTypes())); - verify(resolvedType.equals(inputType), "resolved function input type does not match the input type: %s != %s", resolvedType, inputType); + FunctionHandle functionHandle = functionAndTypeResolver.lookupFunction(columnStatisticMetadata.getFunctionName(), TypeSignatureProvider.fromTypes(ImmutableList.of(input.getType()))); + FunctionMetadata functionMeta = functionAndTypeResolver.getFunctionMetadata(functionHandle); + Type inputType = functionAndTypeResolver.getType(getOnlyElement(functionMeta.getArgumentTypes())); + Type outputType = functionAndTypeResolver.getType(functionMeta.getReturnType()); + verify(inputType.equals(input.getType()) || input.getType().equals(UNKNOWN), "resolved function input type does not match the input type: %s != %s", inputType, input.getType()); return new ColumnStatisticsAggregation( new AggregationNode.Aggregation( new CallExpression( input.getSourceLocation(), - functionName, + columnStatisticMetadata.getFunctionName(), functionHandle, outputType, ImmutableList.of(input)), diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/StatisticAggregationsDescriptor.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/StatisticAggregationsDescriptor.java index 2325320820ad9..95878fb4310f6 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/StatisticAggregationsDescriptor.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/StatisticAggregationsDescriptor.java @@ -173,7 +173,7 @@ public void serialize(ColumnStatisticMetadata value, JsonGenerator gen, Serializ @VisibleForTesting static String serialize(ColumnStatisticMetadata value) { - return value.getStatisticType().name() + ":" + value.getColumnName(); + return value.getStatisticType().name() + ":" + value.getFunctionName() + ":" + value.getColumnName(); } } @@ -189,11 +189,9 @@ public ColumnStatisticMetadata deserializeKey(String key, DeserializationContext @VisibleForTesting static ColumnStatisticMetadata deserialize(String value) { - int separatorIndex = value.indexOf(':'); - checkArgument(separatorIndex >= 0, "separator not found: %s", value); - String statisticType = value.substring(0, separatorIndex); - String column = value.substring(separatorIndex + 1); - return new ColumnStatisticMetadata(column, ColumnStatisticType.valueOf(statisticType)); + String[] values = value.split(":", 3); + checkArgument(values.length == 3, "separator(s) not found: %s", value); + return new ColumnStatisticMetadata(values[2], ColumnStatisticType.valueOf(values[0]), values[1]); } } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/rewrite/ShowStatsRewrite.java b/presto-main/src/main/java/com/facebook/presto/sql/rewrite/ShowStatsRewrite.java index a303f0946f28d..60f0b01b662aa 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/rewrite/ShowStatsRewrite.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/rewrite/ShowStatsRewrite.java @@ -101,14 +101,14 @@ public class ShowStatsRewrite @Override public Statement rewrite(Session session, - Metadata metadata, - SqlParser parser, - Optional queryExplainer, - Statement node, - List parameters, - Map, Expression> parameterLookup, - AccessControl accessControl, - WarningCollector warningCollector) + Metadata metadata, + SqlParser parser, + Optional queryExplainer, + Statement node, + List parameters, + Map, Expression> parameterLookup, + AccessControl accessControl, + WarningCollector warningCollector) { return (Statement) new Visitor(metadata, session, parameters, queryExplainer, warningCollector).process(node, null); } diff --git a/presto-main/src/test/java/com/facebook/presto/operator/TestTableFinishOperator.java b/presto-main/src/test/java/com/facebook/presto/operator/TestTableFinishOperator.java index bcfae9c88a424..ff5b789c8607b 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/TestTableFinishOperator.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/TestTableFinishOperator.java @@ -101,7 +101,7 @@ public void testStatisticsAggregation() { TestingTableFinisher tableFinisher = new TestingTableFinisher(); TestingPageSinkCommitter pageSinkCommitter = new TestingPageSinkCommitter(); - ColumnStatisticMetadata statisticMetadata = new ColumnStatisticMetadata("column", MAX_VALUE); + ColumnStatisticMetadata statisticMetadata = MAX_VALUE.getColumnStatisticMetadata("column"); StatisticAggregationsDescriptor descriptor = new StatisticAggregationsDescriptor<>( ImmutableMap.of(), ImmutableMap.of(), @@ -182,7 +182,7 @@ public void testTableWriteCommit() { TestingTableFinisher tableFinisher = new TestingTableFinisher(); TestingPageSinkCommitter pageSinkCommitter = new TestingPageSinkCommitter(); - ColumnStatisticMetadata statisticMetadata = new ColumnStatisticMetadata("column", MAX_VALUE); + ColumnStatisticMetadata statisticMetadata = MAX_VALUE.getColumnStatisticMetadata("column"); StatisticAggregationsDescriptor descriptor = new StatisticAggregationsDescriptor<>( ImmutableMap.of(), ImmutableMap.of(), diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestThetaSketchAggregationFunction.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestThetaSketchAggregationFunction.java new file mode 100644 index 0000000000000..15ca91f3fc1df --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestThetaSketchAggregationFunction.java @@ -0,0 +1,62 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.aggregation; + +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.block.BlockBuilder; +import com.facebook.presto.common.type.SqlVarbinary; +import com.facebook.presto.operator.aggregation.sketch.theta.ThetaSketchAggregationFunction; +import com.google.common.collect.ImmutableList; +import org.apache.datasketches.theta.Union; + +import java.util.List; + +import static com.facebook.presto.common.type.DoubleType.DOUBLE; + +public class TestThetaSketchAggregationFunction + extends AbstractTestAggregationFunction +{ + @Override + public Block[] getSequenceBlocks(int start, int length) + { + BlockBuilder blockBuilder = DOUBLE.createBlockBuilder(null, length); + for (int i = start; i < start + length; i++) { + DOUBLE.writeDouble(blockBuilder, i); + } + return new Block[] {blockBuilder.build()}; + } + + @Override + protected String getFunctionName() + { + return ThetaSketchAggregationFunction.NAME; + } + + @Override + protected List getFunctionParameterTypes() + { + return ImmutableList.of("double"); + } + + @Override + public Object getExpectedValue(int start, int length) + { + Block values = getSequenceBlocks(start, length)[0]; + Union union = Union.builder().buildUnion(); + for (int i = 0; i < length; i++) { + union.update(DOUBLE.getDouble(values, i)); + } + return new SqlVarbinary(union.getResult().toByteArray()); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestThetaSketchFunctions.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestThetaSketchFunctions.java new file mode 100644 index 0000000000000..0b296ac80447d --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestThetaSketchFunctions.java @@ -0,0 +1,133 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.scalar; + +import com.facebook.presto.common.type.SqlVarbinary; +import org.apache.datasketches.theta.CompactSketch; +import org.apache.datasketches.theta.Union; +import org.testng.annotations.Test; + +import static com.facebook.presto.common.type.DoubleType.DOUBLE; +import static com.facebook.presto.common.type.IntegerType.INTEGER; +import static java.lang.String.format; + +public class TestThetaSketchFunctions + extends AbstractTestFunctions +{ + @Test + public void testNullSketch() + { + functionAssertions.assertFunction("sketch_theta_estimate(CAST(NULL as VARBINARY))", DOUBLE, null); + } + + @Test + public void testEstimateEmptySketch() + { + Union union = Union.builder().buildUnion(); + functionAssertions.assertFunction( + format("sketch_theta_estimate(CAST(X'%s' as VARBINARY))", + toVarbinarySql(union.getResult().toByteArray())), + DOUBLE, + 0.0); + } + + @Test + public void testEstimateSingleValue() + { + Union union = Union.builder().buildUnion(); + union.update(1); + functionAssertions.assertFunction( + format("sketch_theta_estimate(CAST(X'%s' as VARBINARY))", + toVarbinarySql(union.getResult().toByteArray())), + DOUBLE, + 1.0); + } + + @Test + public void testEstimateManyValues() + { + Union union = Union.builder().buildUnion(); + int size = 100; + for (int i = 0; i < size; i++) { + union.update(i); + } + functionAssertions.assertFunction( + format("sketch_theta_estimate(CAST(X'%s' as VARBINARY))", + toVarbinarySql(union.getResult().toByteArray())), + DOUBLE, + (double) size); + } + + @Test + public void testSummaryNull() + { + functionAssertions.assertFunction("sketch_theta_summary(CAST(NULL as VARBINARY)).estimate", + DOUBLE, + null); + } + + @Test + public void testSummarySingle() + { + Union union = Union.builder().buildUnion(); + union.update(1); + CompactSketch compactSketch = union.getResult(); + summaryMatches(compactSketch, union.getResult().toByteArray()); + } + + @Test + public void testSummaryMany() + { + Union union = Union.builder().buildUnion(); + int size = 100; + for (int i = 0; i < size; i++) { + union.update(i); + } + summaryMatches(union.getResult(), union.getResult().toByteArray()); + } + + private void summaryMatches(CompactSketch expected, byte[] input) + { + functionAssertions.assertFunction( + format("sketch_theta_summary(CAST(X'%s' as VARBINARY)).estimate", + toVarbinarySql(input)), + DOUBLE, + expected.getEstimate()); + functionAssertions.assertFunction( + format("sketch_theta_summary(CAST(X'%s' as VARBINARY)).theta", + toVarbinarySql(input)), + DOUBLE, + expected.getTheta()); + functionAssertions.assertFunction( + format("sketch_theta_summary(CAST(X'%s' as VARBINARY)).upper_bound_std", + toVarbinarySql(input)), + DOUBLE, + expected.getUpperBound(1)); + functionAssertions.assertFunction( + format("sketch_theta_summary(CAST(X'%s' as VARBINARY)).lower_bound_std", + toVarbinarySql(input)), + DOUBLE, + expected.getLowerBound(1)); + functionAssertions.assertFunction( + format("sketch_theta_summary(CAST(X'%s' as VARBINARY)).retained_entries", + toVarbinarySql(input)), + INTEGER, + expected.getRetainedEntries()); + } + + private static String toVarbinarySql(byte[] data) + { + return new SqlVarbinary(data).toString().replaceAll("\\s+", " "); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/plan/TestStatisticAggregationsDescriptor.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/plan/TestStatisticAggregationsDescriptor.java index b68df26152f3f..fe4081cb64252 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/plan/TestStatisticAggregationsDescriptor.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/plan/TestStatisticAggregationsDescriptor.java @@ -31,7 +31,7 @@ public void testColumnStatisticMetadataKeySerializationRoundTrip() { for (String column : COLUMNS) { for (ColumnStatisticType type : ColumnStatisticType.values()) { - ColumnStatisticMetadata expected = new ColumnStatisticMetadata(column, type); + ColumnStatisticMetadata expected = type.getColumnStatisticMetadata(column); assertEquals(deserialize(serialize(expected)), expected); } } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/plan/TestStatisticsWriterNode.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/plan/TestStatisticsWriterNode.java index fad391db44333..b2d882888a5f8 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/plan/TestStatisticsWriterNode.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/plan/TestStatisticsWriterNode.java @@ -28,7 +28,6 @@ import com.facebook.presto.spi.plan.PlanNodeId; import com.facebook.presto.spi.plan.ValuesNode; import com.facebook.presto.spi.relation.VariableReferenceExpression; -import com.facebook.presto.spi.statistics.ColumnStatisticMetadata; import com.facebook.presto.spi.statistics.ColumnStatisticType; import com.facebook.presto.sql.parser.SqlParser; import com.facebook.presto.testing.TestingHandleResolver; @@ -92,7 +91,7 @@ private static StatisticAggregationsDescriptor crea VariableAllocator variableAllocator = new VariableAllocator(); for (String column : COLUMNS) { for (ColumnStatisticType type : ColumnStatisticType.values()) { - builder.addColumnStatistic(new ColumnStatisticMetadata(column, type), testVariable(variableAllocator)); + builder.addColumnStatistic(type.getColumnStatisticMetadata(column), testVariable(variableAllocator)); } builder.addGrouping(column, testVariable(variableAllocator)); } diff --git a/presto-pinot-toolkit/pom.xml b/presto-pinot-toolkit/pom.xml index 41c683c6c1722..851e46ef59b59 100644 --- a/presto-pinot-toolkit/pom.xml +++ b/presto-pinot-toolkit/pom.xml @@ -415,7 +415,6 @@ org.apache.datasketches datasketches-java - 1.2.0-incubating test diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/statistics/ColumnStatisticMetadata.java b/presto-spi/src/main/java/com/facebook/presto/spi/statistics/ColumnStatisticMetadata.java index f473209a19bfe..ff32b0167ed62 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/statistics/ColumnStatisticMetadata.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/statistics/ColumnStatisticMetadata.java @@ -25,13 +25,17 @@ public class ColumnStatisticMetadata private final String columnName; private final ColumnStatisticType statisticType; + private final String functionName; + @JsonCreator public ColumnStatisticMetadata( @JsonProperty("columnName") String columnName, - @JsonProperty("statisticType") ColumnStatisticType statisticType) + @JsonProperty("statisticType") ColumnStatisticType statisticType, + @JsonProperty("functionName") String functionName) { this.columnName = requireNonNull(columnName, "columnName is null"); this.statisticType = requireNonNull(statisticType, "statisticType is null"); + this.functionName = requireNonNull(functionName, "functionName is null"); } @JsonProperty @@ -46,6 +50,12 @@ public ColumnStatisticType getStatisticType() return statisticType; } + @JsonProperty + public String getFunctionName() + { + return functionName; + } + @Override public boolean equals(Object o) { @@ -57,13 +67,14 @@ public boolean equals(Object o) } ColumnStatisticMetadata that = (ColumnStatisticMetadata) o; return Objects.equals(columnName, that.columnName) && - statisticType == that.statisticType; + statisticType == that.statisticType && + Objects.equals(functionName, that.functionName); } @Override public int hashCode() { - return Objects.hash(columnName, statisticType); + return Objects.hash(columnName, statisticType, functionName); } @Override @@ -72,6 +83,7 @@ public String toString() return "ColumnStatisticMetadata{" + "columnName='" + columnName + '\'' + ", statisticType=" + statisticType + + ", functionName=" + functionName + '}'; } } diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/statistics/ColumnStatisticType.java b/presto-spi/src/main/java/com/facebook/presto/spi/statistics/ColumnStatisticType.java index 70301944ae59b..793c5acec9606 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/statistics/ColumnStatisticType.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/statistics/ColumnStatisticType.java @@ -15,11 +15,27 @@ public enum ColumnStatisticType { - MIN_VALUE, - MAX_VALUE, - NUMBER_OF_DISTINCT_VALUES, - NUMBER_OF_NON_NULL_VALUES, - NUMBER_OF_TRUE_VALUES, - MAX_VALUE_SIZE_IN_BYTES, - TOTAL_SIZE_IN_BYTES, + MAX_VALUE("max"), + MAX_VALUE_SIZE_IN_BYTES("max_data_size_for_stats"), + MIN_VALUE("min"), + NUMBER_OF_DISTINCT_VALUES("approx_distinct"), + NUMBER_OF_NON_NULL_VALUES("count"), + NUMBER_OF_TRUE_VALUES("count_if"), + TOTAL_SIZE_IN_BYTES("sum_data_size_for_stats"); + private final String functionName; + + ColumnStatisticType(String functionName) + { + this.functionName = functionName; + } + + public ColumnStatisticMetadata getColumnStatisticMetadata(String columnName) + { + return new ColumnStatisticMetadata(columnName, this, this.functionName); + } + + public ColumnStatisticMetadata getColumnStatisticMetadataWithCustomFunction(String columnName, String functionName) + { + return new ColumnStatisticMetadata(columnName, this, functionName); + } }