diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java index 1a4bd2982cb11..7e1d774bf7c5b 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java @@ -68,6 +68,8 @@ public final class UnsafeFixedWidthAggregationMap { */ private long[] groupingKeyConversionScratchSpace = new long[1024 / 8]; + private final boolean enablePerfMetrics; + /** * @return true if UnsafeFixedWidthAggregationMap supports grouping keys with the given schema, * false otherwise. @@ -102,19 +104,22 @@ public static boolean supportsAggregationBufferSchema(StructType schema) { * @param groupingKeySchema the schema of the grouping key, used for row conversion. * @param allocator the memory allocator used to allocate our Unsafe memory structures. * @param initialCapacity the initial capacity of the map (a sizing hint to avoid re-hashing). + * @param enablePerfMetrics if true, performance metrics will be recorded (has minor perf impact) */ public UnsafeFixedWidthAggregationMap( Row emptyAggregationBuffer, StructType aggregationBufferSchema, StructType groupingKeySchema, MemoryAllocator allocator, - int initialCapacity) { + int initialCapacity, + boolean enablePerfMetrics) { this.emptyAggregationBuffer = convertToUnsafeRow(emptyAggregationBuffer, aggregationBufferSchema); this.aggregationBufferSchema = aggregationBufferSchema; this.groupingKeyToUnsafeRowConverter = new UnsafeRowConverter(groupingKeySchema); this.groupingKeySchema = groupingKeySchema; - this.map = new BytesToBytesMap(allocator, initialCapacity); + this.map = new BytesToBytesMap(allocator, initialCapacity, enablePerfMetrics); + this.enablePerfMetrics = enablePerfMetrics; } /** @@ -232,4 +237,13 @@ public void free() { map.free(); } + public void printPerfMetrics() { + if (!enablePerfMetrics) { + throw new IllegalStateException("Perf metrics not enabled"); + } + System.out.println("Average probes per lookup: " + map.getAverageProbesPerLookup()); + System.out.println("Time spent resizing (ms): " + map.getTimeSpentResizingMs()); + System.out.println("Total memory consumption (bytes): " + map.getTotalMemoryConsumption()); + } + } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala index 82f6ca142e922..822b23b40e9fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -290,7 +290,8 @@ case class GeneratedAggregate( aggregationBufferSchema, groupKeySchema, MemoryAllocator.UNSAFE, - 1024 * 16 + 1024 * 16, + false ) while (iter.hasNext) { diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index f7857db126d88..63afbea6e9060 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -139,17 +139,38 @@ public final class BytesToBytesMap { private final Location loc; + private final boolean enablePerfMetrics; - public BytesToBytesMap(MemoryAllocator allocator, int initialCapacity, double loadFactor) { + private long timeSpentResizingMs = 0; + + private int numResizes = 0; + + private long numProbes = 0; + + private long numKeyLookups = 0; + + public BytesToBytesMap( + MemoryAllocator allocator, + int initialCapacity, + double loadFactor, + boolean enablePerfMetrics) { this.inHeap = allocator instanceof HeapMemoryAllocator; this.allocator = allocator; this.loadFactor = loadFactor; this.loc = new Location(); + this.enablePerfMetrics = enablePerfMetrics; allocate(initialCapacity); } public BytesToBytesMap(MemoryAllocator allocator, int initialCapacity) { - this(allocator, initialCapacity, 0.70); + this(allocator, initialCapacity, 0.70, false); + } + + public BytesToBytesMap( + MemoryAllocator allocator, + int initialCapacity, + boolean enablePerfMetrics) { + this(allocator, initialCapacity, 0.70, enablePerfMetrics); } @Override @@ -205,10 +226,16 @@ public Location lookup( Object keyBaseObject, long keyBaseOffset, int keyRowLengthBytes) { + if (enablePerfMetrics) { + numKeyLookups++; + } final int hashcode = HASHER.hashUnsafeWords(keyBaseObject, keyBaseOffset, keyRowLengthBytes); int pos = hashcode & mask; int step = 1; while (true) { + if (enablePerfMetrics) { + numProbes++; + } if (!bitset.isSet(pos)) { // This is a new key. return loc.with(pos, hashcode, false); @@ -484,10 +511,36 @@ public long getTotalMemoryConsumption() { longArray.memoryBlock().size()); } + /** + * Returns the total amount of time spent resizing this map (in milliseconds). + */ + public long getTimeSpentResizingMs() { + if (!enablePerfMetrics) { + throw new IllegalStateException(); + } + return timeSpentResizingMs; + } + + + /** + * Returns the average number of probes per key lookup. + */ + public double getAverageProbesPerLookup() { + if (!enablePerfMetrics) { + throw new IllegalStateException(); + } + return (1.0 * numProbes) / numKeyLookups; + } + /** * Grows the size of the hash table and re-hash everything. */ private void growAndRehash() { + long resizeStartTime = -1; + if (enablePerfMetrics) { + numResizes++; + resizeStartTime = System.currentTimeMillis(); + } // Store references to the old data structures to be used when we re-hash final LongArray oldLongArray = longArray; final BitSet oldBitSet = bitset; @@ -526,6 +579,10 @@ private void growAndRehash() { // Deallocate the old data structures. allocator.free(oldLongArray.memoryBlock()); allocator.free(oldBitSet.memoryBlock()); + if (enablePerfMetrics) { + System.out.println("Resizing took " + (System.currentTimeMillis() - resizeStartTime) + " ms"); + timeSpentResizingMs += System.currentTimeMillis() - resizeStartTime; + } } /** Returns the next number greater or equal num that is power of 2. */