diff --git a/docs/mllib-ensembles.md b/docs/mllib-ensembles.md index 9fd49d5c9bc9a..bc6fac30cb75d 100644 --- a/docs/mllib-ensembles.md +++ b/docs/mllib-ensembles.md @@ -506,7 +506,7 @@ JavaRDD testData = splits[1]; // Train a GradientBoostedTrees model. // The defaultParams for Classification use LogLoss by default. -BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams(Algo.Classification); +BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams(Algo.Classification()); boostingStrategy.setNumIterations(3); // Note: Use more iterations in practice. boostingStrategy.getTreeStrategy().setNumClassesForClassification(2); boostingStrategy.getTreeStrategy().setMaxDepth(5); @@ -614,7 +614,7 @@ JavaRDD testData = splits[1]; // Train a GradientBoostedTrees model. // The defaultParams for Regression use SquaredError by default. -BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams(Algo.Regression); +BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams(Algo.Regression()); boostingStrategy.setNumIterations(3); // Note: Use more iterations in practice. boostingStrategy.getTreeStrategy().setMaxDepth(5); // Empty categoricalFeaturesInfo indicates all features are continuous. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala index 49ceb83a54154..0203aff2b03ec 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala @@ -70,22 +70,22 @@ object BoostingStrategy { /** * Returns default configuration for the boosting algorithm - * @param algo Learning goal. Supported: + * @param algoStr Learning goal. Supported: * [[org.apache.spark.mllib.tree.configuration.Algo.Classification]], * [[org.apache.spark.mllib.tree.configuration.Algo.Regression]] * @return Configuration for boosting algorithm */ - def defaultParams(algo: String): BoostingStrategy = { - val treeStrategy = Strategy.defaultStrategy(algo) + def defaultParams(algoStr: String): BoostingStrategy = { + val treeStrategy = Strategy.defaultStrategy(algoStr) treeStrategy.maxDepth = 3 - algo match { + algoStr match { case "Classification" => treeStrategy.numClasses = 2 new BoostingStrategy(treeStrategy, LogLoss) case "Regression" => new BoostingStrategy(treeStrategy, SquaredError) case _ => - throw new IllegalArgumentException(s"$algo is not supported by boosting.") + throw new IllegalArgumentException(s"$algoStr is not supported by boosting.") } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index bef066783082c..6345c041e7aed 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -171,9 +171,9 @@ object Strategy { /** * Construct a default set of parameters for [[org.apache.spark.mllib.tree.DecisionTree]] - * @param algo "Classification" or "Regression" + * @param algoStr "Classification" or "Regression" */ - def defaultStrategy(algo: String): Strategy = algo match { + def defaultStrategy(algoStr: String): Strategy = algoStr match { case "Classification" => new Strategy(algo = Classification, impurity = Gini, maxDepth = 10, numClasses = 2)