diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java
index f7c22dddb8cc0..1930b03a18cb6 100644
--- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java
+++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java
@@ -31,6 +31,10 @@
*
{@link Integer}
* {@link Long}
* {@link String}
+ * {@link Float}
+ * {@link Double}
+ * {@link java.math.BigDecimal}
+ * {@link Boolean}
*
* A {@link CountMinSketch} is initialized with a random seed, and a pair of parameters:
*
diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java
index 045fec33a282a..0d010a9759526 100644
--- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java
+++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java
@@ -211,10 +211,6 @@ private int hash(long item, int count) {
return ((int) hash) % width;
}
- private static int[] getHashBuckets(String key, int hashCount, int max) {
- return getHashBuckets(Utils.getBytesFromUTF8String(key), hashCount, max);
- }
-
private static int[] getHashBuckets(byte[] b, int hashCount, int max) {
int[] result = new int[hashCount];
int hash1 = Murmur3_x86_32.hashUnsafeBytes(b, Platform.BYTE_ARRAY_OFFSET, b.length, 0);
@@ -244,7 +240,7 @@ private long estimateCountForLongItem(long item) {
return res;
}
- private long estimateCountForStringItem(String item) {
+ private long estimateCountForBinaryItem(byte[] item) {
long res = Long.MAX_VALUE;
int[] buckets = getHashBuckets(item, depth, width);
for (int i = 0; i < depth; ++i) {
diff --git a/common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala b/common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala
index 174eb01986c4f..6115710045eec 100644
--- a/common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala
+++ b/common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.util.sketch
import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
+import java.nio.charset.StandardCharsets
import scala.reflect.ClassTag
import scala.util.Random
@@ -44,6 +45,12 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite
}
def testAccuracy[T: ClassTag](typeName: String)(itemGenerator: Random => T): Unit = {
+ def getProbeItem(item: T): Any = item match {
+ // Use a string to represent the content of an array of bytes
+ case bytes: Array[Byte] => new String(bytes, StandardCharsets.UTF_8)
+ case i => identity(i)
+ }
+
test(s"accuracy - $typeName") {
// Uses fixed seed to ensure reproducible test execution
val r = new Random(31)
@@ -56,7 +63,7 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite
val exactFreq = {
val sampledItems = sampledItemIndices.map(allItems)
- sampledItems.groupBy(identity).mapValues(_.length.toLong)
+ sampledItems.groupBy(getProbeItem).mapValues(_.length.toLong)
}
val sketch = CountMinSketch.create(epsOfTotalCount, confidence, seed)
@@ -67,7 +74,7 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite
val probCorrect = {
val numErrors = allItems.map { item =>
- val count = exactFreq.getOrElse(item, 0L)
+ val count = exactFreq.getOrElse(getProbeItem(item), 0L)
val ratio = (sketch.estimateCount(item) - count).toDouble / numAllItems
if (ratio > epsOfTotalCount) 1 else 0
}.sum
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAggSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAggSuite.scala
index 10479630f3f99..676cfdb448648 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAggSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAggSuite.scala
@@ -190,7 +190,6 @@ class CountMinSketchAggSuite extends SparkFunSuite {
agg.initialize(buffer)
// Empty aggregation buffer
assert(isEqual(agg.eval(buffer), emptyCms))
-
// Empty input row
agg.update(buffer, InternalRow(null))
assert(isEqual(agg.eval(buffer), emptyCms))