-
Notifications
You must be signed in to change notification settings - Fork 28.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
SPARK-1939 Refactor takeSample method in RDD to use ScaSRS
Modified the takeSample method in RDD to use the ScaSRS sampling technique to improve performance. Added a private method that computes sampling rate > sample_size/total to ensure sufficient sample size with success rate >= 0.9999. Added a unit test for the private method to validate choice of sampling rate. Author: Doris Xin <[email protected]> Author: dorx <[email protected]> Author: Xiangrui Meng <[email protected]> Closes #916 from dorx/takeSample and squashes the following commits: 5b061ae [Doris Xin] merge master 444e750 [Doris Xin] edge cases 3de882b [dorx] Merge pull request #2 from mengxr/SPARK-1939 82dde31 [Xiangrui Meng] update pyspark's takeSample 48d954d [Doris Xin] remove unused imports from RDDSuite fb1452f [Doris Xin] allowing num to be greater than count in all cases 1481b01 [Doris Xin] washing test tubes and making coffee dc699f3 [Doris Xin] give back imports removed by accident in rdd.py 64e445b [Doris Xin] logwarnning as soon as it enters the while loop 55518ed [Doris Xin] added TODO for logging in rdd.py eff89e2 [Doris Xin] addressed reviewer comments. ecab508 [Doris Xin] "fixed checkstyle violation 0a9b3e3 [Doris Xin] "reviewer comment addressed" f80f270 [Doris Xin] Merge branch 'master' into takeSample ae3ad04 [Doris Xin] fixed edge cases to prevent overflow 065ebcd [Doris Xin] Merge branch 'master' into takeSample 9bdd36e [Doris Xin] Check sample size and move computeFraction e3fd6a6 [Doris Xin] Merge branch 'master' into takeSample 7cab53a [Doris Xin] fixed import bug in rdd.py ffea61a [Doris Xin] SPARK-1939: Refactor takeSample method in RDD 1441977 [Doris Xin] SPARK-1939 Refactor takeSample method in RDD to use ScaSRS
- Loading branch information
Showing
8 changed files
with
263 additions
and
100 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
55 changes: 55 additions & 0 deletions
55
core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.util.random | ||
|
||
private[spark] object SamplingUtils { | ||
|
||
/** | ||
* Returns a sampling rate that guarantees a sample of size >= sampleSizeLowerBound 99.99% of | ||
* the time. | ||
* | ||
* How the sampling rate is determined: | ||
* Let p = num / total, where num is the sample size and total is the total number of | ||
* datapoints in the RDD. We're trying to compute q > p such that | ||
* - when sampling with replacement, we're drawing each datapoint with prob_i ~ Pois(q), | ||
* where we want to guarantee Pr[s < num] < 0.0001 for s = sum(prob_i for i from 0 to total), | ||
* i.e. the failure rate of not having a sufficiently large sample < 0.0001. | ||
* Setting q = p + 5 * sqrt(p/total) is sufficient to guarantee 0.9999 success rate for | ||
* num > 12, but we need a slightly larger q (9 empirically determined). | ||
* - when sampling without replacement, we're drawing each datapoint with prob_i | ||
* ~ Binomial(total, fraction) and our choice of q guarantees 1-delta, or 0.9999 success | ||
* rate, where success rate is defined the same as in sampling with replacement. | ||
* | ||
* @param sampleSizeLowerBound sample size | ||
* @param total size of RDD | ||
* @param withReplacement whether sampling with replacement | ||
* @return a sampling rate that guarantees sufficient sample size with 99.99% success rate | ||
*/ | ||
def computeFractionForSampleSize(sampleSizeLowerBound: Int, total: Long, | ||
withReplacement: Boolean): Double = { | ||
val fraction = sampleSizeLowerBound.toDouble / total | ||
if (withReplacement) { | ||
val numStDev = if (sampleSizeLowerBound < 12) 9 else 5 | ||
fraction + numStDev * math.sqrt(fraction / total) | ||
} else { | ||
val delta = 1e-4 | ||
val gamma = - math.log(delta) / total | ||
math.min(1, fraction + gamma + math.sqrt(gamma * gamma + 2 * gamma * fraction)) | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
46 changes: 46 additions & 0 deletions
46
core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.util.random | ||
|
||
import org.apache.commons.math3.distribution.{BinomialDistribution, PoissonDistribution} | ||
import org.scalatest.FunSuite | ||
|
||
class SamplingUtilsSuite extends FunSuite { | ||
|
||
test("computeFraction") { | ||
// test that the computed fraction guarantees enough data points | ||
// in the sample with a failure rate <= 0.0001 | ||
val n = 100000 | ||
|
||
for (s <- 1 to 15) { | ||
val frac = SamplingUtils.computeFractionForSampleSize(s, n, true) | ||
val poisson = new PoissonDistribution(frac * n) | ||
assert(poisson.inverseCumulativeProbability(0.0001) >= s, "Computed fraction is too low") | ||
} | ||
for (s <- List(20, 100, 1000)) { | ||
val frac = SamplingUtils.computeFractionForSampleSize(s, n, true) | ||
val poisson = new PoissonDistribution(frac * n) | ||
assert(poisson.inverseCumulativeProbability(0.0001) >= s, "Computed fraction is too low") | ||
} | ||
for (s <- List(1, 10, 100, 1000)) { | ||
val frac = SamplingUtils.computeFractionForSampleSize(s, n, false) | ||
val binomial = new BinomialDistribution(n, frac) | ||
assert(binomial.inverseCumulativeProbability(0.0001)*n >= s, "Computed fraction is too low") | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
1de1d70
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The added commons-math3 into core/pom.xml should give it a version, otherwise, compiling in my machine failed.
./pom.xml
<math3.version>3.3</math3.version>
core/pom.xml
org.apache.commons
commons-math3
${math3.version}
test
1de1d70
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Qiuzhuang Thanks! It should be fixed now.