diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 952f03f10e538..a5a4e61049ccf 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -60,7 +60,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo // depth of the decision tree val maxDepth = strategy.maxDepth // the max number of nodes possible given the depth of the tree - val maxNumNodes = scala.math.pow(2, maxDepth).toInt - 1 + val maxNumNodes = math.pow(2, maxDepth).toInt - 1 // Initialize an array to hold filters applied to points for each node. val filters = new Array[List[Filter]](maxNumNodes) // The filter at the top node is an empty list. @@ -85,11 +85,11 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo } logDebug("numElementsPerNode = " + numElementsPerNode) val arraySizePerNode = 8 * numElementsPerNode // approx. memory usage for bin aggregate array - val maxNumberOfNodesPerGroup = scala.math.max(maxMemoryUsage / arraySizePerNode, 1) + val maxNumberOfNodesPerGroup = math.max(maxMemoryUsage / arraySizePerNode, 1) logDebug("maxNumberOfNodesPerGroup = " + maxNumberOfNodesPerGroup) // nodes at a level is 2^(level-1). level is zero indexed. - val maxLevelForSingleGroup = scala.math.max( - (scala.math.log(maxNumberOfNodesPerGroup) / scala.math.log(2)).floor.toInt - 1, 0) + val maxLevelForSingleGroup = math.max( + (math.log(maxNumberOfNodesPerGroup) / math.log(2)).floor.toInt - 1, 0) logDebug("max level for single group = " + maxLevelForSingleGroup) /* @@ -120,7 +120,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo filters) logDebug("final best split = " + nodeSplitStats._1) } - require(scala.math.pow(2, level) == splitsStatsForLevel.length) + require(math.pow(2, level) == splitsStatsForLevel.length) // Check whether all the nodes at the current level at leaves. val allLeaf = splitsStatsForLevel.forall(_._2.gain <= 0) logDebug("all leaf = " + allLeaf) @@ -153,7 +153,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo nodes: Array[Node]): Unit = { val split = nodeSplitStats._1 val stats = nodeSplitStats._2 - val nodeIndex = scala.math.pow(2, level).toInt - 1 + index + val nodeIndex = math.pow(2, level).toInt - 1 + index val isLeaf = (stats.gain <= 0) || (level == strategy.maxDepth - 1) val node = new Node(nodeIndex, stats.predict, isLeaf, Some(split), None, None, Some(stats)) logDebug("Node = " + node) @@ -174,7 +174,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo var i = 0 while (i <= 1) { // Calculate the index of the node from the node level and the index at the current level. - val nodeIndex = scala.math.pow(2, level + 1).toInt - 1 + 2 * index + i + val nodeIndex = math.pow(2, level + 1).toInt - 1 + 2 * index + i if (level < maxDepth - 1) { val impurity = if (i == 0) { nodeSplitStats._2.leftImpurity @@ -300,7 +300,7 @@ object DecisionTree extends Serializable with Logging { maxLevelForSingleGroup: Int): Array[(Split, InformationGainStats)] = { // split into groups to avoid memory overflow during aggregation if (level > maxLevelForSingleGroup) { - val numGroups = scala.math.pow(2, (level - maxLevelForSingleGroup)).toInt + val numGroups = math.pow(2, (level - maxLevelForSingleGroup)).toInt logDebug("numGroups = " + numGroups) var groupIndex = 0 var bestSplits = new Array[(Split, InformationGainStats)](0) @@ -366,7 +366,7 @@ object DecisionTree extends Serializable with Logging { */ // common calculations for multiple nested methods - val numNodes = scala.math.pow(2, level).toInt / numGroups + val numNodes = math.pow(2, level).toInt / numGroups logDebug("numNodes = " + numNodes) // Find the number of features by looking at the first sample. val numFeatures = input.first().features.size @@ -382,7 +382,7 @@ object DecisionTree extends Serializable with Logging { if (level == 0) { List[Filter]() } else { - val nodeFilterIndex = scala.math.pow(2, level).toInt - 1 + nodeIndex + groupShift + val nodeFilterIndex = math.pow(2, level).toInt - 1 + nodeIndex + groupShift filters(nodeFilterIndex) } } @@ -951,7 +951,7 @@ object DecisionTree extends Serializable with Logging { // Iterating over all nodes at this level var node = 0 while (node < numNodes) { - val nodeImpurityIndex = scala.math.pow(2, level).toInt - 1 + node + groupShift + val nodeImpurityIndex = math.pow(2, level).toInt - 1 + node + groupShift val binsForNode: Array[Double] = getBinDataForNode(node) logDebug("nodeImpurityIndex = " + nodeImpurityIndex) val parentNodeImpurity = parentImpurities(nodeImpurityIndex)