diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java index f7c22dddb8cc0..1930b03a18cb6 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java @@ -31,6 +31,10 @@ *
  • {@link Integer}
  • *
  • {@link Long}
  • *
  • {@link String}
  • + *
  • {@link Float}
  • + *
  • {@link Double}
  • + *
  • {@link java.math.BigDecimal}
  • + *
  • {@link Boolean}
  • * * A {@link CountMinSketch} is initialized with a random seed, and a pair of parameters: *
      diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java index 045fec33a282a..0d010a9759526 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java @@ -211,10 +211,6 @@ private int hash(long item, int count) { return ((int) hash) % width; } - private static int[] getHashBuckets(String key, int hashCount, int max) { - return getHashBuckets(Utils.getBytesFromUTF8String(key), hashCount, max); - } - private static int[] getHashBuckets(byte[] b, int hashCount, int max) { int[] result = new int[hashCount]; int hash1 = Murmur3_x86_32.hashUnsafeBytes(b, Platform.BYTE_ARRAY_OFFSET, b.length, 0); @@ -244,7 +240,7 @@ private long estimateCountForLongItem(long item) { return res; } - private long estimateCountForStringItem(String item) { + private long estimateCountForBinaryItem(byte[] item) { long res = Long.MAX_VALUE; int[] buckets = getHashBuckets(item, depth, width); for (int i = 0; i < depth; ++i) { diff --git a/common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala b/common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala index 174eb01986c4f..6115710045eec 100644 --- a/common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala +++ b/common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.util.sketch import java.io.{ByteArrayInputStream, ByteArrayOutputStream} +import java.nio.charset.StandardCharsets import scala.reflect.ClassTag import scala.util.Random @@ -44,6 +45,12 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite } def testAccuracy[T: ClassTag](typeName: String)(itemGenerator: Random => T): Unit = { + def getProbeItem(item: T): Any = item match { + // Use a string to represent the content of an array of bytes + case bytes: Array[Byte] => new String(bytes, StandardCharsets.UTF_8) + case i => identity(i) + } + test(s"accuracy - $typeName") { // Uses fixed seed to ensure reproducible test execution val r = new Random(31) @@ -56,7 +63,7 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite val exactFreq = { val sampledItems = sampledItemIndices.map(allItems) - sampledItems.groupBy(identity).mapValues(_.length.toLong) + sampledItems.groupBy(getProbeItem).mapValues(_.length.toLong) } val sketch = CountMinSketch.create(epsOfTotalCount, confidence, seed) @@ -67,7 +74,7 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite val probCorrect = { val numErrors = allItems.map { item => - val count = exactFreq.getOrElse(item, 0L) + val count = exactFreq.getOrElse(getProbeItem(item), 0L) val ratio = (sketch.estimateCount(item) - count).toDouble / numAllItems if (ratio > epsOfTotalCount) 1 else 0 }.sum diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAggSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAggSuite.scala index 10479630f3f99..676cfdb448648 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAggSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAggSuite.scala @@ -190,7 +190,6 @@ class CountMinSketchAggSuite extends SparkFunSuite { agg.initialize(buffer) // Empty aggregation buffer assert(isEqual(agg.eval(buffer), emptyCms)) - // Empty input row agg.update(buffer, InternalRow(null)) assert(isEqual(agg.eval(buffer), emptyCms))