-
Notifications
You must be signed in to change notification settings - Fork 28.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-18429] [SQL] implement a new Aggregate for CountMinSketch #15877
Conversation
cc @rxin |
cc @liancheng |
Test build #68604 has finished for PR 15877 at commit
|
retest this please |
Test build #68610 has finished for PR 15877 at commit
|
buffer.mergeInPlace(input) | ||
} | ||
|
||
override def eval(buffer: CountMinSketch): Any = new GenericArrayData(serialize(buffer)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this an array of bytes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is better just to return the byte array and to change the datatype into a BinaryType
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that's better, thanks!
} | ||
|
||
override def checkInputDataTypes(): TypeCheckResult = { | ||
val defaultCheck = super.checkInputDataTypes() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we need to check this (the super class does not implement this).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ExpectsInputTypes.checkInputDataTypes() checks validity of input types, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is fair.
} | ||
|
||
override def createAggregationBuffer(): CountMinSketch = { | ||
val eps: Double = epsExpression.eval().asInstanceOf[Double] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we cache this in lazy vals? I am not sure about the performance implications.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, i'll change them to lazy vals
// ignore empty rows | ||
if (value != null) { | ||
// UTF8String is a spark sql type, while CountMinSketch accepts String type | ||
buffer.add(if (value.isInstanceOf[UTF8String]) value.toString else value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How bad would it be to add support for UTF8 string to CMS? Or to pass the UTF8 byte array to CMS?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, we should pass the byte array to CMS.
/** | ||
* This function returns a count-min sketch of a column with the given esp, confidence and seed. | ||
* A count-min sketch is a probabilistic data structure used for summarizing streams of data in | ||
* sub-linear space, which is useful for equality predicates and join size estimation. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe something on the return type? A developer should know how to work with these bytes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, I wrote this in usage, I'll add it here too, thanks.
copy(inputAggBufferOffset = newInputAggBufferOffset) | ||
|
||
override def inputTypes: Seq[AbstractDataType] = { | ||
// currently `CountMinSketch` supports integral and string types |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we expand this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@rxin suggested that for unsupported types, we hash it before count min sketch, i.e. CountMinSketchAgg(hash(col)).
agg.merge(mergeBuffer, group1Buffer) | ||
agg.merge(mergeBuffer, group2Buffer) | ||
checkResult(agg.eval(mergeBuffer), allItems, exactFreq) | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This might also be a good place to test merging in a different order, and the merging of an empty partition.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I'll also test these.
data: Array[T], | ||
exactFreq: Map[T, Long]): Unit = { | ||
result match { | ||
case arrayData: ArrayData => |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add case _ => fail("unexpected return type")
to have a nicer error when something goed wrong there
!seedExpression.foldable) { | ||
TypeCheckFailure( | ||
"The eps, confidence or seed provided must be a literal or constant foldable") | ||
} else if (epsExpression.eval() == null || confidenceExpression.eval() == null || |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we also check for negative eps and confidence values?
confidenceExpression = Literal(confidence), | ||
seedExpression = Literal(seed)) | ||
val err = intercept[IllegalArgumentException] { | ||
invalidAgg.createAggregationBuffer() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See my comment in the CMS agg. This is too late to throw such an error. I'd rather have driver side errors then executor side errors.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, we should have driver side errors, thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks pretty good! I have left a few minor comments. Also consider to register this aggregate in the FunctionRegistry and to add it to functions.scala.
yes please register a count_min_sketch and alias cmsketch in FunctionRegistry. |
case StringType => buffer.addBinary(value.asInstanceOf[UTF8String].getBytes) | ||
case ByteType => buffer.addLong(value.asInstanceOf[Byte]) | ||
case ShortType => buffer.addLong(value.asInstanceOf[Short]) | ||
case IntegerType => buffer.addLong(value.asInstanceOf[Int]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add DateType?
case ByteType => buffer.addLong(value.asInstanceOf[Byte]) | ||
case ShortType => buffer.addLong(value.asInstanceOf[Short]) | ||
case IntegerType => buffer.addLong(value.asInstanceOf[Int]) | ||
case LongType => buffer.addLong(value.asInstanceOf[Long]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add TimestampType?
val value = child.eval(input) | ||
// ignore empty rows | ||
if (value != null) { | ||
child.dataType match { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A general question: what is faster a pattern match at runtime or to use a virtual function here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
virtual function dispatch is usually a lot faster than pattern match.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
although i don't know if it matters much here given we are going to run it through many hash functions.
Test build #68664 has finished for PR 15877 at commit
|
Test build #68666 has finished for PR 15877 at commit
|
@@ -261,6 +261,8 @@ object FunctionRegistry { | |||
expression[VarianceSamp]("var_samp"), | |||
expression[CollectList]("collect_list"), | |||
expression[CollectSet]("collect_set"), | |||
expression[CountMinSketchAgg]("count_min_sketch"), | |||
expression[CountMinSketchAgg]("cmsketch"), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
actually i take my word back. let's add only count_min_sketch. I don't think it's worth having an alias given this is sketch is difficult to consume (returning some binary)
* @group agg_funcs | ||
* @since 2.2.0 | ||
*/ | ||
def count_min_sketch(e: Column, eps: Double, confidence: Double, seed: Int): Column = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's not add these for now.
Test build #68699 has finished for PR 15877 at commit
|
Test build #68703 has finished for PR 15877 at commit
|
Hi @hvanhovell @rxin, I've updated this pr, does it look good to you now? |
val value = child.eval(input) | ||
// Ignore empty rows | ||
if (value != null) { | ||
child.dataType match { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lets not do a pattern match for every update. We should use an update function instead, for example:
private[this] val doUpdate: (CountMinSketch, Any) => Unit = child.dataType match {
case StringType => (cms, value) => cms.addBinary(value.asInstanceOf[UTF8String].getBytes)
case ByteType => (cms, value) => cms.addLong(value..asInstanceOf[Byte])
...
}
override def update(buffer: CountMinSketch, input: InternalRow): Unit = {
val value = child.eval(input)
if (value != null) {
doUpdate(buffer, value)
}
}
// Currently `CountMinSketch` supports integral (date/timestamp is represented as int/long | ||
// internally) and string types. | ||
Seq(TypeCollection(IntegralType, StringType, DateType, TimestampType), | ||
DoubleType, DoubleType, IntegerType) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also add FloatType
(use Float.floatToIntBits), DoubleType
(use Double.doubleToLongBits), BooleanType
and BinaryType
? We could also add support for Decimal, but that would be a bit harder to get right.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@rxin @hvanhovell If we really want to support all these types, is it better to move this conversion and pattern match logics into CountMinSketch
? That is, make cms support these types itself. Then, when users do queries e.g. on float type, they don't need to do conversions like cms.estimateCount(Float.floatToIntBits(value))
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes if we want to add support for those I think it'd make sense to do it in count-min sketch itself too.
Test build #68985 has finished for PR 15877 at commit
|
Test build #68994 has finished for PR 15877 at commit
|
cc @rxin @hvanhovell |
@@ -152,6 +153,16 @@ public void add(Object item) { | |||
public void add(Object item, long count) { | |||
if (item instanceof String) { | |||
addString((String) item, count); | |||
} else if (item instanceof BigDecimal) { | |||
addString(((BigDecimal) item).toString(), count); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here I use string to represent decimal because there is a one-to-one mapping between BigDecimal and String.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this true?
"1.0" and "1.00" is the same value but not the same string representation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I didn't describe it accurately. It should be "There is a one-to-one mapping between the distinguishable values and the result of this conversion." (from java doc of BigDecimal)
Test build #69049 has finished for PR 15877 at commit
|
Test build #69047 has finished for PR 15877 at commit
|
Thanks - I'm going to merge this in master. I will submit a follow-up PR to simplify this a little bit, and remove the handling of float/double/decimal types and require explicit user action on how to turn that into long. |
Hey guys - after looking at the pr more, I'm afraid we have gone overboard with testing here. Most of the test cases written are just repeating each other and doing exactly the same thing. For testing something like this I'd probably just have some simple end-to-end test and then be done with it, because most of the complicated logics are isolated in the actual CountMinSketch implementation itself and already has good test coverage. |
assert(buffer.equals(agg.deserialize(agg.serialize(buffer)))) | ||
} | ||
|
||
def testHighLevelInterface[T: ClassTag]( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@wzhfy can you comment on why we need to test both the high level interface and the low level interface?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just followed the style in ApproximatePercentileSuite
which is also a TypedImperativeAggregate
. I thought they are used to test different levels of operations for TypedImperativeAggregate
, e.g. update(buffer: InternalRow, input: InternalRow)
and def update(buffer: T, input: InternalRow)
.
## What changes were proposed in this pull request? This PR implements a new Aggregate to generate count min sketch, which is a wrapper of CountMinSketch. ## How was this patch tested? add test cases Author: wangzhenhua <[email protected]> Closes apache#15877 from wzhfy/cms.
## What changes were proposed in this pull request? This PR implements a new Aggregate to generate count min sketch, which is a wrapper of CountMinSketch. ## How was this patch tested? add test cases Author: wangzhenhua <[email protected]> Closes apache#15877 from wzhfy/cms.
## What changes were proposed in this pull request? This PR implements a new Aggregate to generate count min sketch, which is a wrapper of CountMinSketch. ## How was this patch tested? add test cases Author: wangzhenhua <[email protected]> Closes apache#15877 from wzhfy/cms.
What changes were proposed in this pull request?
This PR implements a new Aggregate to generate count min sketch, which is a wrapper of CountMinSketch.
How was this patch tested?
add test cases