Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
WeichenXu123 committed Dec 21, 2019
1 parent 22865e0 commit 05a525d
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 3 deletions.
5 changes: 3 additions & 2 deletions mllib/src/main/scala/org/apache/spark/ml/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,10 @@ object functions {
vec match {
case v: Vector => v.toArray
case v: OldVector => v.toArray
case _ => throw new IllegalArgumentException(
case v => throw new IllegalArgumentException(
"function vector_to_array requires a non-null input argument and input type must be " +
"`org.apache.spark.ml.linalg.Vector` or `org.apache.spark.mllib.linalg.Vector`.")
"`org.apache.spark.ml.linalg.Vector` or `org.apache.spark.mllib.linalg.Vector`, " +
s"but got ${ if (v == null) "null" else v.getClass.getName }.")
}
}.asNonNullable()

Expand Down
23 changes: 22 additions & 1 deletion mllib/src/test/scala/org/apache/spark/ml/FunctionsSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@

package org.apache.spark.ml

import org.apache.spark.SparkException
import org.apache.spark.ml.functions.vector_to_array
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.util.MLTest
import org.apache.spark.mllib.linalg.{Vectors => OldVectors}
import org.apache.spark.sql.functions.col

class FunctionsSuite extends MLTest {

Expand All @@ -32,7 +34,7 @@ class FunctionsSuite extends MLTest {
(Vectors.sparse(3, Seq((0, 2.0), (2, 3.0))), OldVectors.sparse(3, Seq((0, 20.0), (2, 30.0))))
).toDF("vec", "oldVec")

val result = df.select(vector_to_array('vec), vector_to_array('oldVec))
val result = df.select(vector_to_array('vec), vector_to_array('oldVec))
.as[(Seq[Double], Seq[Double])]
.collect().toSeq

Expand All @@ -41,5 +43,24 @@ class FunctionsSuite extends MLTest {
(Seq(2.0, 0.0, 3.0), Seq(20.0, 0.0, 30.0))
)
assert(result === expected)

val df2 = Seq(
(Vectors.dense(1.0, 2.0, 3.0),
OldVectors.dense(10.0, 20.0, 30.0), 1),
(null, null, 0)
).toDF("vec", "oldVec", "label")


for ((colName, valType) <- Seq(
("vec", "null"), ("oldVec", "null"), ("label", "java.lang.Integer"))) {
val thrown1 = intercept[SparkException] {
df2.select(vector_to_array(col(colName))).count
}
assert(thrown1.getCause.getMessage.contains(
"function vector_to_array requires a non-null input argument and input type must be " +
"`org.apache.spark.ml.linalg.Vector` or `org.apache.spark.mllib.linalg.Vector`, " +
s"but got ${valType}"))

}
}
}

0 comments on commit 05a525d

Please sign in to comment.