Skip to content

Commit

Permalink
Save default params separately in JSON.
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Feb 17, 2018
1 parent d5ed210 commit 69648d6
Show file tree
Hide file tree
Showing 45 changed files with 184 additions and 143 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ object DecisionTreeClassificationModel extends MLReadable[DecisionTreeClassifica
val numClasses = (metadata.metadata \ "numClasses").extract[Int]
val root = loadTreeNodes(path, metadata, sparkSession)
val model = new DecisionTreeClassificationModel(metadata.uid, root, numFeatures, numClasses)
DefaultParamsReader.getAndSetParams(model, metadata)
metadata.getAndSetParams(model)
model
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -379,14 +379,14 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] {
case (treeMetadata, root) =>
val tree =
new DecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures)
DefaultParamsReader.getAndSetParams(tree, treeMetadata)
treeMetadata.getAndSetParams(tree)
tree
}
require(numTrees == trees.length, s"GBTClassificationModel.load expected $numTrees" +
s" trees based on metadata but found ${trees.length} trees.")
val model = new GBTClassificationModel(metadata.uid,
trees, treeWeights, numFeatures)
DefaultParamsReader.getAndSetParams(model, metadata)
metadata.getAndSetParams(model)
model
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ object LinearSVCModel extends MLReadable[LinearSVCModel] {
val Row(coefficients: Vector, intercept: Double) =
data.select("coefficients", "intercept").head()
val model = new LinearSVCModel(metadata.uid, coefficients, intercept)
DefaultParamsReader.getAndSetParams(model, metadata)
metadata.getAndSetParams(model)
model
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1267,7 +1267,7 @@ object LogisticRegressionModel extends MLReadable[LogisticRegressionModel] {
numClasses, isMultinomial)
}

DefaultParamsReader.getAndSetParams(model, metadata)
metadata.getAndSetParams(model)
model
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ object MultilayerPerceptronClassificationModel
val weights = data.getAs[Vector](1)
val model = new MultilayerPerceptronClassificationModel(metadata.uid, layers, weights)

DefaultParamsReader.getAndSetParams(model, metadata)
metadata.getAndSetParams(model)
model
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ object NaiveBayesModel extends MLReadable[NaiveBayesModel] {
.head()
val model = new NaiveBayesModel(metadata.uid, pi, theta)

DefaultParamsReader.getAndSetParams(model, metadata)
metadata.getAndSetParams(model)
model
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ object OneVsRestModel extends MLReadable[OneVsRestModel] {
DefaultParamsReader.loadParamsInstance[ClassificationModel[_, _]](modelPath, sc)
}
val ovrModel = new OneVsRestModel(metadata.uid, labelMetadata, models)
DefaultParamsReader.getAndSetParams(ovrModel, metadata)
metadata.getAndSetParams(ovrModel)
ovrModel.set("classifier", classifier)
ovrModel
}
Expand Down Expand Up @@ -448,7 +448,7 @@ object OneVsRest extends MLReadable[OneVsRest] {
override def load(path: String): OneVsRest = {
val (metadata, classifier) = OneVsRestParams.loadImpl(path, sc, className)
val ovr = new OneVsRest(metadata.uid)
DefaultParamsReader.getAndSetParams(ovr, metadata)
metadata.getAndSetParams(ovr)
ovr.setClassifier(classifier)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -319,14 +319,14 @@ object RandomForestClassificationModel extends MLReadable[RandomForestClassifica
case (treeMetadata, root) =>
val tree =
new DecisionTreeClassificationModel(treeMetadata.uid, root, numFeatures, numClasses)
DefaultParamsReader.getAndSetParams(tree, treeMetadata)
treeMetadata.getAndSetParams(tree)
tree
}
require(numTrees == trees.length, s"RandomForestClassificationModel.load expected $numTrees" +
s" trees based on metadata but found ${trees.length} trees.")

val model = new RandomForestClassificationModel(metadata.uid, trees, numFeatures, numClasses)
DefaultParamsReader.getAndSetParams(model, metadata)
metadata.getAndSetParams(model)
model
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ object BisectingKMeansModel extends MLReadable[BisectingKMeansModel] {
val dataPath = new Path(path, "data").toString
val mllibModel = MLlibBisectingKMeansModel.load(sc, dataPath)
val model = new BisectingKMeansModel(metadata.uid, mllibModel)
DefaultParamsReader.getAndSetParams(model, metadata)
metadata.getAndSetParams(model)
model
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ object GaussianMixtureModel extends MLReadable[GaussianMixtureModel] {
}
val model = new GaussianMixtureModel(metadata.uid, weights, gaussians)

DefaultParamsReader.getAndSetParams(model, metadata)
metadata.getAndSetParams(model)
model
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ object KMeansModel extends MLReadable[KMeansModel] {
sparkSession.read.parquet(dataPath).as[OldData].head().clusterCenters
}
val model = new KMeansModel(metadata.uid, new MLlibKMeansModel(clusterCenters))
DefaultParamsReader.getAndSetParams(model, metadata)
metadata.getAndSetParams(model)
model
}
}
Expand Down
4 changes: 2 additions & 2 deletions mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM
private object LDAParams {

/**
* Equivalent to [[DefaultParamsReader.getAndSetParams()]], but handles [[LDA]] and [[LDAModel]]
* Equivalent to [[Metadata.getAndSetParams()]], but handles [[LDA]] and [[LDAModel]]
* formats saved with Spark 1.6, which differ from the formats in Spark 2.0+.
*
* @param model [[LDA]] or [[LDAModel]] instance. This instance will be modified with
Expand All @@ -391,7 +391,7 @@ private object LDAParams {
s"Cannot recognize JSON metadata: ${metadata.metadataJson}.")
}
case _ => // 2.0+
DefaultParamsReader.getAndSetParams(model, metadata)
metadata.getAndSetParams(model)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ object BucketedRandomProjectionLSHModel extends MLReadable[BucketedRandomProject
val model = new BucketedRandomProjectionLSHModel(metadata.uid,
randUnitVectors.rowIter.toArray)

DefaultParamsReader.getAndSetParams(model, metadata)
metadata.getAndSetParams(model)
model
}
}
Expand Down
24 changes: 0 additions & 24 deletions mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,6 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String
override def copy(extra: ParamMap): Bucketizer = {
defaultCopy[Bucketizer](extra).setParent(parent)
}

override def write: MLWriter = new Bucketizer.BucketizerWriter(this)
}

@Since("1.6.0")
Expand Down Expand Up @@ -296,28 +294,6 @@ object Bucketizer extends DefaultParamsReadable[Bucketizer] {
}
}


private[Bucketizer] class BucketizerWriter(instance: Bucketizer) extends MLWriter {

override protected def saveImpl(path: String): Unit = {
// SPARK-23377: The default params will be saved and loaded as user-supplied params.
// Once `inputCols` is set, the default value of `outputCol` param causes the error
// when checking exclusive params. As a temporary to fix it, we skip the default value
// of `outputCol` if `inputCols` is set when saving the metadata.
// TODO: If we modify the persistence mechanism later to better handle default params,
// we can get rid of this.
var paramWithoutOutputCol: Option[JValue] = None
if (instance.isSet(instance.inputCols)) {
val params = instance.extractParamMap().toSeq
val jsonParams = params.filter(_.param != instance.outputCol).map { case ParamPair(p, v) =>
p.name -> parse(p.jsonEncode(v))
}.toList
paramWithoutOutputCol = Some(render(jsonParams))
}
DefaultParamsWriter.saveMetadata(instance, path, sc, paramMap = paramWithoutOutputCol)
}
}

@Since("1.6.0")
override def load(path: String): Bucketizer = super.load(path)
}
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ object ChiSqSelectorModel extends MLReadable[ChiSqSelectorModel] {
val selectedFeatures = data.getAs[Seq[Int]](0).toArray
val oldModel = new feature.ChiSqSelectorModel(selectedFeatures)
val model = new ChiSqSelectorModel(metadata.uid, oldModel)
DefaultParamsReader.getAndSetParams(model, metadata)
metadata.getAndSetParams(model)
model
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ object CountVectorizerModel extends MLReadable[CountVectorizerModel] {
.head()
val vocabulary = data.getAs[Seq[String]](0).toArray
val model = new CountVectorizerModel(metadata.uid, vocabulary)
DefaultParamsReader.getAndSetParams(model, metadata)
metadata.getAndSetParams(model)
model
}
}
Expand Down
2 changes: 1 addition & 1 deletion mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ object IDFModel extends MLReadable[IDFModel] {
.select("idf")
.head()
val model = new IDFModel(metadata.uid, new feature.IDFModel(OldVectors.fromML(idf)))
DefaultParamsReader.getAndSetParams(model, metadata)
metadata.getAndSetParams(model)
model
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ object ImputerModel extends MLReadable[ImputerModel] {
val dataPath = new Path(path, "data").toString
val surrogateDF = sqlContext.read.parquet(dataPath)
val model = new ImputerModel(metadata.uid, surrogateDF)
DefaultParamsReader.getAndSetParams(model, metadata)
metadata.getAndSetParams(model)
model
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ object MaxAbsScalerModel extends MLReadable[MaxAbsScalerModel] {
.select("maxAbs")
.head()
val model = new MaxAbsScalerModel(metadata.uid, maxAbs)
DefaultParamsReader.getAndSetParams(model, metadata)
metadata.getAndSetParams(model)
model
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ object MinHashLSHModel extends MLReadable[MinHashLSHModel] {
.map(tuple => (tuple(0), tuple(1))).toArray
val model = new MinHashLSHModel(metadata.uid, randCoefficients)

DefaultParamsReader.getAndSetParams(model, metadata)
metadata.getAndSetParams(model)
model
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ object MinMaxScalerModel extends MLReadable[MinMaxScalerModel] {
.select("originalMin", "originalMax")
.head()
val model = new MinMaxScalerModel(metadata.uid, originalMin, originalMax)
DefaultParamsReader.getAndSetParams(model, metadata)
metadata.getAndSetParams(model)
model
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ object OneHotEncoderModel extends MLReadable[OneHotEncoderModel] {
.head()
val categorySizes = data.getAs[Seq[Int]](0).toArray
val model = new OneHotEncoderModel(metadata.uid, categorySizes)
DefaultParamsReader.getAndSetParams(model, metadata)
metadata.getAndSetParams(model)
model
}
}
Expand Down
2 changes: 1 addition & 1 deletion mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ object PCAModel extends MLReadable[PCAModel] {
new PCAModel(metadata.uid, pc.asML,
Vectors.dense(Array.empty[Double]).asInstanceOf[DenseVector])
}
DefaultParamsReader.getAndSetParams(model, metadata)
metadata.getAndSetParams(model)
model
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -253,35 +253,11 @@ final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val ui

@Since("1.6.0")
override def copy(extra: ParamMap): QuantileDiscretizer = defaultCopy(extra)

override def write: MLWriter = new QuantileDiscretizer.QuantileDiscretizerWriter(this)
}

@Since("1.6.0")
object QuantileDiscretizer extends DefaultParamsReadable[QuantileDiscretizer] with Logging {

private[QuantileDiscretizer]
class QuantileDiscretizerWriter(instance: QuantileDiscretizer) extends MLWriter {

override protected def saveImpl(path: String): Unit = {
// SPARK-23377: The default params will be saved and loaded as user-supplied params.
// Once `inputCols` is set, the default value of `outputCol` param causes the error
// when checking exclusive params. As a temporary to fix it, we skip the default value
// of `outputCol` if `inputCols` is set when saving the metadata.
// TODO: If we modify the persistence mechanism later to better handle default params,
// we can get rid of this.
var paramWithoutOutputCol: Option[JValue] = None
if (instance.isSet(instance.inputCols)) {
val params = instance.extractParamMap().toSeq
val jsonParams = params.filter(_.param != instance.outputCol).map { case ParamPair(p, v) =>
p.name -> parse(p.jsonEncode(v))
}.toList
paramWithoutOutputCol = Some(render(jsonParams))
}
DefaultParamsWriter.saveMetadata(instance, path, sc, paramMap = paramWithoutOutputCol)
}
}

@Since("1.6.0")
override def load(path: String): QuantileDiscretizer = super.load(path)
}
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ object RFormulaModel extends MLReadable[RFormulaModel] {

val model = new RFormulaModel(metadata.uid, resolvedRFormula, pipelineModel)

DefaultParamsReader.getAndSetParams(model, metadata)
metadata.getAndSetParams(model)
model
}
}
Expand Down Expand Up @@ -509,7 +509,7 @@ private object ColumnPruner extends MLReadable[ColumnPruner] {
val columnsToPrune = data.getAs[Seq[String]](0).toSet
val pruner = new ColumnPruner(metadata.uid, columnsToPrune)

DefaultParamsReader.getAndSetParams(pruner, metadata)
metadata.getAndSetParams(pruner)
pruner
}
}
Expand Down Expand Up @@ -601,7 +601,7 @@ private object VectorAttributeRewriter extends MLReadable[VectorAttributeRewrite
val prefixesToRewrite = data.getAs[Map[String, String]](1)
val rewriter = new VectorAttributeRewriter(metadata.uid, vectorCol, prefixesToRewrite)

DefaultParamsReader.getAndSetParams(rewriter, metadata)
metadata.getAndSetParams(rewriter)
rewriter
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ object StandardScalerModel extends MLReadable[StandardScalerModel] {
.select("std", "mean")
.head()
val model = new StandardScalerModel(metadata.uid, std, mean)
DefaultParamsReader.getAndSetParams(model, metadata)
metadata.getAndSetParams(model)
model
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ object StringIndexerModel extends MLReadable[StringIndexerModel] {
.head()
val labels = data.getAs[Seq[String]](0).toArray
val model = new StringIndexerModel(metadata.uid, labels)
DefaultParamsReader.getAndSetParams(model, metadata)
metadata.getAndSetParams(model)
model
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,7 @@ object VectorIndexerModel extends MLReadable[VectorIndexerModel] {
val numFeatures = data.getAs[Int](0)
val categoryMaps = data.getAs[Map[Int, Map[Double, Int]]](1)
val model = new VectorIndexerModel(metadata.uid, numFeatures, categoryMaps)
DefaultParamsReader.getAndSetParams(model, metadata)
metadata.getAndSetParams(model)
model
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ object Word2VecModel extends MLReadable[Word2VecModel] {
}

val model = new Word2VecModel(metadata.uid, oldModel)
DefaultParamsReader.getAndSetParams(model, metadata)
metadata.getAndSetParams(model)
model
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ object FPGrowthModel extends MLReadable[FPGrowthModel] {
val dataPath = new Path(path, "data").toString
val frequentItems = sparkSession.read.parquet(dataPath)
val model = new FPGrowthModel(metadata.uid, frequentItems)
DefaultParamsReader.getAndSetParams(model, metadata)
metadata.getAndSetParams(model)
model
}
}
Expand Down
6 changes: 3 additions & 3 deletions mllib/src/main/scala/org/apache/spark/ml/param/params.scala
Original file line number Diff line number Diff line change
Expand Up @@ -791,7 +791,7 @@ trait Params extends Identifiable with Serializable {
* this method gets called.
* @param value the default value
*/
protected final def setDefault[T](param: Param[T], value: T): this.type = {
private[ml] final def setDefault[T](param: Param[T], value: T): this.type = {
defaultParamMap.put(param -> value)
this
}
Expand Down Expand Up @@ -865,10 +865,10 @@ trait Params extends Identifiable with Serializable {
}

/** Internal param map for user-supplied values. */
private val paramMap: ParamMap = ParamMap.empty
private[ml] val paramMap: ParamMap = ParamMap.empty

/** Internal param map for default values. */
private val defaultParamMap: ParamMap = ParamMap.empty
private[ml] val defaultParamMap: ParamMap = ParamMap.empty

/** Validates that the input param belongs to this instance. */
private def shouldOwn(param: Param[_]): Unit = {
Expand Down
Loading

0 comments on commit 69648d6

Please sign in to comment.