Skip to content

Commit

Permalink
fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
chouqin committed Oct 13, 2014
1 parent 2a8267a commit af6dc97
Showing 1 changed file with 19 additions and 15 deletions.
34 changes: 19 additions & 15 deletions mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
Original file line number Diff line number Diff line change
Expand Up @@ -923,20 +923,25 @@ object DecisionTree extends Serializable with Logging {
splits(featureIndex) = new Array[Split](numSplits)
bins(featureIndex) = new Array[Bin](numBins)

for (splitIndex <- 0 until numSplits) {
val threshold = featureSplits(splitIndex)
splits(featureIndex)(splitIndex) =
new Split(featureIndex, threshold, Continuous, List())
}
bins(featureIndex)(0) = new Bin(new DummyLowSplit(featureIndex, Continuous),
splits(featureIndex)(0), Continuous, Double.MinValue)
for (splitIndex <- 1 until numSplits) {
bins(featureIndex)(splitIndex) =
new Bin(splits(featureIndex)(splitIndex - 1), splits(featureIndex)(splitIndex),
Continuous, Double.MinValue)
if (numSplits == 0) {
bins(featureIndex)(0) = new Bin(new DummyLowSplit(featureIndex, Continuous),
new DummyLowSplit(featureIndex, Continuous), Continuous, Double.MinValue)
} else {
for (splitIndex <- 0 until numSplits) {
val threshold = featureSplits(splitIndex)
splits(featureIndex)(splitIndex) =
new Split(featureIndex, threshold, Continuous, List())
}
bins(featureIndex)(0) = new Bin(new DummyLowSplit(featureIndex, Continuous),
splits(featureIndex)(0), Continuous, Double.MinValue)
for (splitIndex <- 1 until numSplits) {
bins(featureIndex)(splitIndex) =
new Bin(splits(featureIndex)(splitIndex - 1), splits(featureIndex)(splitIndex),
Continuous, Double.MinValue)
}
bins(featureIndex)(numSplits) = new Bin(splits(featureIndex)(numSplits - 1),
new DummyHighSplit(featureIndex, Continuous), Continuous, Double.MinValue)
}
bins(featureIndex)(numSplits) = new Bin(splits(featureIndex)(numSplits - 1),
new DummyHighSplit(featureIndex, Continuous), Continuous, Double.MinValue)
} else {
// Categorical feature
val featureArity = metadata.featureArity(featureIndex)
Expand Down Expand Up @@ -1050,8 +1055,7 @@ object DecisionTree extends Serializable with Logging {
}
index += 1
}
// last value is not put into valueCount
// because we should not use it as a split threshold

valueCount.append((currentValue, currentCount))

valueCount.toArray
Expand Down

0 comments on commit af6dc97

Please sign in to comment.