From af7cb7962ff9f5041981ea5e4fe2465eceb6f0e5 Mon Sep 17 00:00:00 2001 From: Qiping Li Date: Thu, 9 Oct 2014 19:47:09 +0800 Subject: [PATCH] Choose splits for continuous features in DecisionTree more adaptively --- .../spark/mllib/tree/DecisionTree.scala | 72 +++++++++++++++++-- .../tree/impl/DecisionTreeMetadata.scala | 10 +++ 2 files changed, 75 insertions(+), 7 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 03eeaa707715b..efee37858bbd3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -36,6 +36,7 @@ import org.apache.spark.mllib.tree.model._ import org.apache.spark.rdd.RDD import org.apache.spark.util.random.XORShiftRandom import org.apache.spark.SparkContext._ +import scala.collection.mutable.ArrayBuffer /** @@ -912,16 +913,19 @@ object DecisionTree extends Serializable with Logging { val numSplits = metadata.numSplits(featureIndex) val numBins = metadata.numBins(featureIndex) if (metadata.isContinuous(featureIndex)) { - val numSamples = sampledInput.length + + val featureSamples = sampledInput.map(lp => lp.features(featureIndex)).sorted + val featureSplits = findSplits(featureSamples, metadata.numSplits(featureIndex)) + metadata.setNumBinForFeature(featureIndex, metadata.numSplits(featureIndex)) + val numSplits = metadata.numSplits(featureIndex) + val numBins = metadata.numBins(featureIndex) + logDebug("numSplits= " + numSplits) + splits(featureIndex) = new Array[Split](numSplits) bins(featureIndex) = new Array[Bin](numBins) - val featureSamples = sampledInput.map(lp => lp.features(featureIndex)).sorted - val stride: Double = numSamples.toDouble / metadata.numBins(featureIndex) - logDebug("stride = " + stride) + for (splitIndex <- 0 until numSplits) { - val sampleIndex = splitIndex * stride.toInt - // Set threshold halfway in between 2 samples. - val threshold = (featureSamples(sampleIndex) + featureSamples(sampleIndex + 1)) / 2.0 + val threshold = featureSplits(splitIndex) splits(featureIndex)(splitIndex) = new Split(featureIndex, threshold, Continuous, List()) } @@ -1011,4 +1015,58 @@ object DecisionTree extends Serializable with Logging { categories } + /** + * Find splits for a continuous feature + * @param featureSamples + * @param numSplits + * @return + */ + private def findSplits(featureSamples: Array[Double], numSplits: Int): Array[Double] = { + /* + * Get count for each distinct value + */ + def getValueCount(arr: Array[Double]): Array[(Double, Int)] = { + val valueCount = new ArrayBuffer[(Double, Int)] + var index = 1 + var currentValue = arr(0) + var currentCount = 1 + while (index < arr.length) { + if (currentValue != arr(index)) { + valueCount.append((currentValue, currentCount)) + currentCount = 1 + currentValue = arr(index) + } else { + currentCount += 1 + } + index += 1 + } + valueCount.append((currentValue, currentCount)) + + valueCount.toArray + } + + val valueCount = getValueCount(featureSamples) + if (valueCount.length <= numSplits) { + return valueCount.map(_._1) + } + + val stride: Double = featureSamples.length.toDouble / (numSplits + 1) + logDebug("stride = " + stride) + + val splits = new ArrayBuffer[Double] + var index = 1 + var currentCount = valueCount(0)._2 + var expectedCount = stride + while (index < valueCount.length) { + if (math.abs(currentCount - expectedCount) < + math.abs(currentCount + valueCount(index)._2 - expectedCount)) { + splits.append(valueCount(index-1)._1) + expectedCount += stride + } + currentCount += valueCount(index)._2 + index += 1 + } + + splits.toArray + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala index 212dce25236e0..75c974be1446b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala @@ -75,6 +75,16 @@ private[tree] class DecisionTreeMetadata( numBins(featureIndex) - 1 } + + /** + * + */ + def setNumBinForFeature(featureIndex: Int, numBin: Int) { + require(isContinuous(featureIndex), + s"Can only set number of bin for continuous feature.") + numBins(featureIndex) = numBin + } + /** * Indicates if feature subsampling is being used. */