Skip to content

Commit

Permalink
todo for multiclass support
Browse files Browse the repository at this point in the history
  • Loading branch information
manishamde committed May 12, 2014
1 parent 455bea9 commit 46f909c
Showing 1 changed file with 9 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,7 @@ object DecisionTree extends Serializable with Logging {
// Update the left or right count for one bin.
val aggShift = 2 * numBins * numFeatures * nodeIndex
val aggIndex = aggShift + 2 * featureIndex * numBins + arr(arrIndex).toInt * 2
// TODO: Multiclass modification here
label match {
case 0.0 => agg(aggIndex) = agg(aggIndex) + 1
case 1.0 => agg(aggIndex + 1) = agg(aggIndex + 1) + 1
Expand Down Expand Up @@ -679,6 +680,7 @@ object DecisionTree extends Serializable with Logging {
topImpurity: Double): InformationGainStats = {
strategy.algo match {
case Classification =>
// TODO: Modify here
val left0Count = leftNodeAgg(featureIndex)(2 * splitIndex)
val left1Count = leftNodeAgg(featureIndex)(2 * splitIndex + 1)
val leftCount = left0Count + left1Count
Expand Down Expand Up @@ -779,6 +781,7 @@ object DecisionTree extends Serializable with Logging {
binData: Array[Double]): (Array[Array[Double]], Array[Array[Double]]) = {
strategy.algo match {
case Classification =>
// TODO: Multiclass modification here
// Initialize left and right split aggregates.
val leftNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numBins - 1))
val rightNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numBins - 1))
Expand Down Expand Up @@ -904,6 +907,8 @@ object DecisionTree extends Serializable with Logging {
binData: Array[Double],
nodeImpurity: Double): (Split, InformationGainStats) = {

// TODO: Multiclass modification here

logDebug("node impurity = " + nodeImpurity)

// Extract left right node aggregates.
Expand Down Expand Up @@ -948,6 +953,7 @@ object DecisionTree extends Serializable with Logging {
def getBinDataForNode(node: Int): Array[Double] = {
strategy.algo match {
case Classification =>
// TODO: Multiclass modification here
val shift = 2 * node * numBins * numFeatures
val binsForNode = binAggregates.slice(shift, shift + 2 * numBins * numFeatures)
binsForNode
Expand Down Expand Up @@ -997,6 +1003,8 @@ object DecisionTree extends Serializable with Logging {
val numBins = if (maxBins <= count) maxBins else count.toInt
logDebug("numBins = " + numBins)

// TODO: Multiclass modification here

/*
* TODO: Add a require statement ensuring #bins is always greater than the categories.
* It's a limitation of the current implementation but a reasonable trade-off since features
Expand Down Expand Up @@ -1041,6 +1049,7 @@ object DecisionTree extends Serializable with Logging {
splits(featureIndex)(index) = split
}
} else {
// TODO: Multiclass modification here
val maxFeatureValue = strategy.categoricalFeaturesInfo(featureIndex)
require(maxFeatureValue < numBins, "number of categories should be less than number " +
"of bins")
Expand Down

0 comments on commit 46f909c

Please sign in to comment.