Skip to content

Commit

Permalink
[SPARK-18429][SQL] implement a new Aggregate for CountMinSketch
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

This PR implements a new Aggregate to generate count min sketch, which is a wrapper of CountMinSketch.

## How was this patch tested?

add test cases

Author: wangzhenhua <[email protected]>

Closes apache#15877 from wzhfy/cms.
  • Loading branch information
wzhfy authored and Robert Kruszewski committed Dec 15, 2016
1 parent b06bcc3 commit ad5a3d1
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@
* <li>{@link Integer}</li>
* <li>{@link Long}</li>
* <li>{@link String}</li>
* <li>{@link Float}</li>
* <li>{@link Double}</li>
* <li>{@link java.math.BigDecimal}</li>
* <li>{@link Boolean}</li>
* </ul>
* A {@link CountMinSketch} is initialized with a random seed, and a pair of parameters:
* <ol>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -211,10 +211,6 @@ private int hash(long item, int count) {
return ((int) hash) % width;
}

private static int[] getHashBuckets(String key, int hashCount, int max) {
return getHashBuckets(Utils.getBytesFromUTF8String(key), hashCount, max);
}

private static int[] getHashBuckets(byte[] b, int hashCount, int max) {
int[] result = new int[hashCount];
int hash1 = Murmur3_x86_32.hashUnsafeBytes(b, Platform.BYTE_ARRAY_OFFSET, b.length, 0);
Expand Down Expand Up @@ -244,7 +240,7 @@ private long estimateCountForLongItem(long item) {
return res;
}

private long estimateCountForStringItem(String item) {
private long estimateCountForBinaryItem(byte[] item) {
long res = Long.MAX_VALUE;
int[] buckets = getHashBuckets(item, depth, width);
for (int i = 0; i < depth; ++i) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.util.sketch

import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
import java.nio.charset.StandardCharsets

import scala.reflect.ClassTag
import scala.util.Random
Expand All @@ -44,6 +45,12 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite
}

def testAccuracy[T: ClassTag](typeName: String)(itemGenerator: Random => T): Unit = {
def getProbeItem(item: T): Any = item match {
// Use a string to represent the content of an array of bytes
case bytes: Array[Byte] => new String(bytes, StandardCharsets.UTF_8)
case i => identity(i)
}

test(s"accuracy - $typeName") {
// Uses fixed seed to ensure reproducible test execution
val r = new Random(31)
Expand All @@ -56,7 +63,7 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite

val exactFreq = {
val sampledItems = sampledItemIndices.map(allItems)
sampledItems.groupBy(identity).mapValues(_.length.toLong)
sampledItems.groupBy(getProbeItem).mapValues(_.length.toLong)
}

val sketch = CountMinSketch.create(epsOfTotalCount, confidence, seed)
Expand All @@ -67,7 +74,7 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite

val probCorrect = {
val numErrors = allItems.map { item =>
val count = exactFreq.getOrElse(item, 0L)
val count = exactFreq.getOrElse(getProbeItem(item), 0L)
val ratio = (sketch.estimateCount(item) - count).toDouble / numAllItems
if (ratio > epsOfTotalCount) 1 else 0
}.sum
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,6 @@ class CountMinSketchAggSuite extends SparkFunSuite {
agg.initialize(buffer)
// Empty aggregation buffer
assert(isEqual(agg.eval(buffer), emptyCms))

// Empty input row
agg.update(buffer, InternalRow(null))
assert(isEqual(agg.eval(buffer), emptyCms))
Expand Down

0 comments on commit ad5a3d1

Please sign in to comment.