diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index dbde4a8db0011..be1ac1fc57049 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -827,8 +827,15 @@ abstract class RDD[T: ClassTag]( jobResult.getOrElse(throw new UnsupportedOperationException("empty collection")) } - def treeReduce(f: (T, T) => T, level: Int): T = { - require(level >= 1, s"Level must be greater than 1 but got $level.") + /** + * :: DeveloperApi :: + * Reduces the elements of this RDD in a tree pattern. + * @param depth suggested depth of the tree + * @see [[org.apache.spark.rdd.RDD#reduce]] + */ + @DeveloperApi + def treeReduce(f: (T, T) => T, depth: Int): T = { + require(depth >= 1, s"Depth must be greater than 1 but got $depth.") val cleanF = sc.clean(f) val reducePartition: Iterator[T] => Option[T] = iter => { if (iter.hasNext) { @@ -849,7 +856,7 @@ abstract class RDD[T: ClassTag]( None } } - local.treeAggregate(Option.empty[T])(op, op, level) + local.treeAggregate(Option.empty[T])(op, op, depth) .getOrElse(throw new UnsupportedOperationException("empty collection")) } @@ -888,12 +895,18 @@ abstract class RDD[T: ClassTag]( jobResult } + /** + * :: DeveloperApi :: + * Aggregates the elements of this RDD in a tree pattern. + * @param depth suggested depth of the tree + * @see [[org.apache.spark.rdd.RDD#aggregate]] + */ @DeveloperApi def treeAggregate[U: ClassTag](zeroValue: U)( seqOp: (U, T) => U, combOp: (U, U) => U, - level: Int): U = { - require(level >= 1, s"Level must be greater than 1 but got $level.") + depth: Int): U = { + require(depth >= 1, s"Depth must be greater than 1 but got $depth.") if (this.partitions.size == 0) { return Utils.clone(zeroValue, sc.env.closureSerializer.newInstance()) } @@ -902,7 +915,7 @@ abstract class RDD[T: ClassTag]( val aggregatePartition = (it: Iterator[T]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp) var local = this.mapPartitions(it => Iterator(aggregatePartition(it))) var numPartitions = local.partitions.size - val scale = math.max(math.ceil(math.pow(numPartitions, 1.0 / level)).toInt, 2) + val scale = math.max(math.ceil(math.pow(numPartitions, 1.0 / depth)).toInt, 2) while (numPartitions > scale + numPartitions / scale) { numPartitions /= scale local = local.mapPartitionsWithIndex { (i, iter) =>