diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 4a77d4adcd865..53d6482f8057c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -19,7 +19,6 @@ package org.apache.spark.mllib.tree import scala.collection.JavaConverters._ import scala.collection.mutable -import scala.collection.mutable.ArrayBuilder import org.apache.spark.Logging import org.apache.spark.annotation.{Experimental, Since} @@ -643,8 +642,8 @@ object DecisionTree extends Serializable with Logging { val nodeToBestSplits = partitionAggregates.reduceByKey((a, b) => a.merge(b)) .map { case (nodeIndex, aggStats) => - val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures => - Some(nodeToFeatures(nodeIndex)) + val featuresForNode = nodeToFeaturesBc.value.map { nodeToFeatures => + nodeToFeatures(nodeIndex) } // find best split for each node @@ -976,8 +975,8 @@ object DecisionTree extends Serializable with Logging { val numFeatures = metadata.numFeatures // Sample the input only if there are continuous features. - val hasContinuousFeatures = Range(0, numFeatures).exists(metadata.isContinuous) - val sampledInput = if (hasContinuousFeatures) { + val continuousFeatures = Range(0, numFeatures).filter(metadata.isContinuous) + val sampledInput = if (continuousFeatures.nonEmpty) { // Calculate the number of samples for approximate quantile calculation. val requiredSamples = math.max(metadata.maxBins * metadata.maxBins, 10000) val fraction = if (requiredSamples < metadata.numExamples) { @@ -986,81 +985,14 @@ object DecisionTree extends Serializable with Logging { 1.0 } logDebug("fraction of data used for calculating quantiles = " + fraction) - input.sample(withReplacement = false, fraction, new XORShiftRandom().nextInt()).collect() + input.sample(withReplacement = false, fraction, new XORShiftRandom().nextInt()) } else { - new Array[LabeledPoint](0) + input.sparkContext.emptyRDD[LabeledPoint] } metadata.quantileStrategy match { case Sort => - val splits = new Array[Array[Split]](numFeatures) - val bins = new Array[Array[Bin]](numFeatures) - - // Find all splits. - // Iterate over all features. - var featureIndex = 0 - while (featureIndex < numFeatures) { - if (metadata.isContinuous(featureIndex)) { - val featureSamples = sampledInput.map(lp => lp.features(featureIndex)) - val featureSplits = findSplitsForContinuousFeature(featureSamples, - metadata, featureIndex) - - val numSplits = featureSplits.length - val numBins = numSplits + 1 - logDebug(s"featureIndex = $featureIndex, numSplits = $numSplits") - splits(featureIndex) = new Array[Split](numSplits) - bins(featureIndex) = new Array[Bin](numBins) - - var splitIndex = 0 - while (splitIndex < numSplits) { - val threshold = featureSplits(splitIndex) - splits(featureIndex)(splitIndex) = - new Split(featureIndex, threshold, Continuous, List()) - splitIndex += 1 - } - bins(featureIndex)(0) = new Bin(new DummyLowSplit(featureIndex, Continuous), - splits(featureIndex)(0), Continuous, Double.MinValue) - - splitIndex = 1 - while (splitIndex < numSplits) { - bins(featureIndex)(splitIndex) = - new Bin(splits(featureIndex)(splitIndex - 1), splits(featureIndex)(splitIndex), - Continuous, Double.MinValue) - splitIndex += 1 - } - bins(featureIndex)(numSplits) = new Bin(splits(featureIndex)(numSplits - 1), - new DummyHighSplit(featureIndex, Continuous), Continuous, Double.MinValue) - } else { - val numSplits = metadata.numSplits(featureIndex) - val numBins = metadata.numBins(featureIndex) - // Categorical feature - val featureArity = metadata.featureArity(featureIndex) - if (metadata.isUnordered(featureIndex)) { - // Unordered features - // 2^(maxFeatureValue - 1) - 1 combinations - splits(featureIndex) = new Array[Split](numSplits) - var splitIndex = 0 - while (splitIndex < numSplits) { - val categories: List[Double] = - extractMultiClassCategories(splitIndex + 1, featureArity) - splits(featureIndex)(splitIndex) = - new Split(featureIndex, Double.MinValue, Categorical, categories) - splitIndex += 1 - } - } else { - // Ordered features - // Bins correspond to feature values, so we do not need to compute splits or bins - // beforehand. Splits are constructed as needed during training. - splits(featureIndex) = new Array[Split](0) - } - // For ordered features, bins correspond to feature values. - // For unordered categorical features, there is no need to construct the bins. - // since there is a one-to-one correspondence between the splits and the bins. - bins(featureIndex) = new Array[Bin](0) - } - featureIndex += 1 - } - (splits, bins) + findSplitsBinsBySorting(sampledInput, metadata, continuousFeatures) case MinMax => throw new UnsupportedOperationException("minmax not supported yet.") case ApproxHist => @@ -1068,6 +1000,82 @@ object DecisionTree extends Serializable with Logging { } } + private def findSplitsBinsBySorting( + input: RDD[LabeledPoint], + metadata: DecisionTreeMetadata, + continuousFeatures: IndexedSeq[Int]): (Array[Array[Split]], Array[Array[Bin]]) = { + def findSplits( + featureIndex: Int, + featureSamples: Iterable[Double]): (Int, (Array[Split], Array[Bin])) = { + val splits = { + val featureSplits = findSplitsForContinuousFeature( + featureSamples.toArray, + metadata, + featureIndex) + logDebug(s"featureIndex = $featureIndex, numSplits = ${featureSplits.length}") + + featureSplits.map(threshold => new Split(featureIndex, threshold, Continuous, Nil)) + } + + val bins = { + val lowSplit = new DummyLowSplit(featureIndex, Continuous) + val highSplit = new DummyHighSplit(featureIndex, Continuous) + + // tack the dummy splits on either side of the computed splits + val allSplits = lowSplit +: splits.toSeq :+ highSplit + + // slide across the split points pairwise to allocate the bins + allSplits.sliding(2).map { + case Seq(left, right) => new Bin(left, right, Continuous, Double.MinValue) + }.toArray + } + + (featureIndex, (splits, bins)) + } + + val continuousSplits = { + // reduce the parallelism for split computations when there are less + // continuous features than input partitions. this prevents tasks from + // being spun up that will definitely do no work. + val numPartitions = math.min(continuousFeatures.length, input.partitions.length) + + input + .flatMap(point => continuousFeatures.map(idx => (idx, point.features(idx)))) + .groupByKey(numPartitions) + .map { case (k, v) => findSplits(k, v) } + .collectAsMap() + } + + val numFeatures = metadata.numFeatures + val (splits, bins) = Range(0, numFeatures).unzip { + case i if metadata.isContinuous(i) => + val (split, bin) = continuousSplits(i) + metadata.setNumSplits(i, split.length) + (split, bin) + + case i if metadata.isCategorical(i) && metadata.isUnordered(i) => + // Unordered features + // 2^(maxFeatureValue - 1) - 1 combinations + val featureArity = metadata.featureArity(i) + val split = Range(0, metadata.numSplits(i)).map { splitIndex => + val categories = extractMultiClassCategories(splitIndex + 1, featureArity) + new Split(i, Double.MinValue, Categorical, categories) + } + + // For unordered categorical features, there is no need to construct the bins. + // since there is a one-to-one correspondence between the splits and the bins. + (split.toArray, Array.empty[Bin]) + + case i if metadata.isCategorical(i) => + // Ordered features + // Bins correspond to feature values, so we do not need to compute splits or bins + // beforehand. Splits are constructed as needed during training. + (Array.empty[Split], Array.empty[Bin]) + } + + (splits.toArray, bins.toArray) + } + /** * Nested method to extract list of eligible categories given an index. It extracts the * position of ones in a binary representation of the input. If binary @@ -1131,7 +1139,7 @@ object DecisionTree extends Serializable with Logging { logDebug("stride = " + stride) // iterate `valueCount` to find splits - val splitsBuilder = ArrayBuilder.make[Double] + val splitsBuilder = Array.newBuilder[Double] var index = 1 // currentCount: sum of counts of values that have been visited var currentCount = valueCounts(0)._2 @@ -1163,8 +1171,8 @@ object DecisionTree extends Serializable with Logging { assert(splits.length > 0, s"DecisionTree could not handle feature $featureIndex since it had only 1 unique value." + " Please remove this feature and then try again.") - // set number of splits accordingly - metadata.setNumSplits(featureIndex, splits.length) + + // the split metadata must be updated on the driver splits } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala index 0abed5411143d..1c611976a9308 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala @@ -108,21 +108,21 @@ private[spark] class NodeIdCache( prevNodeIdsForInstances = nodeIdsForInstances nodeIdsForInstances = data.zip(nodeIdsForInstances).map { - dataPoint => { + case (point, node) => { var treeId = 0 while (treeId < nodeIdUpdaters.length) { - val nodeIdUpdater = nodeIdUpdaters(treeId).getOrElse(dataPoint._2(treeId), null) + val nodeIdUpdater = nodeIdUpdaters(treeId).getOrElse(node(treeId), null) if (nodeIdUpdater != null) { val newNodeIndex = nodeIdUpdater.updateNodeIndex( - binnedFeatures = dataPoint._1.datum.binnedFeatures, + binnedFeatures = point.datum.binnedFeatures, bins = bins) - dataPoint._2(treeId) = newNodeIndex + node(treeId) = newNodeIndex } treeId += 1 } - dataPoint._2 + node } } @@ -138,7 +138,7 @@ private[spark] class NodeIdCache( while (checkpointQueue.size > 1 && canDelete) { // We can delete the oldest checkpoint iff // the next checkpoint actually exists in the file system. - if (checkpointQueue.get(1).get.getCheckpointFile != None) { + if (checkpointQueue.get(1).get.getCheckpointFile.isDefined) { val old = checkpointQueue.dequeue() // Since the old checkpoint is not deleted by Spark, @@ -159,11 +159,11 @@ private[spark] class NodeIdCache( * Call this after training is finished to delete any remaining checkpoints. */ def deleteAllCheckpoints(): Unit = { - while (checkpointQueue.size > 0) { + while (checkpointQueue.nonEmpty) { val old = checkpointQueue.dequeue() - if (old.getCheckpointFile != None) { + for (checkpointFile <- old.getCheckpointFile) { val fs = FileSystem.get(old.sparkContext.hadoopConfiguration) - fs.delete(new Path(old.getCheckpointFile.get), true) + fs.delete(new Path(checkpointFile), true) } } if (prevNodeIdsForInstances != null) { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 356d957f15909..1a4299db4eab2 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -135,8 +135,6 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { val featureSamples = Array(1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3).map(_.toDouble) val splits = DecisionTree.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) assert(splits.length === 3) - assert(fakeMetadata.numSplits(0) === 3) - assert(fakeMetadata.numBins(0) === 4) // check returned splits are distinct assert(splits.distinct.length === splits.length) } @@ -151,8 +149,6 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { val featureSamples = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4, 5).map(_.toDouble) val splits = DecisionTree.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) assert(splits.length === 2) - assert(fakeMetadata.numSplits(0) === 2) - assert(fakeMetadata.numBins(0) === 3) assert(splits(0) === 2.0) assert(splits(1) === 3.0) } @@ -167,8 +163,6 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2).map(_.toDouble) val splits = DecisionTree.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) assert(splits.length === 1) - assert(fakeMetadata.numSplits(0) === 1) - assert(fakeMetadata.numBins(0) === 2) assert(splits(0) === 1.0) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala index 334bf3790fc7a..3d3f80063f904 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala @@ -69,8 +69,8 @@ object EnsembleTestHelper { required: Double, metricName: String = "mse") { val predictions = input.map(x => model.predict(x.features)) - val errors = predictions.zip(input.map(_.label)).map { case (prediction, label) => - label - prediction + val errors = predictions.zip(input).map { case (prediction, point) => + point.label - prediction } val metric = metricName match { case "mse" =>