Skip to content

Commit

Permalink
fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
chouqin committed Oct 13, 2014
1 parent f69f47f commit 092efcb
Showing 1 changed file with 29 additions and 29 deletions.
58 changes: 29 additions & 29 deletions mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1070,37 +1070,37 @@ object DecisionTree extends Serializable with Logging {
// just return all possible splits
val possibleSplits = valueCount.length
if (possibleSplits <= numSplits) {
return valueCount.map(_._1)
}

// stride between splits
val stride: Double = featureSamples.length.toDouble / (numSplits + 1)
logDebug("stride = " + stride)

// iterate `valueCount` to find splits
val splits = new ArrayBuffer[Double]
var index = 1
// currentCount: sum of counts of values that have been visited
var currentCount = valueCount(0)._2
// expectedCount: expected value for `currentCount`.
// If `currentCount` is closest value to `expectedCount`,
// then current value is a split threshold.
// After finding a split threshold, `expectedCount` is added by stride.
var expectedCount = stride
while (index < valueCount.length) {
// If adding count of current value to currentCount
// makes currentCount less close to expectedCount,
// previous value is a split threshold.
if (math.abs(currentCount - expectedCount) <
math.abs(currentCount + valueCount(index)._2 - expectedCount)) {
splits.append(valueCount(index-1)._1)
expectedCount += stride
valueCount.map(_._1)
} else {
// stride between splits
val stride: Double = featureSamples.length.toDouble / (numSplits + 1)
logDebug("stride = " + stride)

// iterate `valueCount` to find splits
val splits = new ArrayBuffer[Double]
var index = 1
// currentCount: sum of counts of values that have been visited
var currentCount = valueCount(0)._2
// expectedCount: expected value for `currentCount`.
// If `currentCount` is closest value to `expectedCount`,
// then current value is a split threshold.
// After finding a split threshold, `expectedCount` is added by stride.
var expectedCount = stride
while (index < valueCount.length) {
// If adding count of current value to currentCount
// makes currentCount less close to expectedCount,
// previous value is a split threshold.
if (math.abs(currentCount - expectedCount) <
math.abs(currentCount + valueCount(index)._2 - expectedCount)) {
splits.append(valueCount(index-1)._1)
expectedCount += stride
}
currentCount += valueCount(index)._2
index += 1
}
currentCount += valueCount(index)._2
index += 1
}

splits.toArray
splits.toArray
}
}

assert(splits.length > 0)
Expand Down

0 comments on commit 092efcb

Please sign in to comment.