From 166cdbb3e95315e0feb29fb26c6c98837747e22d Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 6 Mar 2018 07:50:46 +0000 Subject: [PATCH] Remove isDefaultParam. --- .../org/apache/spark/ml/tree/treeModels.scala | 4 +- .../org/apache/spark/ml/util/ReadWrite.scala | 45 ++++++++++++------- 2 files changed, 31 insertions(+), 18 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala index 264a086e7d795..4aa4c3617e7fd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala @@ -333,7 +333,7 @@ private[ml] object DecisionTreeModelReadWrite { // Get impurity to construct ImpurityCalculator for each node val impurityType: String = { - val impurityJson: JValue = metadata.getParamValue("impurity", isDefaultParam = true) + val impurityJson: JValue = metadata.getParamValue("impurity") Param.jsonDecode[String](compact(render(impurityJson))) } @@ -428,7 +428,7 @@ private[ml] object EnsembleModelReadWrite { // Get impurity to construct ImpurityCalculator for each node val impurityType: String = { - val impurityJson: JValue = metadata.getParamValue("impurity", isDefaultParam = true) + val impurityJson: JValue = metadata.getParamValue("impurity") Param.jsonDecode[String](compact(render(impurityJson))) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index 31b4612e33a2a..1a8825df3499d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -363,28 +363,41 @@ private[ml] object DefaultParamsReader { metadata: JValue, metadataJson: String) { + + private def getValueFromParams(params: JValue): Seq[(String, JValue)] = { + params match { + case JObject(pairs) => pairs + case _ => + throw new IllegalArgumentException( + s"Cannot recognize JSON metadata: $metadataJson.") + } + } + /** * Get the JSON value of the [[org.apache.spark.ml.param.Param]] of the given name. * This can be useful for getting a Param value before an instance of `Params` - * is available. - * - * @param isDefaultParam Whether the given param name is a default param. Default is false. + * is available. This will look up `params` first, if not existing then looking up + * `defaultParams`. */ - def getParamValue(paramName: String, isDefaultParam: Boolean = false): JValue = { + def getParamValue(paramName: String): JValue = { implicit val format = DefaultFormats - val paramsToLookup = if (isDefaultParam) defaultParams else params - paramsToLookup match { - case JObject(pairs) => - val values = pairs.filter { case (pName, jsonValue) => - pName == paramName - }.map(_._2) - assert(values.length == 1, s"Expected one instance of Param '$paramName' but found" + - s" ${values.length} in JSON Params: " + pairs.map(_.toString).mkString(", ")) - values.head - case _ => - throw new IllegalArgumentException( - s"Cannot recognize JSON metadata: $metadataJson.") + + // Looking up for `params` first. + var pairs = getValueFromParams(params) + var foundPairs = pairs.filter { case (pName, jsonValue) => + pName == paramName + } + if (foundPairs.length == 0) { + // Looking up for `defaultParams` then. + pairs = getValueFromParams(defaultParams) + foundPairs = pairs.filter { case (pName, jsonValue) => + pName == paramName + } } + assert(foundPairs.length == 1, s"Expected one instance of Param '$paramName' but found" + + s" ${foundPairs.length} in JSON Params: " + pairs.map(_.toString).mkString(", ")) + + foundPairs.map(_._2).head } /**