Skip to content

Commit

Permalink
boostingstrategy.defaultParam string algo to enumeration.
Browse files Browse the repository at this point in the history
  • Loading branch information
Peishen-Jia committed Jan 19, 2015
1 parent 3453d57 commit e04a5aa
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -88,4 +88,25 @@ object BoostingStrategy {
throw new IllegalArgumentException(s"$algo is not supported by boosting.")
}
}

/**
* Returns default configuration for the boosting algorithm
* @param algo Learning goal. Supported:
* [[org.apache.spark.mllib.tree.configuration.Algo.Classification]],
* [[org.apache.spark.mllib.tree.configuration.Algo.Regression]]
* @return Configuration for boosting algorithm
*/
def defaultParams(algo: Algo): BoostingStrategy = {
val treeStragtegy = Strategy.defaultStategy(algo)
treeStragtegy.maxDepth = 3
algo match {
case Algo.Classification =>
treeStragtegy.numClasses = 2
new BoostingStrategy(treeStragtegy, LogLoss)
case Algo.Regression =>
new BoostingStrategy(treeStragtegy, SquaredError)
case _ =>
throw new IllegalArgumentException(s"$algo is not supported by boosting.")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -181,4 +181,17 @@ object Strategy {
new Strategy(algo = Regression, impurity = Variance, maxDepth = 10,
numClasses = 0)
}

/**
* Construct a default set of parameters for [[org.apache.spark.mllib.tree.DecisionTree]]
* @param algo Algo.Classification or Algo.Regression
*/
def defaultStategy(algo: Algo): Strategy = algo match {
case Algo.Classification =>
new Strategy(algo = Classification, impurity = Gini, maxDepth = 10,
numClasses = 2)
case Algo.Regression =>
new Strategy(algo = Regression, impurity = Variance, maxDepth = 10,
numClasses = 0)
}
}

0 comments on commit e04a5aa

Please sign in to comment.