Skip to content

Commit

Permalink
fix inputTypes
Browse files Browse the repository at this point in the history
  • Loading branch information
wzhfy committed Nov 22, 2016
1 parent ca4a13f commit b009ff8
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,7 @@ case class CountMinSketchAgg(
copy(inputAggBufferOffset = newInputAggBufferOffset)

override def inputTypes: Seq[AbstractDataType] = {
// Currently `CountMinSketch` supports integral (date/timestamp is represented as int/long
// internally) and string types.
Seq(TypeCollection(IntegralType, StringType, DateType, TimestampType),
Seq(TypeCollection(NumericType, StringType, DateType, TimestampType, BooleanType, BinaryType),
DoubleType, DoubleType, IntegerType)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ class CountMinSketchAggQuerySuite extends QueryTest with SharedSQLContext {
private val numAllItems = 500000
private val numSamples = numAllItems / 10

private val eps = 0.0001
private val confidence = 0.95
private val eps = 0.0001D
private val confidence = 0.95D
private val seed = 11

val startDate = DateTimeUtils.fromJavaDate(Date.valueOf("1900-01-01"))
Expand Down Expand Up @@ -107,23 +107,14 @@ class CountMinSketchAggQuerySuite extends QueryTest with SharedSQLContext {
StructField("c11", BooleanType),
StructField("c12", BinaryType)))

def cmsSql(col: String): String = s"count_min_sketch($col, $eps, $confidence, $seed)"

val query =
s"""
|SELECT
| ${cmsSql("c1")}, ${cmsSql("c2")}, ${cmsSql("c3")}, ${cmsSql("c4")}, ${cmsSql("c5")},
| ${cmsSql("c6")}, ${cmsSql("c7")}, ${cmsSql("c8")}, ${cmsSql("c9")}, ${cmsSql("c10")},
| ${cmsSql("c11")}, ${cmsSql("c12")}
|FROM $table
""".stripMargin

withTempView(table) {
val rdd: RDD[Row] = spark.sparkContext.parallelize(data)
spark.createDataFrame(rdd, schema).createOrReplaceTempView(table)
val result = sql(query).queryExecution.toRdd.collect().head
val cmsSql = schema.fieldNames.map(col => s"count_min_sketch($col, $eps, $confidence, $seed)")
.mkString(", ")
val result = sql(s"SELECT $cmsSql FROM $table").head()
schema.indices.foreach { i =>
val binaryData = result.getBinary(i)
val binaryData = result.getAs[Array[Byte]](i)
val in = new ByteArrayInputStream(binaryData)
val cms = CountMinSketch.readFrom(in)
schema.fields(i).dataType match {
Expand Down

0 comments on commit b009ff8

Please sign in to comment.