Skip to content

Commit

Permalink
[SPARK-13784][ML] Persistence for RandomForestClassifier, RandomFores…
Browse files Browse the repository at this point in the history
…tRegressor

## What changes were proposed in this pull request?

**Main change**: Added save/load for RandomForestClassifier, RandomForestRegressor (implementation details below)

Modified numTrees method (*deprecation*)
* Goal: Use default implementations of unit tests which assume Estimators and Models share the same set of Params.
* What this PR does: Moves method numTrees outside of trait TreeEnsembleModel.  Adds it to GBT and RF Models.  Deprecates it in RF Models in favor of new method getNumTrees.  In Spark 2.1, we can have RF Models include Param numTrees.

Minor items
* Fixes bugs in GBTClassificationModel, GBTRegressionModel fromOld methods where they assign the wrong old UID.

**Implementation details**
* Split DecisionTreeModelReadWrite.loadTreeNodes into 2 methods in order to reuse some code for ensembles.
* Added EnsembleModelReadWrite object with save/load implementations usable for RFs and GBTs
  * These store all trees' nodes in a single DataFrame, and all trees' metadata in a second DataFrame.
* Split trait RandomForestParams into parts in order to add more Estimator Params to RF models
* Split DefaultParamsWriter.saveMetadata into two methods to allow ensembles to store sub-models' metadata in a single DataFrame.  Same for DefaultParamsReader.loadMetadata

## How was this patch tested?

Adds standard unit tests for RF save/load

Author: Joseph K. Bradley <[email protected]>
Author: GayathriMurali <[email protected]>

Closes #12118 from jkbradley/GayathriMurali-SPARK-13784.
  • Loading branch information
jkbradley committed Apr 4, 2016
1 parent 7454253 commit 89f3bef
Show file tree
Hide file tree
Showing 9 changed files with 424 additions and 103 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ final class GBTClassificationModel private[ml](
extends PredictionModel[Vector, GBTClassificationModel]
with TreeEnsembleModel with Serializable {

require(numTrees > 0, "GBTClassificationModel requires at least 1 tree.")
require(_trees.nonEmpty, "GBTClassificationModel requires at least 1 tree.")
require(_trees.length == _treeWeights.length, "GBTClassificationModel given trees, treeWeights" +
s" of non-matching lengths (${_trees.length}, ${_treeWeights.length}, respectively).")

Expand Down Expand Up @@ -227,6 +227,9 @@ final class GBTClassificationModel private[ml](
if (prediction > 0.0) 1.0 else 0.0
}

/** Number of trees in ensemble */
val numTrees: Int = trees.length

@Since("1.4.0")
override def copy(extra: ParamMap): GBTClassificationModel = {
copyValues(new GBTClassificationModel(uid, _trees, _treeWeights, numFeatures),
Expand Down Expand Up @@ -272,6 +275,6 @@ private[ml] object GBTClassificationModel {
DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures)
}
val uid = if (parent != null) parent.uid else Identifiable.randomUID("gbtc")
new GBTClassificationModel(parent.uid, newTrees, oldModel.treeWeights, numFeatures)
new GBTClassificationModel(uid, newTrees, oldModel.treeWeights, numFeatures)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,15 @@

package org.apache.spark.ml.classification

import org.json4s.{DefaultFormats, JObject}
import org.json4s.JsonDSL._

import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tree.{DecisionTreeModel, RandomForestParams, TreeClassifierParams, TreeEnsembleModel}
import org.apache.spark.ml.tree._
import org.apache.spark.ml.tree.impl.RandomForest
import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
import org.apache.spark.ml.util._
import org.apache.spark.ml.util.DefaultParamsReader.Metadata
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
Expand All @@ -43,7 +47,7 @@ import org.apache.spark.sql.functions._
final class RandomForestClassifier @Since("1.4.0") (
@Since("1.4.0") override val uid: String)
extends ProbabilisticClassifier[Vector, RandomForestClassifier, RandomForestClassificationModel]
with RandomForestParams with TreeClassifierParams {
with RandomForestClassifierParams with DefaultParamsWritable {

@Since("1.4.0")
def this() = this(Identifiable.randomUID("rfc"))
Expand Down Expand Up @@ -120,7 +124,7 @@ final class RandomForestClassifier @Since("1.4.0") (

@Since("1.4.0")
@Experimental
object RandomForestClassifier {
object RandomForestClassifier extends DefaultParamsReadable[RandomForestClassifier] {
/** Accessor for supported impurity settings: entropy, gini */
@Since("1.4.0")
final val supportedImpurities: Array[String] = TreeClassifierParams.supportedImpurities
Expand All @@ -129,15 +133,19 @@ object RandomForestClassifier {
@Since("1.4.0")
final val supportedFeatureSubsetStrategies: Array[String] =
RandomForestParams.supportedFeatureSubsetStrategies

@Since("2.0.0")
override def load(path: String): RandomForestClassifier = super.load(path)
}

/**
* :: Experimental ::
* [[http://en.wikipedia.org/wiki/Random_forest Random Forest]] model for classification.
* It supports both binary and multiclass labels, as well as both continuous and categorical
* features.
*
* @param _trees Decision trees in the ensemble.
* Warning: These have null parents.
* Warning: These have null parents.
*/
@Since("1.4.0")
@Experimental
Expand All @@ -147,12 +155,14 @@ final class RandomForestClassificationModel private[ml] (
@Since("1.6.0") override val numFeatures: Int,
@Since("1.5.0") override val numClasses: Int)
extends ProbabilisticClassificationModel[Vector, RandomForestClassificationModel]
with TreeEnsembleModel with Serializable {
with RandomForestClassificationModelParams with TreeEnsembleModel with MLWritable
with Serializable {

require(numTrees > 0, "RandomForestClassificationModel requires at least 1 tree.")
require(_trees.nonEmpty, "RandomForestClassificationModel requires at least 1 tree.")

/**
* Construct a random forest classification model, with all trees weighted equally.
*
* @param trees Component trees
*/
private[ml] def this(
Expand All @@ -165,7 +175,7 @@ final class RandomForestClassificationModel private[ml] (
override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]]

// Note: We may add support for weights (based on tree performance) later on.
private lazy val _treeWeights: Array[Double] = Array.fill[Double](numTrees)(1.0)
private lazy val _treeWeights: Array[Double] = Array.fill[Double](_trees.length)(1.0)

@Since("1.4.0")
override def treeWeights: Array[Double] = _treeWeights
Expand Down Expand Up @@ -208,6 +218,15 @@ final class RandomForestClassificationModel private[ml] (
}
}

/**
* Number of trees in ensemble
*
* @deprecated Use [[getNumTrees]] instead. This method will be removed in 2.1.0
*/
// TODO: Once this is removed, then this class can inherit from RandomForestClassifierParams
@deprecated("Use getNumTrees instead. This method will be removed in 2.1.0.", "2.0.0")
val numTrees: Int = trees.length

@Since("1.4.0")
override def copy(extra: ParamMap): RandomForestClassificationModel = {
copyValues(new RandomForestClassificationModel(uid, _trees, numFeatures, numClasses), extra)
Expand All @@ -216,7 +235,7 @@ final class RandomForestClassificationModel private[ml] (

@Since("1.4.0")
override def toString: String = {
s"RandomForestClassificationModel (uid=$uid) with $numTrees trees"
s"RandomForestClassificationModel (uid=$uid) with $getNumTrees trees"
}

/**
Expand All @@ -236,12 +255,69 @@ final class RandomForestClassificationModel private[ml] (
private[ml] def toOld: OldRandomForestModel = {
new OldRandomForestModel(OldAlgo.Classification, _trees.map(_.toOld))
}

@Since("2.0.0")
override def write: MLWriter =
new RandomForestClassificationModel.RandomForestClassificationModelWriter(this)
}

private[ml] object RandomForestClassificationModel {
@Since("2.0.0")
object RandomForestClassificationModel extends MLReadable[RandomForestClassificationModel] {

@Since("2.0.0")
override def read: MLReader[RandomForestClassificationModel] =
new RandomForestClassificationModelReader

@Since("2.0.0")
override def load(path: String): RandomForestClassificationModel = super.load(path)

private[RandomForestClassificationModel]
class RandomForestClassificationModelWriter(instance: RandomForestClassificationModel)
extends MLWriter {

override protected def saveImpl(path: String): Unit = {
// Note: numTrees is not currently used, but could be nice to store for fast querying.
val extraMetadata: JObject = Map(
"numFeatures" -> instance.numFeatures,
"numClasses" -> instance.numClasses,
"numTrees" -> instance.getNumTrees)
EnsembleModelReadWrite.saveImpl(instance, path, sqlContext, extraMetadata)
}
}

private class RandomForestClassificationModelReader
extends MLReader[RandomForestClassificationModel] {

/** Checked against metadata when loading model */
private val className = classOf[RandomForestClassificationModel].getName
private val treeClassName = classOf[DecisionTreeClassificationModel].getName

override def load(path: String): RandomForestClassificationModel = {
implicit val format = DefaultFormats
val (metadata: Metadata, treesData: Array[(Metadata, Node)]) =
EnsembleModelReadWrite.loadImpl(path, sqlContext, className, treeClassName)
val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
val numClasses = (metadata.metadata \ "numClasses").extract[Int]
val numTrees = (metadata.metadata \ "numTrees").extract[Int]

val trees: Array[DecisionTreeClassificationModel] = treesData.map {
case (treeMetadata, root) =>
val tree =
new DecisionTreeClassificationModel(treeMetadata.uid, root, numFeatures, numClasses)
DefaultParamsReader.getAndSetParams(tree, treeMetadata)
tree
}
require(numTrees == trees.length, s"RandomForestClassificationModel.load expected $numTrees" +
s" trees based on metadata but found ${trees.length} trees.")

val model = new RandomForestClassificationModel(metadata.uid, trees, numFeatures, numClasses)
DefaultParamsReader.getAndSetParams(model, metadata)
model
}
}

/** (private[ml]) Convert a model from the old API */
def fromOld(
/** Convert a model from the old API */
private[ml] def fromOld(
oldModel: OldRandomForestModel,
parent: RandomForestClassifier,
categoricalFeatures: Map[Int, Int],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ final class GBTRegressionModel private[ml](
extends PredictionModel[Vector, GBTRegressionModel]
with TreeEnsembleModel with Serializable {

require(numTrees > 0, "GBTRegressionModel requires at least 1 tree.")
require(_trees.nonEmpty, "GBTRegressionModel requires at least 1 tree.")
require(_trees.length == _treeWeights.length, "GBTRegressionModel given trees, treeWeights of" +
s" non-matching lengths (${_trees.length}, ${_treeWeights.length}, respectively).")

Expand Down Expand Up @@ -213,6 +213,9 @@ final class GBTRegressionModel private[ml](
blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1)
}

/** Number of trees in ensemble */
val numTrees: Int = trees.length

@Since("1.4.0")
override def copy(extra: ParamMap): GBTRegressionModel = {
copyValues(new GBTRegressionModel(uid, _trees, _treeWeights, numFeatures),
Expand Down Expand Up @@ -258,6 +261,6 @@ private[ml] object GBTRegressionModel {
DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures)
}
val uid = if (parent != null) parent.uid else Identifiable.randomUID("gbtr")
new GBTRegressionModel(parent.uid, newTrees, oldModel.treeWeights, numFeatures)
new GBTRegressionModel(uid, newTrees, oldModel.treeWeights, numFeatures)
}
}
Loading

0 comments on commit 89f3bef

Please sign in to comment.