diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 1881f85d0ff56..5c30c0a2f8cd7 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -227,8 +227,8 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) fractionByKey: Map[K, Double], seed: Long = Utils.random.nextLong, exact: Boolean = true): RDD[(K, V)]= { + require(fractionByKey.forall({case(k, v) => v >= 0.0}), "Invalid sampling rates.") if (withReplacement) { - require(fractionByKey.forall({case(k, v) => v >= 0.0}), "Invalid sampling rates.") val counts = if (exact) Some(this.countByKey()) else None val samplingFunc = StratifiedSampler.getPoissonSamplingFunction(self, fractionByKey, exact, counts, seed) diff --git a/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala b/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala index 19c00384ca60f..0a4a3a7fc85ff 100644 --- a/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala @@ -54,7 +54,8 @@ private[spark] object SamplingUtils { } else { val delta = 1e-4 val gamma = - math.log(delta) / total - math.min(1, math.max(1e-10, fraction + gamma + math.sqrt(gamma * gamma + 2 * gamma * fraction))) + math.min(1, + math.max(1e-10, fraction + gamma + math.sqrt(gamma * gamma + 2 * gamma * fraction))) } } } diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala index 5868440065877..4ac99a9dc6824 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala @@ -106,7 +106,7 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { n: Long) = { val expectedSampleSize = stratifiedData.countByKey().mapValues(count => math.ceil(count * samplingRate).toInt) - val fractionByKey = (_:String) => samplingRate + val fractionByKey = Map("1" -> samplingRate, "0" -> samplingRate) val sample = stratifiedData.sampleByKey(false, fractionByKey, seed, exact) val sampleCounts = sample.countByKey() val takeSample = sample.collect() @@ -124,7 +124,7 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { n: Long) = { val expectedSampleSize = stratifiedData.countByKey().mapValues(count => math.ceil(count * samplingRate).toInt) - val fractionByKey = (_:String) => samplingRate + val fractionByKey = Map("1" -> samplingRate, "0" -> samplingRate) val sample = stratifiedData.sampleByKey(true, fractionByKey, seed, exact) val sampleCounts = sample.countByKey() val takeSample = sample.collect()