Skip to content

Commit

Permalink
Small cleanups after original tree API PR
Browse files Browse the repository at this point in the history
  • Loading branch information
jkbradley committed Apr 17, 2015
1 parent a83571a commit bb9f610
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ object DecisionTreeExample {
val parser = new OptionParser[Params]("DecisionTreeExample") {
head("DecisionTreeExample: an example decision tree app.")
opt[String]("algo")
.text(s"algorithm (Classification, Regression), default: ${defaultParams.algo}")
.text(s"algorithm (classification, regression), default: ${defaultParams.algo}")
.action((x, c) => c.copy(algo = x))
opt[Int]("maxDepth")
.text(s"max depth of the tree, default: ${defaultParams.maxDepth}")
Expand Down Expand Up @@ -221,18 +221,23 @@ object DecisionTreeExample {
// (1) For classification, re-index classes.
val labelColName = if (algo == "classification") "indexedLabel" else "label"
if (algo == "classification") {
val labelIndexer = new StringIndexer().setInputCol("labelString").setOutputCol(labelColName)
val labelIndexer = new StringIndexer()
.setInputCol("labelString")
.setOutputCol(labelColName)
stages += labelIndexer
}
// (2) Identify categorical features using VectorIndexer.
// Features with more than maxCategories values will be treated as continuous.
val featuresIndexer = new VectorIndexer().setInputCol("features")
.setOutputCol("indexedFeatures").setMaxCategories(10)
val featuresIndexer = new VectorIndexer()
.setInputCol("features")
.setOutputCol("indexedFeatures")
.setMaxCategories(10)
stages += featuresIndexer
// (3) Learn DecisionTree
val dt = algo match {
case "classification" =>
new DecisionTreeClassifier().setFeaturesCol("indexedFeatures")
new DecisionTreeClassifier()
.setFeaturesCol("indexedFeatures")
.setLabelCol(labelColName)
.setMaxDepth(params.maxDepth)
.setMaxBins(params.maxBins)
Expand All @@ -241,7 +246,8 @@ object DecisionTreeExample {
.setCacheNodeIds(params.cacheNodeIds)
.setCheckpointInterval(params.checkpointInterval)
case "regression" =>
new DecisionTreeRegressor().setFeaturesCol("indexedFeatures")
new DecisionTreeRegressor()
.setFeaturesCol("indexedFeatures")
.setLabelCol(labelColName)
.setMaxDepth(params.maxDepth)
.setMaxBins(params.maxBins)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ private[ml] trait TreeRegressorParams extends Params {
def getImpurity: String = getOrDefault(impurity)

/** Convert new impurity to old impurity. */
protected def getOldImpurity: OldImpurity = {
private[ml] def getOldImpurity: OldImpurity = {
getImpurity match {
case "variance" => OldVariance
case _ =>
Expand Down
7 changes: 4 additions & 3 deletions mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ sealed trait Split extends Serializable {
private[tree] def toOld: OldSplit
}

private[ml] object Split {
private[tree] object Split {

def fromOld(oldSplit: OldSplit, categoricalFeatures: Map[Int, Int]): Split = {
oldSplit.featureType match {
Expand All @@ -58,7 +58,7 @@ private[ml] object Split {
* left. Otherwise, it goes right.
* @param numCategories Number of categories for this feature.
*/
final class CategoricalSplit(
final class CategoricalSplit private[ml] (
override val featureIndex: Int,
leftCategories: Array[Double],
private val numCategories: Int)
Expand Down Expand Up @@ -130,7 +130,8 @@ final class CategoricalSplit(
* @param threshold If the feature value is <= this threshold, then the split goes left.
* Otherwise, it goes right.
*/
final class ContinuousSplit(override val featureIndex: Int, val threshold: Double) extends Split {
final class ContinuousSplit private[ml] (override val featureIndex: Int, val threshold: Double)
extends Split {

override private[ml] def shouldGoLeft(features: Vector): Boolean = {
features(featureIndex) <= threshold
Expand Down

0 comments on commit bb9f610

Please sign in to comment.