diff --git a/docs/mllib-decision-tree.md b/docs/mllib-decision-tree.md index e55db0ab0deae..e4f77e890b67c 100644 --- a/docs/mllib-decision-tree.md +++ b/docs/mllib-decision-tree.md @@ -162,47 +162,64 @@ val labelAndPreds = data.map { point => } val trainErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / data.count println("Training Error = " + trainErr) +println("Learned classification tree model:\n" + model) {% endhighlight %}
{% highlight java %} +import java.util.HashMap; import scala.Tuple2; +import org.apache.spark.api.java.function.Function2; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; import org.apache.spark.api.java.function.PairFunction; import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.mllib.tree.DecisionTree; import org.apache.spark.mllib.tree.model.DecisionTreeModel; - -JavaRDD data = ... // data set - -// Train a DecisionTree model. +import org.apache.spark.mllib.util.MLUtils; +import org.apache.spark.SparkConf; + +SparkConf sparkConf = new SparkConf().setAppName("JavaDecisionTree"); +JavaSparkContext sc = new JavaSparkContext(sparkConf); + +String datapath = "data/mllib/sample_libsvm_data.txt"; +JavaRDD data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD(); +// Compute the number of classes from the data. +Integer numClasses = data.map(new Function() { + @Override public Double call(LabeledPoint p) { + return p.label(); + } +}).countByValue().size(); + +// Set parameters. // Empty categoricalFeaturesInfo indicates all features are continuous. -Integer numClasses = ... // number of classes HashMap categoricalFeaturesInfo = new HashMap(); String impurity = "gini"; Integer maxDepth = 5; Integer maxBins = 100; +// Train a DecisionTree model for classification. final DecisionTreeModel model = DecisionTree.trainClassifier(data, numClasses, categoricalFeaturesInfo, impurity, maxDepth, maxBins); // Evaluate model on training instances and compute training error -JavaPairRDD predictionAndLabel = +JavaPairRDD predictionAndLabel = data.mapToPair(new PairFunction() { @Override public Tuple2 call(LabeledPoint p) { return new Tuple2(model.predict(p.features()), p.label()); } }); -Double trainErr = 1.0 * predictionAndLabel.filter(new Function, Boolean>() { +Double trainErr = + 1.0 * predictionAndLabel.filter(new Function, Boolean>() { @Override public Boolean call(Tuple2 pl) { return !pl._1().equals(pl._2()); } }).count() / data.count(); -System.out.print("Training error: " + trainErr); -System.out.print("Learned model:\n" + model); +System.out.println("Training error: " + trainErr); +System.out.println("Learned classification tree model:\n" + model); {% endhighlight %}
@@ -225,6 +242,8 @@ predictions = model.predict(data.map(lambda x: x.features)) labelsAndPredictions = data.map(lambda lp: lp.label).zip(predictions) trainErr = labelsAndPredictions.filter(lambda (v, p): v != p).count() / float(data.count()) print('Training Error = ' + str(trainErr)) +print('Learned classification tree model:') +print(model) {% endhighlight %} Note: When making predictions for a dataset, it is more efficient to do batch prediction rather @@ -268,47 +287,63 @@ val labelsAndPredictions = data.map { point => } val trainMSE = labelsAndPredictions.map{ case(v, p) => math.pow((v - p), 2)}.mean() println("Training Mean Squared Error = " + trainMSE) +println("Learned regression tree model:\n" + model) {% endhighlight %}
{% highlight java %} +import java.util.HashMap; +import scala.Tuple2; +import org.apache.spark.api.java.function.Function2; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; import org.apache.spark.api.java.function.PairFunction; import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.mllib.tree.DecisionTree; import org.apache.spark.mllib.tree.model.DecisionTreeModel; -import scala.Tuple2; +import org.apache.spark.mllib.util.MLUtils; +import org.apache.spark.SparkConf; -JavaRDD data = ... // data set +String datapath = "data/mllib/sample_libsvm_data.txt"; +JavaRDD data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD(); -// Train a DecisionTree model. +SparkConf sparkConf = new SparkConf().setAppName("JavaDecisionTree"); +JavaSparkContext sc = new JavaSparkContext(sparkConf); + +// Set parameters. // Empty categoricalFeaturesInfo indicates all features are continuous. HashMap categoricalFeaturesInfo = new HashMap(); String impurity = "variance"; Integer maxDepth = 5; Integer maxBins = 100; +// Train a DecisionTree model. final DecisionTreeModel model = DecisionTree.trainRegressor(data, categoricalFeaturesInfo, impurity, maxDepth, maxBins); // Evaluate model on training instances and compute training error -JavaPairRDD predictionAndLabel = +JavaPairRDD predictionAndLabel = data.mapToPair(new PairFunction() { @Override public Tuple2 call(LabeledPoint p) { return new Tuple2(model.predict(p.features()), p.label()); } }); -Double trainMSE = predictionAndLabel.map(new Function, Double>() { +Double trainMSE = + predictionAndLabel.map(new Function, Double>() { @Override public Double call(Tuple2 pl) { - Double diff = pl._1() - pl._2(); + Double diff = pl._1() - pl._2(); return diff * diff; } - }).sum() / data.count(); -System.out.print("Training Mean Squared Error: " + trainMSE); -System.out.print("Learned model:\n" + model); + }).reduce(new Function2() { + @Override public Double call(Double a, Double b) { + return a + b; + } + }) / data.count(); +System.out.println("Training Mean Squared Error: " + trainMSE); +System.out.println("Learned regression tree model:\n" + model); {% endhighlight %}
@@ -331,6 +366,8 @@ predictions = model.predict(data.map(lambda x: x.features)) labelsAndPredictions = data.map(lambda lp: lp.label).zip(predictions) trainMSE = labelsAndPredictions.map(lambda (v, p): (v - p) * (v - p)).sum() / float(data.count()) print('Training Mean Squared Error = ' + str(trainMSE)) +print('Learned regression tree model:') +print(model) {% endhighlight %} Note: When making predictions for a dataset, it is more efficient to do batch prediction rather diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTree.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTree.java index a70d5f40b5252..ee79946ec6f3b 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTree.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTree.java @@ -19,11 +19,10 @@ import java.util.HashMap; -import scala.reflect.ClassTag; import scala.Tuple2; import org.apache.spark.api.java.function.Function2; - import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; @@ -34,22 +33,23 @@ import org.apache.spark.mllib.util.MLUtils; import org.apache.spark.SparkConf; - /** * Classification and regression using decision trees. */ public final class JavaDecisionTree { public static void main(String[] args) { - if (args.length != 1) { + String datapath = "data/mllib/sample_libsvm_data.txt"; + if (args.length == 1) { + datapath = args[0]; + } else if (args.length > 1) { System.err.println("Usage: JavaDecisionTree "); System.exit(1); } SparkConf sparkConf = new SparkConf().setAppName("JavaDecisionTree"); JavaSparkContext sc = new JavaSparkContext(sparkConf); - String datapath = args[0]; - JavaRDD data = JavaRDD.fromRDD(MLUtils.loadLibSVMFile(sc.sc(), datapath)); + JavaRDD data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD(); // Compute the number of classes from the data. Integer numClasses = data.map(new Function() { @@ -57,7 +57,9 @@ public static void main(String[] args) { return p.label(); } }).countByValue().size(); - // Empty categoricalFeaturesInfo indicates all features are continuous. + + // Set parameters. + // Empty categoricalFeaturesInfo indicates all features are continuous. HashMap categoricalFeaturesInfo = new HashMap(); String impurity = "gini"; Integer maxDepth = 5; @@ -80,12 +82,11 @@ public static void main(String[] args) { return !pl._1().equals(pl._2()); } }).count() / data.count(); - System.out.print("Training error: " + trainErr); - System.out.print("Learned classification tree model:\n" + model); + System.out.println("Training error: " + trainErr); + System.out.println("Learned classification tree model:\n" + model); // Train a DecisionTree model for regression. impurity = "variance"; - final DecisionTreeModel regressionModel = DecisionTree.trainRegressor(data, categoricalFeaturesInfo, impurity, maxDepth, maxBins); @@ -107,8 +108,8 @@ public static void main(String[] args) { return a + b; } }) / data.count(); - System.out.print("Training Mean Squared Error: " + trainMSE); - System.out.print("Learned regression tree model:\n" + regressionModel); + System.out.println("Training Mean Squared Error: " + trainMSE); + System.out.println("Learned regression tree model:\n" + regressionModel); sc.stop(); }