Skip to content

Commit

Permalink
Doc of Java updated. algo -> algoStr instead.
Browse files Browse the repository at this point in the history
  • Loading branch information
Peishen-Jia committed Jan 21, 2015
1 parent d5c8a2e commit 7c1e6ee
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 9 deletions.
4 changes: 2 additions & 2 deletions docs/mllib-ensembles.md
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,7 @@ JavaRDD<LabeledPoint> testData = splits[1];

// Train a GradientBoostedTrees model.
// The defaultParams for Classification use LogLoss by default.
BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams(Algo.Classification);
BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams(Algo.Classification());
boostingStrategy.setNumIterations(3); // Note: Use more iterations in practice.
boostingStrategy.getTreeStrategy().setNumClassesForClassification(2);
boostingStrategy.getTreeStrategy().setMaxDepth(5);
Expand Down Expand Up @@ -614,7 +614,7 @@ JavaRDD<LabeledPoint> testData = splits[1];

// Train a GradientBoostedTrees model.
// The defaultParams for Regression use SquaredError by default.
BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams(Algo.Regression);
BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams(Algo.Regression());
boostingStrategy.setNumIterations(3); // Note: Use more iterations in practice.
boostingStrategy.getTreeStrategy().setMaxDepth(5);
// Empty categoricalFeaturesInfo indicates all features are continuous.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,22 +70,22 @@ object BoostingStrategy {

/**
* Returns default configuration for the boosting algorithm
* @param algo Learning goal. Supported:
* @param algoStr Learning goal. Supported:
* [[org.apache.spark.mllib.tree.configuration.Algo.Classification]],
* [[org.apache.spark.mllib.tree.configuration.Algo.Regression]]
* @return Configuration for boosting algorithm
*/
def defaultParams(algo: String): BoostingStrategy = {
val treeStrategy = Strategy.defaultStrategy(algo)
def defaultParams(algoStr: String): BoostingStrategy = {
val treeStrategy = Strategy.defaultStrategy(algoStr)
treeStrategy.maxDepth = 3
algo match {
algoStr match {
case "Classification" =>
treeStrategy.numClasses = 2
new BoostingStrategy(treeStrategy, LogLoss)
case "Regression" =>
new BoostingStrategy(treeStrategy, SquaredError)
case _ =>
throw new IllegalArgumentException(s"$algo is not supported by boosting.")
throw new IllegalArgumentException(s"$algoStr is not supported by boosting.")
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,9 +171,9 @@ object Strategy {

/**
* Construct a default set of parameters for [[org.apache.spark.mllib.tree.DecisionTree]]
* @param algo "Classification" or "Regression"
* @param algoStr "Classification" or "Regression"
*/
def defaultStrategy(algo: String): Strategy = algo match {
def defaultStrategy(algoStr: String): Strategy = algoStr match {
case "Classification" =>
new Strategy(algo = Classification, impurity = Gini, maxDepth = 10,
numClasses = 2)
Expand Down

0 comments on commit 7c1e6ee

Please sign in to comment.