diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 499fa2dff7be6..23cf6bce6dcdc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -143,9 +143,9 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo timer.start("extractNodeInfo") val split = nodeSplitStats._1 val stats = nodeSplitStats._2 - val predict = nodeSplitStats._3 + val predict = nodeSplitStats._3.predict val isLeaf = (stats.gain <= 0) || (level == strategy.maxDepth) - val node = new Node(nodeIndex, predict.predict, isLeaf, Some(split), None, None, Some(stats)) + val node = new Node(nodeIndex, predict, isLeaf, Some(split), None, None, Some(stats)) logDebug("Node = " + node) nodes(nodeIndex) = node timer.stop("extractNodeInfo") @@ -735,14 +735,13 @@ object DecisionTree extends Serializable with Logging { topImpurity: Double, level: Int, metadata: DecisionTreeMetadata): InformationGainStats = { - val leftCount = leftImpurityCalculator.count val rightCount = rightImpurityCalculator.count // If left child or right child doesn't satisfy minimum instances per node, // then this split is invalid, return invalid information gain stats if ((leftCount < metadata.minInstancesPerNode) || - (rightCount < metadata.minInstancesPerNode)) { + (rightCount < metadata.minInstancesPerNode)) { return InformationGainStats.invalidInformationGainStats }