Skip to content

Commit

Permalink
support more types
Browse files Browse the repository at this point in the history
  • Loading branch information
wzhfy committed Nov 22, 2016
1 parent 3e86075 commit ca4a13f
Show file tree
Hide file tree
Showing 6 changed files with 141 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@
* <li>{@link Integer}</li>
* <li>{@link Long}</li>
* <li>{@link String}</li>
* <li>{@link Float}</li>
* <li>{@link Double}</li>
* <li>{@link java.math.BigDecimal}</li>
* <li>{@link Boolean}</li>
* </ul>
* A {@link CountMinSketch} is initialized with a random seed, and a pair of parameters:
* <ol>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import java.io.ObjectOutputStream;
import java.io.OutputStream;
import java.io.Serializable;
import java.math.BigDecimal;
import java.util.Arrays;
import java.util.Random;

Expand Down Expand Up @@ -152,6 +153,16 @@ public void add(Object item) {
public void add(Object item, long count) {
if (item instanceof String) {
addString((String) item, count);
} else if (item instanceof BigDecimal) {
addString(((BigDecimal) item).toString(), count);
} else if (item instanceof byte[]) {
addBinary((byte[]) item, count);
} else if (item instanceof Float) {
addLong(Float.floatToIntBits((Float) item), count);
} else if (item instanceof Double) {
addLong(Double.doubleToLongBits((Double) item), count);
} else if (item instanceof Boolean) {
addLong(((Boolean) item) ? 1L : 0L, count);
} else {
addLong(Utils.integralToLong(item), count);
}
Expand Down Expand Up @@ -216,10 +227,6 @@ private int hash(long item, int count) {
return ((int) hash) % width;
}

private static int[] getHashBuckets(String key, int hashCount, int max) {
return getHashBuckets(Utils.getBytesFromUTF8String(key), hashCount, max);
}

private static int[] getHashBuckets(byte[] b, int hashCount, int max) {
int[] result = new int[hashCount];
int hash1 = Murmur3_x86_32.hashUnsafeBytes(b, Platform.BYTE_ARRAY_OFFSET, b.length, 0);
Expand All @@ -233,7 +240,18 @@ private static int[] getHashBuckets(byte[] b, int hashCount, int max) {
@Override
public long estimateCount(Object item) {
if (item instanceof String) {
return estimateCountForStringItem((String) item);
return estimateCountForBinaryItem(Utils.getBytesFromUTF8String((String) item));
} else if (item instanceof BigDecimal) {
return estimateCountForBinaryItem(
Utils.getBytesFromUTF8String(((BigDecimal) item).toString()));
} else if (item instanceof byte[]) {
return estimateCountForBinaryItem((byte[]) item);
} else if (item instanceof Float) {
return estimateCountForLongItem(Float.floatToIntBits((Float) item));
} else if (item instanceof Double) {
return estimateCountForLongItem(Double.doubleToLongBits((Double) item));
} else if (item instanceof Boolean) {
return estimateCountForLongItem(((Boolean) item) ? 1L : 0L);
} else {
return estimateCountForLongItem(Utils.integralToLong(item));
}
Expand All @@ -247,7 +265,7 @@ private long estimateCountForLongItem(long item) {
return res;
}

private long estimateCountForStringItem(String item) {
private long estimateCountForBinaryItem(byte[] item) {
long res = Long.MAX_VALUE;
int[] buckets = getHashBuckets(item, depth, width);
for (int i = 0; i < depth; ++i) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.util.sketch

import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
import java.nio.charset.StandardCharsets

import scala.reflect.ClassTag
import scala.util.Random
Expand All @@ -44,6 +45,12 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite
}

def testAccuracy[T: ClassTag](typeName: String)(itemGenerator: Random => T): Unit = {
def getProbeItem(item: T): Any = item match {
// Use a string to represent the content of an array of bytes
case bytes: Array[Byte] => new String(bytes, StandardCharsets.UTF_8)
case i => identity(i)
}

test(s"accuracy - $typeName") {
// Uses fixed seed to ensure reproducible test execution
val r = new Random(31)
Expand All @@ -56,7 +63,7 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite

val exactFreq = {
val sampledItems = sampledItemIndices.map(allItems)
sampledItems.groupBy(identity).mapValues(_.length.toLong)
sampledItems.groupBy(getProbeItem).mapValues(_.length.toLong)
}

val sketch = CountMinSketch.create(epsOfTotalCount, confidence, seed)
Expand All @@ -67,7 +74,7 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite

val probCorrect = {
val numErrors = allItems.map { item =>
val count = exactFreq.getOrElse(item, 0L)
val count = exactFreq.getOrElse(getProbeItem(item), 0L)
val ratio = (sketch.estimateCount(item) - count).toDouble / numAllItems
if (ratio > epsOfTotalCount) 1 else 0
}.sum
Expand Down Expand Up @@ -135,6 +142,18 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite

testItemType[String]("String") { r => r.nextString(r.nextInt(20)) }

testItemType[Float]("Float") { _.nextFloat() }

testItemType[Double]("Double") { _.nextDouble() }

testItemType[java.math.BigDecimal]("Decimal") { r => new java.math.BigDecimal(r.nextDouble()) }

testItemType[Boolean]("Boolean") { _.nextBoolean() }

testItemType[Array[Byte]]("Binary") { r =>
Utils.getBytesFromUTF8String(r.nextString(r.nextInt(20)))
}

test("incompatible merge") {
intercept[IncompatibleMergeException] {
CountMinSketch.create(10, 10, 1).mergeInPlace(null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,15 +96,13 @@ case class CountMinSketchAgg(
// Ignore empty rows
if (value != null) {
child.dataType match {
// `Decimal` and `UTF8String` are internal types in spark sql, we need to convert them
// into acceptable types for `CountMinSketch`.
case DecimalType() => buffer.add(value.asInstanceOf[Decimal].toJavaBigDecimal)
// For string type, we can get bytes of our `UTF8String` directly, and call the `addBinary`
// instead of `addString` to avoid unnecessary conversion.
case StringType => buffer.addBinary(value.asInstanceOf[UTF8String].getBytes)
case ByteType => buffer.addLong(value.asInstanceOf[Byte])
case ShortType => buffer.addLong(value.asInstanceOf[Short])
case IntegerType => buffer.addLong(value.asInstanceOf[Int])
case LongType => buffer.addLong(value.asInstanceOf[Long])
case DateType => buffer.addLong(value.asInstanceOf[Int])
case TimestampType => buffer.addLong(value.asInstanceOf[Long])
case _ => buffer.add(value)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions.aggregate

import java.io.ByteArrayInputStream
import java.nio.charset.StandardCharsets

import scala.reflect.ClassTag
import scala.util.Random
Expand All @@ -26,7 +27,7 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, BoundReference, Cast, GenericInternalRow, Literal}
import org.apache.spark.sql.types._
import org.apache.spark.sql.types.{DecimalType, _}
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.sketch.CountMinSketch

Expand Down Expand Up @@ -55,7 +56,7 @@ class CountMinSketchAggSuite extends SparkFunSuite {
dataType: DataType,
sampledItemIndices: Array[Int],
allItems: Array[T],
exactFreq: Map[T, Long]): Any = {
exactFreq: Map[Any, Long]): Any = {
test(s"high level interface, update, merge, eval... - $dataType") {
val agg = new CountMinSketchAgg(BoundReference(0, dataType, nullable = true),
Literal(epsOfTotalCount), Literal(confidence), Literal(seed))
Expand Down Expand Up @@ -97,7 +98,7 @@ class CountMinSketchAggSuite extends SparkFunSuite {
dataType: DataType,
sampledItemIndices: Array[Int],
allItems: Array[T],
exactFreq: Map[T, Long]): Any = {
exactFreq: Map[Any, Long]): Any = {
test(s"low level interface, update, merge, eval... - ${dataType.typeName}") {
val inputAggregationBufferOffset = 1
val mutableAggregationBufferOffset = 2
Expand Down Expand Up @@ -132,15 +133,19 @@ class CountMinSketchAggSuite extends SparkFunSuite {
private def checkResult[T: ClassTag](
result: Any,
data: Array[T],
exactFreq: Map[T, Long]): Unit = {
exactFreq: Map[Any, Long]): Unit = {
result match {
case bytesData: Array[Byte] =>
val in = new ByteArrayInputStream(bytesData)
val cms = CountMinSketch.readFrom(in)
val probCorrect = {
val numErrors = data.map { i =>
val item = if (i.isInstanceOf[UTF8String]) i.toString else i
val count = exactFreq.getOrElse(i, 0L)
val count = exactFreq.getOrElse(getProbeItem(i), 0L)
val item = i match {
case dec: Decimal => dec.toJavaBigDecimal
case str: UTF8String => str.getBytes
case _ => i
}
val ratio = (cms.estimateCount(item) - count).toDouble / data.length
if (ratio > epsOfTotalCount) 1 else 0
}.sum
Expand All @@ -156,6 +161,12 @@ class CountMinSketchAggSuite extends SparkFunSuite {
}
}

private def getProbeItem[T: ClassTag](item: T): Any = item match {
// Use a string to represent the content of an array of bytes
case bytes: Array[Byte] => new String(bytes, StandardCharsets.UTF_8)
case i => identity(i)
}

def testItemType[T: ClassTag](dataType: DataType)(itemGenerator: Random => T): Unit = {
// Uses fixed seed to ensure reproducible test execution
val r = new Random(31)
Expand All @@ -168,7 +179,7 @@ class CountMinSketchAggSuite extends SparkFunSuite {

val exactFreq = {
val sampledItems = sampledItemIndices.map(allItems)
sampledItems.groupBy(identity).mapValues(_.length.toLong)
sampledItems.groupBy(getProbeItem).mapValues(_.length.toLong)
}

testLowLevelInterface[T](dataType, sampledItemIndices, allItems, exactFreq)
Expand All @@ -185,6 +196,18 @@ class CountMinSketchAggSuite extends SparkFunSuite {

testItemType[UTF8String](StringType) { r => UTF8String.fromString(r.nextString(r.nextInt(20))) }

testItemType[Float](FloatType) { _.nextFloat() }

testItemType[Double](DoubleType) { _.nextDouble() }

testItemType[Decimal](new DecimalType()) { r => Decimal(r.nextDouble()) }

testItemType[Boolean](BooleanType) { _.nextBoolean() }

testItemType[Array[Byte]](BinaryType) { r =>
r.nextString(r.nextInt(20)).getBytes(StandardCharsets.UTF_8)
}


test("fails analysis if eps, confidence or seed provided is not a literal or constant foldable") {
val wrongEps = new CountMinSketchAgg(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql

import java.io.ByteArrayInputStream
import java.nio.charset.StandardCharsets
import java.sql.{Date, Timestamp}

import scala.reflect.ClassTag
Expand All @@ -26,7 +27,8 @@ import scala.util.Random
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{StringType, _}
import org.apache.spark.sql.types.{Decimal, StringType, _}
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.sketch.CountMinSketch

class CountMinSketchAggQuerySuite extends QueryTest with SharedSQLContext {
Expand Down Expand Up @@ -64,6 +66,17 @@ class CountMinSketchAggQuerySuite extends QueryTest with SharedSQLContext {
val (allTimestamps, sampledTSIndices, exactTSFreq) = generateTestData[Timestamp] { r =>
DateTimeUtils.toJavaTimestamp(r.nextLong() % (endTS - startTS) + startTS)
}
val (allFloats, sampledFloatIndices, exactFloatFreq) =
generateTestData[Float] { _.nextFloat() }
val (allDoubles, sampledDoubleIndices, exactDoubleFreq) =
generateTestData[Double] { _.nextDouble() }
val (allDeciamls, sampledDecimalIndices, exactDecimalFreq) =
generateTestData[Decimal] { r => Decimal(r.nextDouble()) }
val (allBooleans, sampledBooleanIndices, exactBooleanFreq) =
generateTestData[Boolean] { _.nextBoolean() }
val (allBinaries, sampledBinaryIndices, exactBinaryFreq) = generateTestData[Array[Byte]] { r =>
r.nextString(r.nextInt(20)).getBytes(StandardCharsets.UTF_8)
}

val data = (0 until numSamples).map { i =>
Row(allBytes(sampledByteIndices(i)),
Expand All @@ -72,7 +85,12 @@ class CountMinSketchAggQuerySuite extends QueryTest with SharedSQLContext {
allLongs(sampledLongIndices(i)),
allStrings(sampledStringIndices(i)),
allDates(sampledDateIndices(i)),
allTimestamps(sampledTSIndices(i)))
allTimestamps(sampledTSIndices(i)),
allFloats(sampledFloatIndices(i)),
allDoubles(sampledDoubleIndices(i)),
allDeciamls(sampledDecimalIndices(i)),
allBooleans(sampledBooleanIndices(i)),
allBinaries(sampledBinaryIndices(i)))
}

val schema = StructType(Seq(
Expand All @@ -82,18 +100,21 @@ class CountMinSketchAggQuerySuite extends QueryTest with SharedSQLContext {
StructField("c4", LongType),
StructField("c5", StringType),
StructField("c6", DateType),
StructField("c7", TimestampType)))
StructField("c7", TimestampType),
StructField("c8", FloatType),
StructField("c9", DoubleType),
StructField("c10", new DecimalType()),
StructField("c11", BooleanType),
StructField("c12", BinaryType)))

def cmsSql(col: String): String = s"count_min_sketch($col, $eps, $confidence, $seed)"

val query =
s"""
|SELECT
| count_min_sketch(c1, $eps, $confidence, $seed),
| count_min_sketch(c2, $eps, $confidence, $seed),
| count_min_sketch(c3, $eps, $confidence, $seed),
| count_min_sketch(c4, $eps, $confidence, $seed),
| count_min_sketch(c5, $eps, $confidence, $seed),
| count_min_sketch(c6, $eps, $confidence, $seed),
| count_min_sketch(c7, $eps, $confidence, $seed)
| ${cmsSql("c1")}, ${cmsSql("c2")}, ${cmsSql("c3")}, ${cmsSql("c4")}, ${cmsSql("c5")},
| ${cmsSql("c6")}, ${cmsSql("c7")}, ${cmsSql("c8")}, ${cmsSql("c9")}, ${cmsSql("c10")},
| ${cmsSql("c11")}, ${cmsSql("c12")}
|FROM $table
""".stripMargin

Expand All @@ -114,11 +135,20 @@ class CountMinSketchAggQuerySuite extends QueryTest with SharedSQLContext {
case DateType =>
checkResult(cms,
allDates.map(DateTimeUtils.fromJavaDate),
exactDateFreq.map(e => (DateTimeUtils.fromJavaDate(e._1), e._2)))
exactDateFreq.map { e =>
(DateTimeUtils.fromJavaDate(e._1.asInstanceOf[Date]), e._2)
})
case TimestampType =>
checkResult(cms,
allTimestamps.map(DateTimeUtils.fromJavaTimestamp),
exactTSFreq.map(e => (DateTimeUtils.fromJavaTimestamp(e._1), e._2)))
exactTSFreq.map { e =>
(DateTimeUtils.fromJavaTimestamp(e._1.asInstanceOf[Timestamp]), e._2)
})
case FloatType => checkResult(cms, allFloats, exactFloatFreq)
case DoubleType => checkResult(cms, allDoubles, exactDoubleFreq)
case DecimalType() => checkResult(cms, allDeciamls, exactDecimalFreq)
case BooleanType => checkResult(cms, allBooleans, exactBooleanFreq)
case BinaryType => checkResult(cms, allBinaries, exactBinaryFreq)
}
}
}
Expand All @@ -127,11 +157,16 @@ class CountMinSketchAggQuerySuite extends QueryTest with SharedSQLContext {
private def checkResult[T: ClassTag](
cms: CountMinSketch,
data: Array[T],
exactFreq: Map[T, Long]): Unit = {
exactFreq: Map[Any, Long]): Unit = {
val probCorrect = {
val numErrors = data.map { i =>
val count = exactFreq.getOrElse(i, 0L)
val ratio = (cms.estimateCount(i) - count).toDouble / data.length
val count = exactFreq.getOrElse(getProbeItem(i), 0L)
val item = i match {
case dec: Decimal => dec.toJavaBigDecimal
case str: UTF8String => str.getBytes
case _ => i
}
val ratio = (cms.estimateCount(item) - count).toDouble / data.length
if (ratio > eps) 1 else 0
}.sum

Expand All @@ -144,13 +179,19 @@ class CountMinSketchAggQuerySuite extends QueryTest with SharedSQLContext {
)
}

private def getProbeItem[T: ClassTag](item: T): Any = item match {
// Use a string to represent the content of an array of bytes
case bytes: Array[Byte] => new String(bytes, StandardCharsets.UTF_8)
case i => identity(i)
}

private def generateTestData[T: ClassTag](
itemGenerator: Random => T): (Array[T], Array[Int], Map[T, Long]) = {
itemGenerator: Random => T): (Array[T], Array[Int], Map[Any, Long]) = {
val allItems = Array.fill(numAllItems)(itemGenerator(r))
val sampledItemIndices = Array.fill(numSamples)(r.nextInt(numAllItems))
val exactFreq = {
val sampledItems = sampledItemIndices.map(allItems)
sampledItems.groupBy(identity).mapValues(_.length.toLong)
sampledItems.groupBy(getProbeItem).mapValues(_.length.toLong)
}
(allItems, sampledItemIndices, exactFreq)
}
Expand Down

0 comments on commit ca4a13f

Please sign in to comment.