From 3b83ec090bfa7fe6595cce41c6c3dd20ac524842 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 3 Dec 2014 10:09:52 +0800 Subject: [PATCH] replace TypeTag with explicit datatype --- .../org/apache/spark/ml/Transformer.scala | 18 ++++++++++-------- .../apache/spark/ml/feature/HashingTF.scala | 5 ++++- .../apache/spark/ml/feature/Tokenizer.scala | 4 +++- 3 files changed, 17 insertions(+), 10 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala index 490e6609ad311..23fbd228d01cb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala @@ -18,16 +18,14 @@ package org.apache.spark.ml import scala.annotation.varargs -import scala.reflect.runtime.universe.TypeTag import org.apache.spark.Logging import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.param._ import org.apache.spark.sql.SchemaRDD import org.apache.spark.sql.api.java.JavaSchemaRDD -import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.analysis.Star -import org.apache.spark.sql.catalyst.dsl._ +import org.apache.spark.sql.catalyst.expressions.ScalaUdf import org.apache.spark.sql.catalyst.types._ /** @@ -86,7 +84,7 @@ abstract class Transformer extends PipelineStage with Params { * Abstract class for transformers that take one input column, apply transformation, and output the * result as a new column. */ -private[ml] abstract class UnaryTransformer[IN, OUT: TypeTag, T <: UnaryTransformer[IN, OUT, T]] +private[ml] abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, OUT, T]] extends Transformer with HasInputCol with HasOutputCol with Logging { def setInputCol(value: String): T = set(inputCol, value).asInstanceOf[T] @@ -99,6 +97,11 @@ private[ml] abstract class UnaryTransformer[IN, OUT: TypeTag, T <: UnaryTransfor */ protected def createTransformFunc(paramMap: ParamMap): IN => OUT + /** + * Returns the data type of the output column. + */ + protected def outputDataType: DataType + /** * Validates the input type. Throw an exception if it is invalid. */ @@ -111,9 +114,8 @@ private[ml] abstract class UnaryTransformer[IN, OUT: TypeTag, T <: UnaryTransfor if (schema.fieldNames.contains(map(outputCol))) { throw new IllegalArgumentException(s"Output column ${map(outputCol)} already exists.") } - val output = ScalaReflection.schemaFor[OUT] val outputFields = schema.fields :+ - StructField(map(outputCol), output.dataType, output.nullable) + StructField(map(outputCol), outputDataType, !outputDataType.isPrimitive) StructType(outputFields) } @@ -121,7 +123,7 @@ private[ml] abstract class UnaryTransformer[IN, OUT: TypeTag, T <: UnaryTransfor transformSchema(dataset.schema, paramMap, logging = true) import dataset.sqlContext._ val map = this.paramMap ++ paramMap - val udf = this.createTransformFunc(map) - dataset.select(Star(None), udf.call(map(inputCol).attr) as map(outputCol)) + val udf = ScalaUdf(this.createTransformFunc(map), outputDataType, Seq(map(inputCol).attr)) + dataset.select(Star(None), udf as map(outputCol)) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala index b98b1755a3584..e0bfb1e484a2e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala @@ -21,7 +21,8 @@ import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.UnaryTransformer import org.apache.spark.ml.param.{IntParam, ParamMap} import org.apache.spark.mllib.feature -import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.linalg.{VectorUDT, Vector} +import org.apache.spark.sql.catalyst.types.DataType /** * :: AlphaComponent :: @@ -39,4 +40,6 @@ class HashingTF extends UnaryTransformer[Iterable[_], Vector, HashingTF] { val hashingTF = new feature.HashingTF(paramMap(numFeatures)) hashingTF.transform } + + override protected def outputDataType: DataType = new VectorUDT() } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala index 0a6599b64c011..9352f40f372d3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.feature import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.UnaryTransformer import org.apache.spark.ml.param.ParamMap -import org.apache.spark.sql.{DataType, StringType} +import org.apache.spark.sql.{DataType, StringType, ArrayType} /** * :: AlphaComponent :: @@ -36,4 +36,6 @@ class Tokenizer extends UnaryTransformer[String, Seq[String], Tokenizer] { protected override def validateInputType(inputType: DataType): Unit = { require(inputType == StringType, s"Input type must be string type but got $inputType.") } + + override protected def outputDataType: DataType = new ArrayType(StringType, false) }