Skip to content

Commit

Permalink
adjust code based on comments and add more test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
chouqin committed Oct 19, 2014
1 parent 9857039 commit ffc920f
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 57 deletions.
94 changes: 40 additions & 54 deletions mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package org.apache.spark.mllib.tree

import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer


import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
Expand All @@ -36,7 +38,6 @@ import org.apache.spark.mllib.tree.model._
import org.apache.spark.rdd.RDD
import org.apache.spark.util.random.XORShiftRandom
import org.apache.spark.SparkContext._
import scala.collection.mutable.ArrayBuffer


/**
Expand Down Expand Up @@ -910,34 +911,39 @@ object DecisionTree extends Serializable with Logging {
// Iterate over all features.
var featureIndex = 0
while (featureIndex < numFeatures) {
val numSplits = metadata.numSplits(featureIndex)
val numBins = metadata.numBins(featureIndex)
if (metadata.isContinuous(featureIndex)) {
val featureSamples = sampledInput.map(lp => lp.features(featureIndex))
val featureSplits = findSplitsForContinuousFeature(featureSamples,
metadata, featureIndex)

val numSplits = featureSplits.length
val numBins = numSplits + 1
logDebug("numSplits= " + numSplits)
logDebug(s"featureIndex = $featureIndex, numSplits = $numSplits")
splits(featureIndex) = new Array[Split](numSplits)
bins(featureIndex) = new Array[Bin](numBins)

for (splitIndex <- 0 until numSplits) {
var splitIndex = 0
while (splitIndex < numSplits) {
val threshold = featureSplits(splitIndex)
splits(featureIndex)(splitIndex) =
new Split(featureIndex, threshold, Continuous, List())
splitIndex += 1
}
bins(featureIndex)(0) = new Bin(new DummyLowSplit(featureIndex, Continuous),
splits(featureIndex)(0), Continuous, Double.MinValue)
for (splitIndex <- 1 until numSplits) {

splitIndex = 1
while (splitIndex < numSplits) {
bins(featureIndex)(splitIndex) =
new Bin(splits(featureIndex)(splitIndex - 1), splits(featureIndex)(splitIndex),
Continuous, Double.MinValue)
splitIndex += 1
}
bins(featureIndex)(numSplits) = new Bin(splits(featureIndex)(numSplits - 1),
new DummyHighSplit(featureIndex, Continuous), Continuous, Double.MinValue)
} else {
val numSplits = metadata.numSplits(featureIndex)
val numBins = metadata.numBins(featureIndex)
// Categorical feature
val featureArity = metadata.featureArity(featureIndex)
if (metadata.isUnordered(featureIndex)) {
Expand Down Expand Up @@ -1018,9 +1024,13 @@ object DecisionTree extends Serializable with Logging {
* Find splits for a continuous feature
* NOTE: Returned number of splits is set based on `featureSamples` and
* may be different with `numSplits`.
* MetaData's number of splits will be set accordingly.
* NOTE: Returned number of splits is set based on `featureSamples` and
* could be different from the specified `numSplits`.
* The `numSplits` attribute in the `DecisionTreeMetadata` class will be set accordingly.
* @param featureSamples feature values of each sample
* @param metadata decision tree metadata
* NOTE: `metadata.numbins` will be changed accordingly
* if there are not enough splits to be found
* @param featureIndex feature index to find splits
* @return array of splits
*/
Expand All @@ -1029,48 +1039,22 @@ object DecisionTree extends Serializable with Logging {
metadata: DecisionTreeMetadata,
featureIndex: Int): Array[Double] = {
require(metadata.isContinuous(featureIndex),
s"findSplitsForContinuousFeature can only be used " +
s"to find splits for a continuous feature.")

/**
* Get count for each distinct value
*/
def getValueCount(arr: Array[Double]): Array[(Double, Int)] = {
val valueCount = new ArrayBuffer[(Double, Int)]
var index = 1
var currentValue = arr(0)
var currentCount = 1
while (index < arr.length) {
if (currentValue != arr(index)) {
valueCount.append((currentValue, currentCount))
currentCount = 1
currentValue = arr(index)
} else {
currentCount += 1
}
index += 1
}

valueCount.append((currentValue, currentCount))

valueCount.toArray
}

"findSplitsForContinuousFeature can only be used to find splits for a continuous feature.")

val splits = {
val numSplits = metadata.numSplits(featureIndex)

// sort feature samples first
val sortedFeatureSamples = featureSamples.sorted

// get count for each distinct value
val valueCount = getValueCount(sortedFeatureSamples)
val valueCountMap = featureSamples.foldLeft(Map.empty[Double, Int]) { (m, x) =>
m + ((x, m.getOrElse(x, 0) + 1))
}
// sort distinct values
val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray

// if possible splits is not enough or just enough,
// just return all possible splits
val possibleSplits = valueCount.length
// if possible splits is not enough or just enough, just return all possible splits
val possibleSplits = valueCounts.length
if (possibleSplits <= numSplits) {
valueCount.map(_._1)
valueCounts.map(_._1)
} else {
// stride between splits
val stride: Double = featureSamples.length.toDouble / (numSplits + 1)
Expand All @@ -1080,22 +1064,24 @@ object DecisionTree extends Serializable with Logging {
val splits = new ArrayBuffer[Double]
var index = 1
// currentCount: sum of counts of values that have been visited
var currentCount = valueCount(0)._2
// expectedCount: expected value for `currentCount`.
// If `currentCount` is closest value to `expectedCount`,
var currentCount = valueCounts(0)._2
// targetCount: target value for `currentCount`.
// If `currentCount` is closest value to `targetCount`,
// then current value is a split threshold.
// After finding a split threshold, `expectedCount` is added by stride.
var expectedCount = stride
while (index < valueCount.length) {
// After finding a split threshold, `targetCount` is added by stride.
var targetCount = stride
while (index < valueCounts.length) {
val previousCount = currentCount
currentCount += valueCounts(index)._2
val previousGap = math.abs(previousCount - targetCount)
val currentGap = math.abs(currentCount - targetCount)
// If adding count of current value to currentCount
// makes currentCount less close to expectedCount,
// makes the gap between currentCount and targetCount smaller,
// previous value is a split threshold.
if (math.abs(currentCount - expectedCount) <
math.abs(currentCount + valueCount(index)._2 - expectedCount)) {
splits.append(valueCount(index-1)._1)
expectedCount += stride
if (previousGap < currentGap) {
splits.append(valueCounts(index - 1)._1)
targetCount += stride
}
currentCount += valueCount(index)._2
index += 1
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,12 +125,45 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
Array(5), Gini, QuantileStrategy.Sort,
0, 0, 0.0, 0, 0
)
val featureSamples = Array(1.0, 1.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 3.0)
val featureSamples = Array(1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3).map(_.toDouble)
val splits = DecisionTree.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
assert(splits.length === 3)
assert(fakeMetadata.numSplits(0) === 3)
assert(fakeMetadata.numBins(0) === 4)
}

// find splits when most samples close to the minimum
{
val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
Map(), Set(),
Array(3), Gini, QuantileStrategy.Sort,
0, 0, 0.0, 0, 0
)
val featureSamples = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4, 5).map(_.toDouble)
val splits = DecisionTree.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
assert(splits.length === 2)
assert(fakeMetadata.numSplits(0) === 2)
assert(fakeMetadata.numBins(0) === 3)
assert(splits(0) === 2.0)
assert(splits(1) === 3.0)
}


// find splits when most samples close to the maximum
{
val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
Map(), Set(),
Array(3), Gini, QuantileStrategy.Sort,
0, 0, 0.0, 0, 0
)
val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2).map(_.toDouble)
val splits = DecisionTree.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
assert(splits.length === 1)
assert(fakeMetadata.numSplits(0) === 1)
assert(fakeMetadata.numBins(0) === 2)
assert(splits(0) === 1.0)
}

}

test("Multiclass classification with unordered categorical features:" +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ class RandomForestSuite extends FunSuite with LocalSparkContext {
val numTrees = 1

val strategy = new Strategy(algo = Regression, impurity = Variance,
maxDepth = 2, maxBins = 10,
numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo)
maxDepth = 2, maxBins = 10, numClassesForClassification = 2,
categoricalFeaturesInfo = categoricalFeaturesInfo)

val rf = RandomForest.trainRegressor(rdd, strategy, numTrees = numTrees,
featureSubsetStrategy = "auto", seed = 123)
Expand Down

0 comments on commit ffc920f

Please sign in to comment.