Skip to content

Commit

Permalink
[SPARK-5317]Set BoostingStrategy.defaultParams With Enumeration Algo.…
Browse files Browse the repository at this point in the history
…Classification or Algo.Regression

JIRA Issue: https://issues.apache.org/jira/browse/SPARK-5317
When setting the BoostingStrategy.defaultParams("Classification"), It's more straightforward to set it with the Enumeration Algo.Classification, just like BoostingStragety.defaultParams(Algo.Classification).
I overload the method BoostingStragety.defaultParams().

Author: Basin <[email protected]>

Closes apache#4103 from Peishen-Jia/stragetyAlgo and squashes the following commits:

87bab1c [Basin] Docs and Code documentations updated.
3b72875 [Basin] defaultParams(algoStr: String) call defaultParams(algo: Algo).
7c1e6ee [Basin] Doc of Java updated. algo -> algoStr instead.
d5c8a2e [Basin] Merge branch 'stragetyAlgo' of github.com:Peishen-Jia/spark into stragetyAlgo
65f96ce [Basin] mllib-ensembles doc modified.
e04a5aa [Basin] boostingstrategy.defaultParam string algo to enumeration.
68cf544 [Basin] mllib-ensembles doc modified.
a4aea51 [Basin] boostingstrategy.defaultParam string algo to enumeration.
  • Loading branch information
Peishen-Jia authored and mengxr committed Jan 22, 2015
1 parent ca7910d commit fcb3e18
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -68,22 +68,31 @@ case class BoostingStrategy(
@Experimental
object BoostingStrategy {

/**
* Returns default configuration for the boosting algorithm
* @param algo Learning goal. Supported: "Classification" or "Regression"
* @return Configuration for boosting algorithm
*/
def defaultParams(algo: String): BoostingStrategy = {
defaultParams(Algo.fromString(algo))
}

/**
* Returns default configuration for the boosting algorithm
* @param algo Learning goal. Supported:
* [[org.apache.spark.mllib.tree.configuration.Algo.Classification]],
* [[org.apache.spark.mllib.tree.configuration.Algo.Regression]]
* @return Configuration for boosting algorithm
*/
def defaultParams(algo: String): BoostingStrategy = {
val treeStrategy = Strategy.defaultStrategy(algo)
treeStrategy.maxDepth = 3
def defaultParams(algo: Algo): BoostingStrategy = {
val treeStragtegy = Strategy.defaultStategy(algo)
treeStragtegy.maxDepth = 3
algo match {
case "Classification" =>
treeStrategy.numClasses = 2
new BoostingStrategy(treeStrategy, LogLoss)
case "Regression" =>
new BoostingStrategy(treeStrategy, SquaredError)
case Algo.Classification =>
treeStragtegy.numClasses = 2
new BoostingStrategy(treeStragtegy, LogLoss)
case Algo.Regression =>
new BoostingStrategy(treeStragtegy, SquaredError)
case _ =>
throw new IllegalArgumentException(s"$algo is not supported by boosting.")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,11 +173,19 @@ object Strategy {
* Construct a default set of parameters for [[org.apache.spark.mllib.tree.DecisionTree]]
* @param algo "Classification" or "Regression"
*/
def defaultStrategy(algo: String): Strategy = algo match {
case "Classification" =>
def defaultStrategy(algo: String): Strategy = {
defaultStategy(Algo.fromString(algo))
}

/**
* Construct a default set of parameters for [[org.apache.spark.mllib.tree.DecisionTree]]
* @param algo Algo.Classification or Algo.Regression
*/
def defaultStategy(algo: Algo): Strategy = algo match {
case Algo.Classification =>
new Strategy(algo = Classification, impurity = Gini, maxDepth = 10,
numClasses = 2)
case "Regression" =>
case Algo.Regression =>
new Strategy(algo = Regression, impurity = Variance, maxDepth = 10,
numClasses = 0)
}
Expand Down

0 comments on commit fcb3e18

Please sign in to comment.