Skip to content

Commit

Permalink
recognize more types
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed Apr 10, 2015
1 parent 35daac2 commit a52b101
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,5 @@ private[ml] trait Identifiable extends Serializable {
* random hex chars.
*/
private[ml] val uid: String =
this.getClass.getSimpleName + "-" + UUID.randomUUID().toString.take(8)
this.getClass.getSimpleName + "_" + UUID.randomUUID().toString.take(8)
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.apache.spark.ml.param.{HasInputCols, HasOutputCol, ParamMap}
import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors}
import org.apache.spark.sql.{Column, DataFrame, Row}
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.expressions.CreateStruct
import org.apache.spark.sql.catalyst.expressions.{Alias, Cast, CreateStruct}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._

Expand All @@ -48,7 +48,15 @@ class VectorAssembler extends Transformer with HasInputCols with HasOutputCol {
val assembleFunc = udf { r: Row =>
VectorAssembler.assemble(r.toSeq: _*)
}
val args = map(inputCols).map(c => UnresolvedAttribute(c))
val schema = dataset.schema
val inputColNames = map(inputCols)
val args = inputColNames.map { c =>
schema(c).dataType match {
case DoubleType => UnresolvedAttribute(c)
case t if t.isInstanceOf[VectorUDT] => UnresolvedAttribute(c)
case _: NativeType => Alias(Cast(UnresolvedAttribute(c), DoubleType), s"${c}_double_$uid")()
}
}
dataset.select(col("*"), assembleFunc(new Column(CreateStruct(args))).as(map(outputCol)))
}

Expand All @@ -57,10 +65,11 @@ class VectorAssembler extends Transformer with HasInputCols with HasOutputCol {
val inputColNames = map(inputCols)
val outputColName = map(outputCol)
val inputDataTypes = inputColNames.map(name => schema(name).dataType)
for (dataType <- inputDataTypes) {
if (!(dataType == DoubleType || dataType.isInstanceOf[VectorUDT])) {
throw new IllegalArgumentException(s"Data type $dataType is not supported.")
}
inputDataTypes.foreach {
case _: NativeType =>
case t if t.isInstanceOf[VectorUDT] =>
case other =>
throw new IllegalArgumentException(s"Data type $other is not supported.")
}
if (schema.fieldNames.contains(outputColName)) {
throw new IllegalArgumentException(s"Output column $outputColName already exists.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,14 @@ class VectorAssemblerSuite extends FunSuite with MLlibTestSparkContext {

test("VectorAssembler") {
val df = sqlContext.createDataFrame(Seq(
(0, 0.0, Vectors.dense(1.0, 2.0), "a", Vectors.sparse(2, Array(1), Array(3.0)))
)).toDF("id", "x", "y", "name", "z")
(0, 0.0, Vectors.dense(1.0, 2.0), "a", Vectors.sparse(2, Array(1), Array(3.0)), 10L)
)).toDF("id", "x", "y", "name", "z", "n")
val assembler = new VectorAssembler()
.setInputCols(Array("x", "y", "z"))
.setInputCols(Array("x", "y", "z", "n"))
.setOutputCol("features")
assembler.transform(df).select("features").collect().foreach {
case Row(v: Vector) =>
assert(v === Vectors.sparse(5, Array(1, 2, 4), Array(1.0, 2.0, 3.0)))
assert(v === Vectors.sparse(6, Array(1, 2, 4, 5), Array(1.0, 2.0, 3.0, 10.0)))
}
}
}

0 comments on commit a52b101

Please sign in to comment.