Skip to content

Commit

Permalink
Updated decision tree doc.
Browse files Browse the repository at this point in the history
  • Loading branch information
jkbradley committed Aug 20, 2014
1 parent d802369 commit 9dd1b6b
Showing 1 changed file with 43 additions and 25 deletions.
68 changes: 43 additions & 25 deletions docs/mllib-decision-tree.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ and their ensembles are popular methods for the machine learning tasks of
classification and regression. Decision trees are widely used since they are easy to interpret,
handle categorical features, extend to the multiclass classification setting, do not require
feature scaling and are able to capture nonlinearities and feature interactions. Tree ensemble
algorithms such as decision forests and boosting are among the top performers for classification and
algorithms such as random forests and boosting are among the top performers for classification and
regression tasks.

MLlib supports decision trees for binary and multiclass classification and for regression,
Expand Down Expand Up @@ -94,13 +94,13 @@ Section 9.2.4 in
details). For example, for a binary classification problem with one categorical feature with three
categories A, B and C whose corresponding proportions of label 1 are 0.2, 0.6 and 0.4, the categorical
features are ordered as A, C, B. The two split candidates are A \| C, B
and A , C \| B where \| denotes the split. A similar heuristic is used for multiclass classification
when `$2^{M-1}-1$` is greater than the `maxBins` parameter: the impurity for each categorical feature value
is used for ordering. In multiclass classification, all `$2^{M-1}-1$` possible splits are used
whenever possible.
and A , C \| B where \| denotes the split.

Note that the `maxBins` parameter must be at least `$M_{max}$`, the maximum number of categories for
any categorical feature.
In multiclass classification, all `$2^{M-1}-1$` possible splits are used whenever possible.
When `$2^{M-1}-1$` is greater than the `maxBins` parameter, we use a (heuristic) method
similar to the method used for binary classification and regression.
The `$M$` categorical feature values are ordered by impurity,
and the resulting `$M-1$` split candidates are considered.

### Stopping rule

Expand All @@ -109,6 +109,8 @@ The recursive tree construction is stopped at a node when one of the two conditi
1. The node depth is equal to the `maxDepth` training parameter.
2. No split candidate leads to an information gain at the node.

## Implementation details

### Max memory requirements

For faster processing, the decision tree algorithm performs simultaneous histogram computations for
Expand All @@ -120,11 +122,24 @@ be 128 MB to allow the decision algorithm to work in most scenarios. Once the me
for a level-wise computation cross the `maxMemoryInMB` threshold, the node training tasks at each
subsequent level are split into smaller tasks.

### Practical limitations
Note that, if you have a large amount of memory, increasing `maxMemoryInMB` can lead to faster
training by requiring fewer passes over the data.

### Binning feature values

Increasing `maxBins` allows the algorithm to consider more split candidates and make fine-grained
split decisions. However, it also increases computation and communication.

Note that the `maxBins` parameter must be at least the maximum number of categories `$M$` for
any categorical feature.

### Scaling

Computation scales approximately linearly in the number of training instances,
in the number of features, and in the `maxBins` parameter.
Communication scales approximately linearly in the number of features and in `maxBins`.

1. The implemented algorithm reads both sparse and dense data. However, it is not optimized for sparse input.
2. Computation scales approximately linearly in the number of training instances,
in the number of features, and in the `maxBins` parameter.
The implemented algorithm reads both sparse and dense data. However, it is not optimized for sparse input.

## Examples

Expand All @@ -143,8 +158,9 @@ maximum tree depth of 5. The training error is calculated to measure the algorit
import org.apache.spark.mllib.tree.DecisionTree
import org.apache.spark.mllib.util.MLUtils

// Load and parse the data file
val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")
// Load and parse the data file.
// Cache the data since we will use it again to compute training error.
val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").cache()

// Train a DecisionTree model.
// Empty categoricalFeaturesInfo indicates all features are continuous.
Expand Down Expand Up @@ -187,17 +203,14 @@ import org.apache.spark.SparkConf;
SparkConf sparkConf = new SparkConf().setAppName("JavaDecisionTree");
JavaSparkContext sc = new JavaSparkContext(sparkConf);

// Load and parse the data file.
// Cache the data since we will use it again to compute training error.
String datapath = "data/mllib/sample_libsvm_data.txt";
JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD().cache();
// Compute the number of classes from the data.
Integer numClasses = data.map(new Function<LabeledPoint, Double>() {
@Override public Double call(LabeledPoint p) {
return p.label();
}
}).countByValue().size();

// Set parameters.
// Empty categoricalFeaturesInfo indicates all features are continuous.
Integer numClasses = 2;
HashMap<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>();
String impurity = "gini";
Integer maxDepth = 5;
Expand Down Expand Up @@ -231,8 +244,9 @@ from pyspark.mllib.regression import LabeledPoint
from pyspark.mllib.tree import DecisionTree
from pyspark.mllib.util import MLUtils

# an RDD of LabeledPoint
data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt')
# Load and parse the data file into an RDD of LabeledPoint.
# Cache the data since we will use it again to compute training error.
data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt').cache()

# Train a DecisionTree model.
# Empty categoricalFeaturesInfo indicates all features are continuous.
Expand Down Expand Up @@ -271,8 +285,9 @@ depth of 5. The Mean Squared Error (MSE) is computed at the end to evaluate
import org.apache.spark.mllib.tree.DecisionTree
import org.apache.spark.mllib.util.MLUtils

// Load and parse the data file
val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")
// Load and parse the data file.
// Cache the data since we will use it again to compute training error.
val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").cache()

// Train a DecisionTree model.
// Empty categoricalFeaturesInfo indicates all features are continuous.
Expand Down Expand Up @@ -311,6 +326,8 @@ import org.apache.spark.mllib.tree.model.DecisionTreeModel;
import org.apache.spark.mllib.util.MLUtils;
import org.apache.spark.SparkConf;

// Load and parse the data file.
// Cache the data since we will use it again to compute training error.
String datapath = "data/mllib/sample_libsvm_data.txt";
JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD().cache();

Expand Down Expand Up @@ -357,8 +374,9 @@ from pyspark.mllib.regression import LabeledPoint
from pyspark.mllib.tree import DecisionTree
from pyspark.mllib.util import MLUtils

# an RDD of LabeledPoint
data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt')
# Load and parse the data file into an RDD of LabeledPoint.
# Cache the data since we will use it again to compute training error.
data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt').cache()

# Train a DecisionTree model.
# Empty categoricalFeaturesInfo indicates all features are continuous.
Expand Down

0 comments on commit 9dd1b6b

Please sign in to comment.