diff --git a/docs/mllib-decision-tree.md b/docs/mllib-decision-tree.md
index e55db0ab0deae..e4f77e890b67c 100644
--- a/docs/mllib-decision-tree.md
+++ b/docs/mllib-decision-tree.md
@@ -162,47 +162,64 @@ val labelAndPreds = data.map { point =>
}
val trainErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / data.count
println("Training Error = " + trainErr)
+println("Learned classification tree model:\n" + model)
{% endhighlight %}
{% highlight java %}
+import java.util.HashMap;
import scala.Tuple2;
+import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.tree.DecisionTree;
import org.apache.spark.mllib.tree.model.DecisionTreeModel;
-
-JavaRDD data = ... // data set
-
-// Train a DecisionTree model.
+import org.apache.spark.mllib.util.MLUtils;
+import org.apache.spark.SparkConf;
+
+SparkConf sparkConf = new SparkConf().setAppName("JavaDecisionTree");
+JavaSparkContext sc = new JavaSparkContext(sparkConf);
+
+String datapath = "data/mllib/sample_libsvm_data.txt";
+JavaRDD data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD();
+// Compute the number of classes from the data.
+Integer numClasses = data.map(new Function() {
+ @Override public Double call(LabeledPoint p) {
+ return p.label();
+ }
+}).countByValue().size();
+
+// Set parameters.
// Empty categoricalFeaturesInfo indicates all features are continuous.
-Integer numClasses = ... // number of classes
HashMap categoricalFeaturesInfo = new HashMap();
String impurity = "gini";
Integer maxDepth = 5;
Integer maxBins = 100;
+// Train a DecisionTree model for classification.
final DecisionTreeModel model = DecisionTree.trainClassifier(data, numClasses,
categoricalFeaturesInfo, impurity, maxDepth, maxBins);
// Evaluate model on training instances and compute training error
-JavaPairRDD predictionAndLabel =
+JavaPairRDD predictionAndLabel =
data.mapToPair(new PairFunction() {
@Override public Tuple2 call(LabeledPoint p) {
return new Tuple2(model.predict(p.features()), p.label());
}
});
-Double trainErr = 1.0 * predictionAndLabel.filter(new Function, Boolean>() {
+Double trainErr =
+ 1.0 * predictionAndLabel.filter(new Function, Boolean>() {
@Override public Boolean call(Tuple2 pl) {
return !pl._1().equals(pl._2());
}
}).count() / data.count();
-System.out.print("Training error: " + trainErr);
-System.out.print("Learned model:\n" + model);
+System.out.println("Training error: " + trainErr);
+System.out.println("Learned classification tree model:\n" + model);
{% endhighlight %}
@@ -225,6 +242,8 @@ predictions = model.predict(data.map(lambda x: x.features))
labelsAndPredictions = data.map(lambda lp: lp.label).zip(predictions)
trainErr = labelsAndPredictions.filter(lambda (v, p): v != p).count() / float(data.count())
print('Training Error = ' + str(trainErr))
+print('Learned classification tree model:')
+print(model)
{% endhighlight %}
Note: When making predictions for a dataset, it is more efficient to do batch prediction rather
@@ -268,47 +287,63 @@ val labelsAndPredictions = data.map { point =>
}
val trainMSE = labelsAndPredictions.map{ case(v, p) => math.pow((v - p), 2)}.mean()
println("Training Mean Squared Error = " + trainMSE)
+println("Learned regression tree model:\n" + model)
{% endhighlight %}
{% highlight java %}
+import java.util.HashMap;
+import scala.Tuple2;
+import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.tree.DecisionTree;
import org.apache.spark.mllib.tree.model.DecisionTreeModel;
-import scala.Tuple2;
+import org.apache.spark.mllib.util.MLUtils;
+import org.apache.spark.SparkConf;
-JavaRDD data = ... // data set
+String datapath = "data/mllib/sample_libsvm_data.txt";
+JavaRDD data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD();
-// Train a DecisionTree model.
+SparkConf sparkConf = new SparkConf().setAppName("JavaDecisionTree");
+JavaSparkContext sc = new JavaSparkContext(sparkConf);
+
+// Set parameters.
// Empty categoricalFeaturesInfo indicates all features are continuous.
HashMap categoricalFeaturesInfo = new HashMap();
String impurity = "variance";
Integer maxDepth = 5;
Integer maxBins = 100;
+// Train a DecisionTree model.
final DecisionTreeModel model = DecisionTree.trainRegressor(data,
categoricalFeaturesInfo, impurity, maxDepth, maxBins);
// Evaluate model on training instances and compute training error
-JavaPairRDD predictionAndLabel =
+JavaPairRDD predictionAndLabel =
data.mapToPair(new PairFunction() {
@Override public Tuple2 call(LabeledPoint p) {
return new Tuple2(model.predict(p.features()), p.label());
}
});
-Double trainMSE = predictionAndLabel.map(new Function, Double>() {
+Double trainMSE =
+ predictionAndLabel.map(new Function, Double>() {
@Override public Double call(Tuple2 pl) {
- Double diff = pl._1() - pl._2();
+ Double diff = pl._1() - pl._2();
return diff * diff;
}
- }).sum() / data.count();
-System.out.print("Training Mean Squared Error: " + trainMSE);
-System.out.print("Learned model:\n" + model);
+ }).reduce(new Function2() {
+ @Override public Double call(Double a, Double b) {
+ return a + b;
+ }
+ }) / data.count();
+System.out.println("Training Mean Squared Error: " + trainMSE);
+System.out.println("Learned regression tree model:\n" + model);
{% endhighlight %}
@@ -331,6 +366,8 @@ predictions = model.predict(data.map(lambda x: x.features))
labelsAndPredictions = data.map(lambda lp: lp.label).zip(predictions)
trainMSE = labelsAndPredictions.map(lambda (v, p): (v - p) * (v - p)).sum() / float(data.count())
print('Training Mean Squared Error = ' + str(trainMSE))
+print('Learned regression tree model:')
+print(model)
{% endhighlight %}
Note: When making predictions for a dataset, it is more efficient to do batch prediction rather
diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTree.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTree.java
index a70d5f40b5252..ee79946ec6f3b 100644
--- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTree.java
+++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTree.java
@@ -19,11 +19,10 @@
import java.util.HashMap;
-import scala.reflect.ClassTag;
import scala.Tuple2;
import org.apache.spark.api.java.function.Function2;
- import org.apache.spark.api.java.JavaPairRDD;
+import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
@@ -34,22 +33,23 @@
import org.apache.spark.mllib.util.MLUtils;
import org.apache.spark.SparkConf;
-
/**
* Classification and regression using decision trees.
*/
public final class JavaDecisionTree {
public static void main(String[] args) {
- if (args.length != 1) {
+ String datapath = "data/mllib/sample_libsvm_data.txt";
+ if (args.length == 1) {
+ datapath = args[0];
+ } else if (args.length > 1) {
System.err.println("Usage: JavaDecisionTree ");
System.exit(1);
}
SparkConf sparkConf = new SparkConf().setAppName("JavaDecisionTree");
JavaSparkContext sc = new JavaSparkContext(sparkConf);
- String datapath = args[0];
- JavaRDD data = JavaRDD.fromRDD(MLUtils.loadLibSVMFile(sc.sc(), datapath));
+ JavaRDD data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD();
// Compute the number of classes from the data.
Integer numClasses = data.map(new Function() {
@@ -57,7 +57,9 @@ public static void main(String[] args) {
return p.label();
}
}).countByValue().size();
- // Empty categoricalFeaturesInfo indicates all features are continuous.
+
+ // Set parameters.
+ // Empty categoricalFeaturesInfo indicates all features are continuous.
HashMap categoricalFeaturesInfo = new HashMap();
String impurity = "gini";
Integer maxDepth = 5;
@@ -80,12 +82,11 @@ public static void main(String[] args) {
return !pl._1().equals(pl._2());
}
}).count() / data.count();
- System.out.print("Training error: " + trainErr);
- System.out.print("Learned classification tree model:\n" + model);
+ System.out.println("Training error: " + trainErr);
+ System.out.println("Learned classification tree model:\n" + model);
// Train a DecisionTree model for regression.
impurity = "variance";
-
final DecisionTreeModel regressionModel = DecisionTree.trainRegressor(data,
categoricalFeaturesInfo, impurity, maxDepth, maxBins);
@@ -107,8 +108,8 @@ public static void main(String[] args) {
return a + b;
}
}) / data.count();
- System.out.print("Training Mean Squared Error: " + trainMSE);
- System.out.print("Learned regression tree model:\n" + regressionModel);
+ System.out.println("Training Mean Squared Error: " + trainMSE);
+ System.out.println("Learned regression tree model:\n" + regressionModel);
sc.stop();
}