Skip to content

Commit

Permalink
pass compile
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed May 7, 2015
1 parent 8726d39 commit 108937e
Show file tree
Hide file tree
Showing 30 changed files with 244 additions and 142 deletions.
7 changes: 6 additions & 1 deletion mllib/src/main/scala/org/apache/spark/ml/Model.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,12 @@ abstract class Model[M <: Model[M]] extends Transformer {
* The parent estimator that produced this model.
* Note: For ensembles' component Models, this value can be null.
*/
val parent: Estimator[M]
var parent: Estimator[M] = _

def setParent(parent: Estimator[M]): M = {
this.parent = parent
this.asInstanceOf[M]
}

override def copy(extra: ParamMap): M = {
// The default implementation of Params.copy doesn't work for models.
Expand Down
13 changes: 8 additions & 5 deletions mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import scala.collection.mutable.ListBuffer
import org.apache.spark.Logging
import org.apache.spark.annotation.{AlphaComponent, DeveloperApi}
import org.apache.spark.ml.param.{Param, ParamMap, Params}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.StructType

Expand Down Expand Up @@ -80,13 +81,15 @@ abstract class PipelineStage extends Params with Logging {
* an identity transformer.
*/
@AlphaComponent
class Pipeline extends Estimator[PipelineModel] {
class Pipeline(override val uid: String) extends Estimator[PipelineModel] {

def this() = this(Identifiable.randomUID("pipeline"))

/**
* param for pipeline stages
* @group param
*/
val stages: Param[Array[PipelineStage]] = new Param(this, "stages", "stages of the pipeline")
val stages: Param[Array[PipelineStage]] = new Param(uid, "stages", "stages of the pipeline")

/** @group setParam */
def setStages(value: Array[PipelineStage]): this.type = { set(stages, value); this }
Expand Down Expand Up @@ -148,7 +151,7 @@ class Pipeline extends Estimator[PipelineModel] {
}
}

new PipelineModel(this, transformers.toArray)
new PipelineModel(uid, transformers.toArray).setParent(this)
}

override def copy(extra: ParamMap): Pipeline = {
Expand All @@ -171,7 +174,7 @@ class Pipeline extends Estimator[PipelineModel] {
*/
@AlphaComponent
class PipelineModel private[ml] (
override val parent: Pipeline,
val uid: String,
val stages: Array[Transformer])
extends Model[PipelineModel] with Logging {

Expand All @@ -190,6 +193,6 @@ class PipelineModel private[ml] (
}

override def copy(extra: ParamMap): PipelineModel = {
new PipelineModel(parent, stages)
new PipelineModel(uid, stages)
}
}
2 changes: 1 addition & 1 deletion mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ abstract class Predictor[
// This handles a few items such as schema validation.
// Developers only need to implement train().
transformSchema(dataset.schema, logging = true)
copyValues(train(dataset))
copyValues(train(dataset).setParent(this))
}

override def copy(extra: ParamMap): Learner = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.{PredictionModel, Predictor}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tree.{TreeClassifierParams, DecisionTreeParams, DecisionTreeModel, Node}
import org.apache.spark.ml.util.MetadataUtils
import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree}
Expand All @@ -39,10 +39,12 @@ import org.apache.spark.sql.DataFrame
* features.
*/
@AlphaComponent
final class DecisionTreeClassifier
final class DecisionTreeClassifier(override val uid: String)
extends Predictor[Vector, DecisionTreeClassifier, DecisionTreeClassificationModel]
with DecisionTreeParams with TreeClassifierParams {

def this() = this(Identifiable.randomUID("dtc"))

// Override parameter setters from parent trait for Java API compatibility.

override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value)
Expand Down Expand Up @@ -101,7 +103,7 @@ object DecisionTreeClassifier {
*/
@AlphaComponent
final class DecisionTreeClassificationModel private[ml] (
override val parent: DecisionTreeClassifier,
override val uid: String,
override val rootNode: Node)
extends PredictionModel[Vector, DecisionTreeClassificationModel]
with DecisionTreeModel with Serializable {
Expand All @@ -114,7 +116,7 @@ final class DecisionTreeClassificationModel private[ml] (
}

override def copy(extra: ParamMap): DecisionTreeClassificationModel = {
copyValues(new DecisionTreeClassificationModel(parent, rootNode), extra)
copyValues(new DecisionTreeClassificationModel(uid, rootNode), extra)
}

override def toString: String = {
Expand All @@ -138,6 +140,6 @@ private[ml] object DecisionTreeClassificationModel {
s"Cannot convert non-classification DecisionTreeModel (old API) to" +
s" DecisionTreeClassificationModel (new API). Algo is: ${oldModel.algo}")
val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures)
new DecisionTreeClassificationModel(parent, rootNode)
new DecisionTreeClassificationModel(parent.uid, rootNode)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.spark.ml.{PredictionModel, Predictor}
import org.apache.spark.ml.param.{Param, ParamMap}
import org.apache.spark.ml.regression.DecisionTreeRegressionModel
import org.apache.spark.ml.tree.{GBTParams, TreeClassifierParams, DecisionTreeModel, TreeEnsembleModel}
import org.apache.spark.ml.util.MetadataUtils
import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{GradientBoostedTrees => OldGBT}
Expand All @@ -44,10 +44,12 @@ import org.apache.spark.sql.DataFrame
* Note: Multiclass labels are not currently supported.
*/
@AlphaComponent
final class GBTClassifier
final class GBTClassifier(override val uid: String)
extends Predictor[Vector, GBTClassifier, GBTClassificationModel]
with GBTParams with TreeClassifierParams with Logging {

def this() = this(Identifiable.randomUID("gbtc"))

// Override parameter setters from parent trait for Java API compatibility.

// Parameters from TreeClassifierParams:
Expand Down Expand Up @@ -99,7 +101,7 @@ final class GBTClassifier
* (default = logistic)
* @group param
*/
val lossType: Param[String] = new Param[String](this, "lossType", "Loss function which GBT" +
val lossType: Param[String] = new Param[String](uid, "lossType", "Loss function which GBT" +
" tries to minimize (case-insensitive). Supported options:" +
s" ${GBTClassifier.supportedLossTypes.mkString(", ")}",
(value: String) => GBTClassifier.supportedLossTypes.contains(value.toLowerCase))
Expand Down Expand Up @@ -160,7 +162,7 @@ object GBTClassifier {
*/
@AlphaComponent
final class GBTClassificationModel(
override val parent: GBTClassifier,
override val uid: String,
private val _trees: Array[DecisionTreeRegressionModel],
private val _treeWeights: Array[Double])
extends PredictionModel[Vector, GBTClassificationModel]
Expand All @@ -184,7 +186,7 @@ final class GBTClassificationModel(
}

override def copy(extra: ParamMap): GBTClassificationModel = {
copyValues(new GBTClassificationModel(parent, _trees, _treeWeights), extra)
copyValues(new GBTClassificationModel(uid, _trees, _treeWeights), extra)
}

override def toString: String = {
Expand All @@ -210,6 +212,6 @@ private[ml] object GBTClassificationModel {
// parent, fittingParamMap for each tree is null since there are no good ways to set these.
DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures)
}
new GBTClassificationModel(parent, newTrees, oldModel.treeWeights)
new GBTClassificationModel(parent.uid, newTrees, oldModel.treeWeights)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.ml.classification
import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
import org.apache.spark.mllib.linalg._
import org.apache.spark.sql.DataFrame
Expand All @@ -41,10 +42,12 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas
* Currently, this class only supports binary classification.
*/
@AlphaComponent
class LogisticRegression
class LogisticRegression(override val uid: String)
extends ProbabilisticClassifier[Vector, LogisticRegression, LogisticRegressionModel]
with LogisticRegressionParams {

def this() = this(Identifiable.randomUID("logreg"))

/** @group setParam */
def setRegParam(value: Double): this.type = set(regParam, value)

Expand Down Expand Up @@ -72,7 +75,7 @@ class LogisticRegression
.setRegParam($(regParam))
.setNumIterations($(maxIter))
val oldModel = lr.run(oldDataset)
val lrm = new LogisticRegressionModel(this, oldModel.weights, oldModel.intercept)
val lrm = new LogisticRegressionModel(uid, oldModel.weights, oldModel.intercept)

if (handlePersistence) {
oldDataset.unpersist()
Expand All @@ -89,7 +92,7 @@ class LogisticRegression
*/
@AlphaComponent
class LogisticRegressionModel private[ml] (
override val parent: LogisticRegression,
override val uid: String,
val weights: Vector,
val intercept: Double)
extends ProbabilisticClassificationModel[Vector, LogisticRegressionModel]
Expand Down Expand Up @@ -140,7 +143,7 @@ class LogisticRegressionModel private[ml] (
}

override def copy(extra: ParamMap): LogisticRegressionModel = {
copyValues(new LogisticRegressionModel(parent, weights, intercept), extra)
copyValues(new LogisticRegressionModel(uid, weights, intercept), extra)
}

override protected def raw2prediction(rawPrediction: Vector): Double = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.{PredictionModel, Predictor}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tree.{RandomForestParams, TreeClassifierParams, DecisionTreeModel, TreeEnsembleModel}
import org.apache.spark.ml.util.MetadataUtils
import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{RandomForest => OldRandomForest}
Expand All @@ -41,10 +41,12 @@ import org.apache.spark.sql.DataFrame
* features.
*/
@AlphaComponent
final class RandomForestClassifier
final class RandomForestClassifier(override val uid: String)
extends Predictor[Vector, RandomForestClassifier, RandomForestClassificationModel]
with RandomForestParams with TreeClassifierParams {

def this() = this(Identifiable.randomUID("rfc"))

// Override parameter setters from parent trait for Java API compatibility.

// Parameters from TreeClassifierParams:
Expand Down Expand Up @@ -118,7 +120,7 @@ object RandomForestClassifier {
*/
@AlphaComponent
final class RandomForestClassificationModel private[ml] (
override val parent: RandomForestClassifier,
override val uid: String,
private val _trees: Array[DecisionTreeClassificationModel])
extends PredictionModel[Vector, RandomForestClassificationModel]
with TreeEnsembleModel with Serializable {
Expand Down Expand Up @@ -146,7 +148,7 @@ final class RandomForestClassificationModel private[ml] (
}

override def copy(extra: ParamMap): RandomForestClassificationModel = {
copyValues(new RandomForestClassificationModel(parent, _trees), extra)
copyValues(new RandomForestClassificationModel(uid, _trees), extra)
}

override def toString: String = {
Expand All @@ -172,6 +174,6 @@ private[ml] object RandomForestClassificationModel {
// parent, fittingParamMap for each tree is null since there are no good ways to set these.
DecisionTreeClassificationModel.fromOld(tree, null, categoricalFeatures)
}
new RandomForestClassificationModel(parent, newTrees)
new RandomForestClassificationModel(parent.uid, newTrees)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.Evaluator
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.SchemaUtils
import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
import org.apache.spark.sql.{DataFrame, Row}
Expand All @@ -33,13 +33,16 @@ import org.apache.spark.sql.types.DoubleType
* Evaluator for binary classification, which expects two input columns: score and label.
*/
@AlphaComponent
class BinaryClassificationEvaluator extends Evaluator with HasRawPredictionCol with HasLabelCol {
class BinaryClassificationEvaluator(override val uid: String)
extends Evaluator with HasRawPredictionCol with HasLabelCol {

def this() = this(Identifiable.randomUID("binEval"))

/**
* param for metric name in evaluation
* @group param
*/
val metricName: Param[String] = new Param(this, "metricName",
val metricName: Param[String] = new Param(uid, "metricName",
"metric name in evaluation (areaUnderROC|areaUnderPR)")

/** @group getParam */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import org.apache.spark.ml.Transformer
import org.apache.spark.ml.attribute.BinaryAttribute
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.util.SchemaUtils
import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DoubleType, StructType}
Expand All @@ -32,7 +32,10 @@ import org.apache.spark.sql.types.{DoubleType, StructType}
* Binarize a column of continuous features given a threshold.
*/
@AlphaComponent
final class Binarizer extends Transformer with HasInputCol with HasOutputCol {
final class Binarizer(override val uid: String)
extends Transformer with HasInputCol with HasOutputCol {

def this() = this(Identifiable.randomUID("binarizer"))

/**
* Param for threshold used to binarize continuous features.
Expand All @@ -41,7 +44,7 @@ final class Binarizer extends Transformer with HasInputCol with HasOutputCol {
* @group param
*/
val threshold: DoubleParam =
new DoubleParam(this, "threshold", "threshold used to binarize continuous features")
new DoubleParam(uid, "threshold", "threshold used to binarize continuous features")

/** @group getParam */
def getThreshold: Double = $(threshold)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.ml.feature
import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.UnaryTransformer
import org.apache.spark.ml.param.{IntParam, ParamValidators}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.mllib.feature
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
import org.apache.spark.sql.types.DataType
Expand All @@ -29,14 +30,16 @@ import org.apache.spark.sql.types.DataType
* Maps a sequence of terms to their term frequencies using the hashing trick.
*/
@AlphaComponent
class HashingTF extends UnaryTransformer[Iterable[_], Vector, HashingTF] {
class HashingTF(override val uid: String) extends UnaryTransformer[Iterable[_], Vector, HashingTF] {

def this() = this(Identifiable.randomUID("hashingTF"))

/**
* Number of features. Should be > 0.
* (default = 2^18^)
* @group param
*/
val numFeatures = new IntParam(this, "numFeatures", "number of features (> 0)",
val numFeatures = new IntParam(uid, "numFeatures", "number of features (> 0)",
ParamValidators.gt(0))

setDefault(numFeatures -> (1 << 18))
Expand Down
Loading

0 comments on commit 108937e

Please sign in to comment.