Skip to content

Commit

Permalink
Remove isDefaultParam.
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Mar 6, 2018
1 parent 69648d6 commit 166cdbb
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ private[ml] object DecisionTreeModelReadWrite {

// Get impurity to construct ImpurityCalculator for each node
val impurityType: String = {
val impurityJson: JValue = metadata.getParamValue("impurity", isDefaultParam = true)
val impurityJson: JValue = metadata.getParamValue("impurity")
Param.jsonDecode[String](compact(render(impurityJson)))
}

Expand Down Expand Up @@ -428,7 +428,7 @@ private[ml] object EnsembleModelReadWrite {

// Get impurity to construct ImpurityCalculator for each node
val impurityType: String = {
val impurityJson: JValue = metadata.getParamValue("impurity", isDefaultParam = true)
val impurityJson: JValue = metadata.getParamValue("impurity")
Param.jsonDecode[String](compact(render(impurityJson)))
}

Expand Down
45 changes: 29 additions & 16 deletions mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -363,28 +363,41 @@ private[ml] object DefaultParamsReader {
metadata: JValue,
metadataJson: String) {


private def getValueFromParams(params: JValue): Seq[(String, JValue)] = {
params match {
case JObject(pairs) => pairs
case _ =>
throw new IllegalArgumentException(
s"Cannot recognize JSON metadata: $metadataJson.")
}
}

/**
* Get the JSON value of the [[org.apache.spark.ml.param.Param]] of the given name.
* This can be useful for getting a Param value before an instance of `Params`
* is available.
*
* @param isDefaultParam Whether the given param name is a default param. Default is false.
* is available. This will look up `params` first, if not existing then looking up
* `defaultParams`.
*/
def getParamValue(paramName: String, isDefaultParam: Boolean = false): JValue = {
def getParamValue(paramName: String): JValue = {
implicit val format = DefaultFormats
val paramsToLookup = if (isDefaultParam) defaultParams else params
paramsToLookup match {
case JObject(pairs) =>
val values = pairs.filter { case (pName, jsonValue) =>
pName == paramName
}.map(_._2)
assert(values.length == 1, s"Expected one instance of Param '$paramName' but found" +
s" ${values.length} in JSON Params: " + pairs.map(_.toString).mkString(", "))
values.head
case _ =>
throw new IllegalArgumentException(
s"Cannot recognize JSON metadata: $metadataJson.")

// Looking up for `params` first.
var pairs = getValueFromParams(params)
var foundPairs = pairs.filter { case (pName, jsonValue) =>
pName == paramName
}
if (foundPairs.length == 0) {
// Looking up for `defaultParams` then.
pairs = getValueFromParams(defaultParams)
foundPairs = pairs.filter { case (pName, jsonValue) =>
pName == paramName
}
}
assert(foundPairs.length == 1, s"Expected one instance of Param '$paramName' but found" +
s" ${foundPairs.length} in JSON Params: " + pairs.map(_.toString).mkString(", "))

foundPairs.map(_._2).head
}

/**
Expand Down

0 comments on commit 166cdbb

Please sign in to comment.