From a52b10162a4d1cca9cb86d726771b481844499c9 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Fri, 10 Apr 2015 00:23:36 -0700 Subject: [PATCH] recognize more types --- .../org/apache/spark/ml/Identifiable.scala | 2 +- .../spark/ml/feature/VectorAssembler.scala | 21 +++++++++++++------ .../ml/feature/VectorAssemblerSuite.scala | 8 +++---- 3 files changed, 20 insertions(+), 11 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/Identifiable.scala b/mllib/src/main/scala/org/apache/spark/ml/Identifiable.scala index cd84b05bfb496..a50090671ae48 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Identifiable.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Identifiable.scala @@ -29,5 +29,5 @@ private[ml] trait Identifiable extends Serializable { * random hex chars. */ private[ml] val uid: String = - this.getClass.getSimpleName + "-" + UUID.randomUUID().toString.take(8) + this.getClass.getSimpleName + "_" + UUID.randomUUID().toString.take(8) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index 8823db7824919..d1b8f7e6e9295 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -26,7 +26,7 @@ import org.apache.spark.ml.param.{HasInputCols, HasOutputCol, ParamMap} import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} import org.apache.spark.sql.{Column, DataFrame, Row} import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute -import org.apache.spark.sql.catalyst.expressions.CreateStruct +import org.apache.spark.sql.catalyst.expressions.{Alias, Cast, CreateStruct} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ @@ -48,7 +48,15 @@ class VectorAssembler extends Transformer with HasInputCols with HasOutputCol { val assembleFunc = udf { r: Row => VectorAssembler.assemble(r.toSeq: _*) } - val args = map(inputCols).map(c => UnresolvedAttribute(c)) + val schema = dataset.schema + val inputColNames = map(inputCols) + val args = inputColNames.map { c => + schema(c).dataType match { + case DoubleType => UnresolvedAttribute(c) + case t if t.isInstanceOf[VectorUDT] => UnresolvedAttribute(c) + case _: NativeType => Alias(Cast(UnresolvedAttribute(c), DoubleType), s"${c}_double_$uid")() + } + } dataset.select(col("*"), assembleFunc(new Column(CreateStruct(args))).as(map(outputCol))) } @@ -57,10 +65,11 @@ class VectorAssembler extends Transformer with HasInputCols with HasOutputCol { val inputColNames = map(inputCols) val outputColName = map(outputCol) val inputDataTypes = inputColNames.map(name => schema(name).dataType) - for (dataType <- inputDataTypes) { - if (!(dataType == DoubleType || dataType.isInstanceOf[VectorUDT])) { - throw new IllegalArgumentException(s"Data type $dataType is not supported.") - } + inputDataTypes.foreach { + case _: NativeType => + case t if t.isInstanceOf[VectorUDT] => + case other => + throw new IllegalArgumentException(s"Data type $other is not supported.") } if (schema.fieldNames.contains(outputColName)) { throw new IllegalArgumentException(s"Output column $outputColName already exists.") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala index d2613d524c9f4..57d0278e03639 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala @@ -50,14 +50,14 @@ class VectorAssemblerSuite extends FunSuite with MLlibTestSparkContext { test("VectorAssembler") { val df = sqlContext.createDataFrame(Seq( - (0, 0.0, Vectors.dense(1.0, 2.0), "a", Vectors.sparse(2, Array(1), Array(3.0))) - )).toDF("id", "x", "y", "name", "z") + (0, 0.0, Vectors.dense(1.0, 2.0), "a", Vectors.sparse(2, Array(1), Array(3.0)), 10L) + )).toDF("id", "x", "y", "name", "z", "n") val assembler = new VectorAssembler() - .setInputCols(Array("x", "y", "z")) + .setInputCols(Array("x", "y", "z", "n")) .setOutputCol("features") assembler.transform(df).select("features").collect().foreach { case Row(v: Vector) => - assert(v === Vectors.sparse(5, Array(1, 2, 4), Array(1.0, 2.0, 3.0))) + assert(v === Vectors.sparse(6, Array(1, 2, 4, 5), Array(1.0, 2.0, 3.0, 10.0))) } } }