Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-14386][ML] Changed spark.ml ensemble trees methods to return concrete types #12158

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.ml.{PredictionModel, Predictor}
import org.apache.spark.ml.param.{Param, ParamMap}
import org.apache.spark.ml.regression.DecisionTreeRegressionModel
import org.apache.spark.ml.tree.{DecisionTreeModel, GBTParams, TreeClassifierParams,
TreeEnsembleModel}
import org.apache.spark.ml.tree.{GBTParams, TreeClassifierParams, TreeEnsembleModel}
import org.apache.spark.ml.tree.impl.GradientBoostedTrees
import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
import org.apache.spark.mllib.linalg.Vector
Expand Down Expand Up @@ -190,7 +189,7 @@ final class GBTClassificationModel private[ml](
private val _treeWeights: Array[Double],
@Since("1.6.0") override val numFeatures: Int)
extends PredictionModel[Vector, GBTClassificationModel]
with TreeEnsembleModel with Serializable {
with TreeEnsembleModel[DecisionTreeRegressionModel] with Serializable {

require(_trees.nonEmpty, "GBTClassificationModel requires at least 1 tree.")
require(_trees.length == _treeWeights.length, "GBTClassificationModel given trees, treeWeights" +
Expand All @@ -206,7 +205,7 @@ final class GBTClassificationModel private[ml](
this(uid, _trees, _treeWeights, -1)

@Since("1.4.0")
override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]]
override def trees: Array[DecisionTreeRegressionModel] = _trees

@Since("1.4.0")
override def treeWeights: Array[Double] = _treeWeights
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,8 @@ final class RandomForestClassificationModel private[ml] (
@Since("1.6.0") override val numFeatures: Int,
@Since("1.5.0") override val numClasses: Int)
extends ProbabilisticClassificationModel[Vector, RandomForestClassificationModel]
with RandomForestClassificationModelParams with TreeEnsembleModel with MLWritable
with Serializable {
with RandomForestClassificationModelParams with TreeEnsembleModel[DecisionTreeClassificationModel]
with MLWritable with Serializable {

require(_trees.nonEmpty, "RandomForestClassificationModel requires at least 1 tree.")

Expand All @@ -172,7 +172,7 @@ final class RandomForestClassificationModel private[ml] (
this(Identifiable.randomUID("rfc"), trees, numFeatures, numClasses)

@Since("1.4.0")
override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]]
override def trees: Array[DecisionTreeClassificationModel] = _trees

// Note: We may add support for weights (based on tree performance) later on.
private lazy val _treeWeights: Array[Double] = Array.fill[Double](_trees.length)(1.0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@ import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.internal.Logging
import org.apache.spark.ml.{PredictionModel, Predictor}
import org.apache.spark.ml.param.{Param, ParamMap}
import org.apache.spark.ml.tree.{DecisionTreeModel, GBTParams, TreeEnsembleModel,
TreeRegressorParams}
import org.apache.spark.ml.tree.{GBTParams, TreeEnsembleModel, TreeRegressorParams}
import org.apache.spark.ml.tree.impl.GradientBoostedTrees
import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
import org.apache.spark.mllib.linalg.Vector
Expand Down Expand Up @@ -177,7 +176,7 @@ final class GBTRegressionModel private[ml](
private val _treeWeights: Array[Double],
override val numFeatures: Int)
extends PredictionModel[Vector, GBTRegressionModel]
with TreeEnsembleModel with Serializable {
with TreeEnsembleModel[DecisionTreeRegressionModel] with Serializable {

require(_trees.nonEmpty, "GBTRegressionModel requires at least 1 tree.")
require(_trees.length == _treeWeights.length, "GBTRegressionModel given trees, treeWeights of" +
Expand All @@ -193,7 +192,7 @@ final class GBTRegressionModel private[ml](
this(uid, _trees, _treeWeights, -1)

@Since("1.4.0")
override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]]
override def trees: Array[DecisionTreeRegressionModel] = _trees

@Since("1.4.0")
override def treeWeights: Array[Double] = _treeWeights
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,8 @@ final class RandomForestRegressionModel private[ml] (
private val _trees: Array[DecisionTreeRegressionModel],
override val numFeatures: Int)
extends PredictionModel[Vector, RandomForestRegressionModel]
with RandomForestRegressionModelParams with TreeEnsembleModel with MLWritable with Serializable {
with RandomForestRegressionModelParams with TreeEnsembleModel[DecisionTreeRegressionModel]
with MLWritable with Serializable {

require(_trees.nonEmpty, "RandomForestRegressionModel requires at least 1 tree.")

Expand All @@ -155,7 +156,7 @@ final class RandomForestRegressionModel private[ml] (
this(Identifiable.randomUID("rfr"), trees, numFeatures)

@Since("1.4.0")
override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]]
override def trees: Array[DecisionTreeRegressionModel] = _trees

// Note: We may add support for weights (based on tree performance) later on.
private lazy val _treeWeights: Array[Double] = Array.fill[Double](_trees.length)(1.0)
Expand Down
14 changes: 9 additions & 5 deletions mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.ml.tree

import scala.reflect.ClassTag

import org.apache.hadoop.fs.Path
import org.json4s._
import org.json4s.jackson.JsonMethods._
Expand Down Expand Up @@ -82,14 +84,16 @@ private[spark] trait DecisionTreeModel {
* Abstraction for models which are ensembles of decision trees
*
* TODO: Add support for predicting probabilities and raw predictions SPARK-3727
*
* @tparam M Type of tree model in this ensemble
*/
private[ml] trait TreeEnsembleModel {
private[ml] trait TreeEnsembleModel[M <: DecisionTreeModel] {

// Note: We use getTrees since subclasses of TreeEnsembleModel will store subclasses of
// DecisionTreeModel.

/** Trees in this ensemble. Warning: These have null parent Estimators. */
def trees: Array[DecisionTreeModel]
def trees: Array[M]

/**
* Number of trees in ensemble
Expand Down Expand Up @@ -148,7 +152,7 @@ private[ml] object TreeEnsembleModel {
* If -1, then numFeatures is set based on the max feature index in all trees.
* @return Feature importance values, of length numFeatures.
*/
def featureImportances(trees: Array[DecisionTreeModel], numFeatures: Int): Vector = {
def featureImportances[M <: DecisionTreeModel](trees: Array[M], numFeatures: Int): Vector = {
val totalImportances = new OpenHashMap[Int, Double]()
trees.foreach { tree =>
// Aggregate feature importance vector for this tree
Expand Down Expand Up @@ -199,7 +203,7 @@ private[ml] object TreeEnsembleModel {
* If -1, then numFeatures is set based on the max feature index in all trees.
* @return Feature importance values, of length numFeatures.
*/
def featureImportances(tree: DecisionTreeModel, numFeatures: Int): Vector = {
def featureImportances[M <: DecisionTreeModel : ClassTag](tree: M, numFeatures: Int): Vector = {
featureImportances(Array(tree), numFeatures)
}

Expand Down Expand Up @@ -386,7 +390,7 @@ private[ml] object EnsembleModelReadWrite {
* @param path Path to which to save the ensemble model.
* @param extraMetadata Metadata such as numFeatures, numClasses, numTrees.
*/
def saveImpl[M <: Params with TreeEnsembleModel](
def saveImpl[M <: Params with TreeEnsembleModel[_ <: DecisionTreeModel]](
instance: M,
path: String,
sql: SQLContext,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ private[ml] object TreeTests extends SparkFunSuite {
* Check if the two models are exactly the same.
* If the models are not equal, this throws an exception.
*/
def checkEqual(a: TreeEnsembleModel, b: TreeEnsembleModel): Unit = {
def checkEqual[M <: DecisionTreeModel](a: TreeEnsembleModel[M], b: TreeEnsembleModel[M]): Unit = {
try {
a.trees.zip(b.trees).foreach { case (treeA, treeB) =>
TreeTests.checkEqual(treeA, treeB)
Expand Down