From c81072dabdaaf9b9ce6fb08c764f302f639c273c Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Wed, 6 May 2015 10:05:11 -0700 Subject: [PATCH] addressed comments --- .../scala/org/apache/spark/ml/param/params.scala | 16 ++++++---------- .../ml/param/shared/SharedParamsCodeGen.scala | 2 +- .../spark/ml/param/shared/sharedParams.scala | 2 +- 3 files changed, 8 insertions(+), 12 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 5c02c82659f5b..6525a5a9aee52 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -220,19 +220,15 @@ class BooleanParam(parent: Params, name: String, doc: String) // No need for isV } /** Specialized version of [[Param[Array[T]]]] for Java. */ -class ArrayParam[T : ClassTag]( - parent: Params, - name: String, - doc: String, - isValid: Array[T] => Boolean) - extends Param[Array[T]](parent, name, doc, isValid) { +class StringArrayParam(parent: Params, name: String, doc: String, isValid: Array[String] => Boolean) + extends Param[Array[String]](parent, name, doc, isValid) { def this(parent: Params, name: String, doc: String) = this(parent, name, doc, ParamValidators.alwaysTrue) - override def w(value: Array[T]): ParamPair[Array[T]] = super.w(value) + override def w(value: Array[String]): ParamPair[Array[String]] = super.w(value) - private[param] def wCast(value: Seq[T]): ParamPair[Array[T]] = w(value.toArray) + private[param] def wCast(value: Seq[String]): ParamPair[Array[String]] = w(value.toArray) } /** @@ -328,8 +324,8 @@ trait Params extends Identifiable with Serializable { */ protected final def set[T](param: Param[T], value: T): this.type = { shouldOwn(param) - if (param.isInstanceOf[ArrayParam[_]] && value.isInstanceOf[Seq[_]]) { - paramMap.put(param.asInstanceOf[ArrayParam[Any]].wCast(value.asInstanceOf[Seq[Any]])) + if (param.isInstanceOf[StringArrayParam] && value.isInstanceOf[Seq[_]]) { + paramMap.put(param.asInstanceOf[StringArrayParam].wCast(value.asInstanceOf[Seq[String]])) } else { paramMap.put(param.w(value)) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index ae0950653d8dc..aaa944a19782a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -83,7 +83,7 @@ private[shared] object SharedParamsCodeGen { case _ if c == classOf[Float] => "FloatParam" case _ if c == classOf[Double] => "DoubleParam" case _ if c == classOf[Boolean] => "BooleanParam" - case _ if c.isArray => s"ArrayParam[${getTypeString(c.getComponentType)}]" + case _ if c.isArray && c.getComponentType == classOf[String] => s"StringArrayParam" case _ => s"Param[${getTypeString(c)}]" } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index 736374114b82c..054a0123dc5b3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -178,7 +178,7 @@ private[ml] trait HasInputCols extends Params { * Param for input column names. * @group param */ - final val inputCols: ArrayParam[String] = new ArrayParam[String](this, "inputCols", "input column names") + final val inputCols: StringArrayParam = new StringArrayParam(this, "inputCols", "input column names") /** @group getParam */ final def getInputCols: Array[String] = $(inputCols)