diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 694865b8595e7..fb0f7e9994c7c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -19,6 +19,8 @@ package org.apache.spark.mllib.tree import scala.collection.JavaConverters._ import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD @@ -36,7 +38,6 @@ import org.apache.spark.mllib.tree.model._ import org.apache.spark.rdd.RDD import org.apache.spark.util.random.XORShiftRandom import org.apache.spark.SparkContext._ -import scala.collection.mutable.ArrayBuffer /** @@ -910,8 +911,6 @@ object DecisionTree extends Serializable with Logging { // Iterate over all features. var featureIndex = 0 while (featureIndex < numFeatures) { - val numSplits = metadata.numSplits(featureIndex) - val numBins = metadata.numBins(featureIndex) if (metadata.isContinuous(featureIndex)) { val featureSamples = sampledInput.map(lp => lp.features(featureIndex)) val featureSplits = findSplitsForContinuousFeature(featureSamples, @@ -919,25 +918,32 @@ object DecisionTree extends Serializable with Logging { val numSplits = featureSplits.length val numBins = numSplits + 1 - logDebug("numSplits= " + numSplits) + logDebug(s"featureIndex = $featureIndex, numSplits = $numSplits") splits(featureIndex) = new Array[Split](numSplits) bins(featureIndex) = new Array[Bin](numBins) - for (splitIndex <- 0 until numSplits) { + var splitIndex = 0 + while (splitIndex < numSplits) { val threshold = featureSplits(splitIndex) splits(featureIndex)(splitIndex) = new Split(featureIndex, threshold, Continuous, List()) + splitIndex += 1 } bins(featureIndex)(0) = new Bin(new DummyLowSplit(featureIndex, Continuous), splits(featureIndex)(0), Continuous, Double.MinValue) - for (splitIndex <- 1 until numSplits) { + + splitIndex = 1 + while (splitIndex < numSplits) { bins(featureIndex)(splitIndex) = new Bin(splits(featureIndex)(splitIndex - 1), splits(featureIndex)(splitIndex), Continuous, Double.MinValue) + splitIndex += 1 } bins(featureIndex)(numSplits) = new Bin(splits(featureIndex)(numSplits - 1), new DummyHighSplit(featureIndex, Continuous), Continuous, Double.MinValue) } else { + val numSplits = metadata.numSplits(featureIndex) + val numBins = metadata.numBins(featureIndex) // Categorical feature val featureArity = metadata.featureArity(featureIndex) if (metadata.isUnordered(featureIndex)) { @@ -1018,9 +1024,13 @@ object DecisionTree extends Serializable with Logging { * Find splits for a continuous feature * NOTE: Returned number of splits is set based on `featureSamples` and * may be different with `numSplits`. - * MetaData's number of splits will be set accordingly. + * NOTE: Returned number of splits is set based on `featureSamples` and + * could be different from the specified `numSplits`. + * The `numSplits` attribute in the `DecisionTreeMetadata` class will be set accordingly. * @param featureSamples feature values of each sample * @param metadata decision tree metadata + * NOTE: `metadata.numbins` will be changed accordingly + * if there are not enough splits to be found * @param featureIndex feature index to find splits * @return array of splits */ @@ -1029,48 +1039,22 @@ object DecisionTree extends Serializable with Logging { metadata: DecisionTreeMetadata, featureIndex: Int): Array[Double] = { require(metadata.isContinuous(featureIndex), - s"findSplitsForContinuousFeature can only be used " + - s"to find splits for a continuous feature.") - - /** - * Get count for each distinct value - */ - def getValueCount(arr: Array[Double]): Array[(Double, Int)] = { - val valueCount = new ArrayBuffer[(Double, Int)] - var index = 1 - var currentValue = arr(0) - var currentCount = 1 - while (index < arr.length) { - if (currentValue != arr(index)) { - valueCount.append((currentValue, currentCount)) - currentCount = 1 - currentValue = arr(index) - } else { - currentCount += 1 - } - index += 1 - } - - valueCount.append((currentValue, currentCount)) - - valueCount.toArray - } - + "findSplitsForContinuousFeature can only be used to find splits for a continuous feature.") val splits = { val numSplits = metadata.numSplits(featureIndex) - // sort feature samples first - val sortedFeatureSamples = featureSamples.sorted - // get count for each distinct value - val valueCount = getValueCount(sortedFeatureSamples) + val valueCountMap = featureSamples.foldLeft(Map.empty[Double, Int]) { (m, x) => + m + ((x, m.getOrElse(x, 0) + 1)) + } + // sort distinct values + val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray - // if possible splits is not enough or just enough, - // just return all possible splits - val possibleSplits = valueCount.length + // if possible splits is not enough or just enough, just return all possible splits + val possibleSplits = valueCounts.length if (possibleSplits <= numSplits) { - valueCount.map(_._1) + valueCounts.map(_._1) } else { // stride between splits val stride: Double = featureSamples.length.toDouble / (numSplits + 1) @@ -1080,22 +1064,24 @@ object DecisionTree extends Serializable with Logging { val splits = new ArrayBuffer[Double] var index = 1 // currentCount: sum of counts of values that have been visited - var currentCount = valueCount(0)._2 - // expectedCount: expected value for `currentCount`. - // If `currentCount` is closest value to `expectedCount`, + var currentCount = valueCounts(0)._2 + // targetCount: target value for `currentCount`. + // If `currentCount` is closest value to `targetCount`, // then current value is a split threshold. - // After finding a split threshold, `expectedCount` is added by stride. - var expectedCount = stride - while (index < valueCount.length) { + // After finding a split threshold, `targetCount` is added by stride. + var targetCount = stride + while (index < valueCounts.length) { + val previousCount = currentCount + currentCount += valueCounts(index)._2 + val previousGap = math.abs(previousCount - targetCount) + val currentGap = math.abs(currentCount - targetCount) // If adding count of current value to currentCount - // makes currentCount less close to expectedCount, + // makes the gap between currentCount and targetCount smaller, // previous value is a split threshold. - if (math.abs(currentCount - expectedCount) < - math.abs(currentCount + valueCount(index)._2 - expectedCount)) { - splits.append(valueCount(index-1)._1) - expectedCount += stride + if (previousGap < currentGap) { + splits.append(valueCounts(index - 1)._1) + targetCount += stride } - currentCount += valueCount(index)._2 index += 1 } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 9e4e3ba6cef62..47a567ace8bed 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -125,12 +125,45 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { Array(5), Gini, QuantileStrategy.Sort, 0, 0, 0.0, 0, 0 ) - val featureSamples = Array(1.0, 1.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 3.0) + val featureSamples = Array(1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3).map(_.toDouble) val splits = DecisionTree.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) assert(splits.length === 3) assert(fakeMetadata.numSplits(0) === 3) assert(fakeMetadata.numBins(0) === 4) } + + // find splits when most samples close to the minimum + { + val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, + Map(), Set(), + Array(3), Gini, QuantileStrategy.Sort, + 0, 0, 0.0, 0, 0 + ) + val featureSamples = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4, 5).map(_.toDouble) + val splits = DecisionTree.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) + assert(splits.length === 2) + assert(fakeMetadata.numSplits(0) === 2) + assert(fakeMetadata.numBins(0) === 3) + assert(splits(0) === 2.0) + assert(splits(1) === 3.0) + } + + + // find splits when most samples close to the maximum + { + val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, + Map(), Set(), + Array(3), Gini, QuantileStrategy.Sort, + 0, 0, 0.0, 0, 0 + ) + val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2).map(_.toDouble) + val splits = DecisionTree.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) + assert(splits.length === 1) + assert(fakeMetadata.numSplits(0) === 1) + assert(fakeMetadata.numBins(0) === 2) + assert(splits(0) === 1.0) + } + } test("Multiclass classification with unordered categorical features:" + diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala index 7e88e01ace00d..6b13765b98f41 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala @@ -94,8 +94,8 @@ class RandomForestSuite extends FunSuite with LocalSparkContext { val numTrees = 1 val strategy = new Strategy(algo = Regression, impurity = Variance, - maxDepth = 2, maxBins = 10, - numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo) + maxDepth = 2, maxBins = 10, numClassesForClassification = 2, + categoricalFeaturesInfo = categoricalFeaturesInfo) val rf = RandomForest.trainRegressor(rdd, strategy, numTrees = numTrees, featureSubsetStrategy = "auto", seed = 123)