diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 9aaafa34f8c03..5070032e49809 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -27,11 +27,9 @@ import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV} import org.apache.spark.SparkException import org.apache.spark.mllib.util.NumericParser -import org.apache.spark.sql.catalyst.UDTRegistry import org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.catalyst.types._ -import org.apache.spark.sql.Row /** * Represents a numeric vector, whose index type is Int and value type is Double. @@ -86,12 +84,6 @@ sealed trait Vector extends Serializable { */ object Vectors { - // Note: Explicit registration is only needed for Vector and SparseVector; - // the annotation works for DenseVector. - UDTRegistry.registerType(scala.reflect.runtime.universe.typeOf[Vector], new VectorUDT()) - UDTRegistry.registerType(scala.reflect.runtime.universe.typeOf[SparseVector], - new SparseVectorUDT()) - /** * Creates a dense vector from its values. */ @@ -202,7 +194,6 @@ object Vectors { /** * A dense vector represented by a value array. */ -@SQLUserDefinedType(udt = classOf[DenseVectorUDT]) class DenseVector(val values: Array[Double]) extends Vector { override def size: Int = values.length @@ -254,105 +245,3 @@ class SparseVector( private[mllib] override def toBreeze: BV[Double] = new BSV[Double](indices, values, size) } - -/** - * User-defined type for [[Vector]] which allows easy interaction with SQL - * via [[org.apache.spark.sql.SchemaRDD]]. - */ -private[spark] class VectorUDT extends UserDefinedType[Vector] { - - /** - * vectorType: 0 = dense, 1 = sparse. - * dense, sparse: One element holds the vector, and the other is null. - */ - override def sqlType: StructType = StructType(Seq( - StructField("vectorType", ByteType, nullable = false), - StructField("dense", new DenseVectorUDT, nullable = true), - StructField("sparse", new SparseVectorUDT, nullable = true))) - - override def serialize(obj: Any): Row = { - val row = new GenericMutableRow(3) - obj match { - case v: DenseVector => - row.setByte(0, 0) - row.update(1, new DenseVectorUDT().serialize(obj)) - row.setNullAt(2) - case v: SparseVector => - row.setByte(0, 1) - row.setNullAt(1) - row.update(2, new SparseVectorUDT().serialize(obj)) - } - row - } - - override def deserialize(datum: Any): Vector = { - datum match { - case row: Row => - require(row.length == 3, - s"VectorUDT.deserialize given row with length ${row.length} but requires length == 3") - val vectorType = row.getByte(0) - vectorType match { - case 0 => - new DenseVectorUDT().deserialize(row.getAs[Row](1)) - case 1 => - new SparseVectorUDT().deserialize(row.getAs[Row](2)) - } - } - } -} - -/** - * User-defined type for [[DenseVector]] which allows easy interaction with SQL - * via [[org.apache.spark.sql.SchemaRDD]]. - */ -private[spark] class DenseVectorUDT extends UserDefinedType[DenseVector] { - - override def sqlType: DataType = ArrayType(DoubleType, containsNull = false) - - override def serialize(obj: Any): Seq[Double] = { - obj match { - case v: DenseVector => - v.values.toSeq - } - } - - override def deserialize(datum: Any): DenseVector = { - datum match { - case values: Seq[_] => - new DenseVector(values.asInstanceOf[Seq[Double]].toArray) - } - } -} - -/** - * User-defined type for [[SparseVector]] which allows easy interaction with SQL - * via [[org.apache.spark.sql.SchemaRDD]]. - */ -private[spark] class SparseVectorUDT extends UserDefinedType[SparseVector] { - - override def sqlType: StructType = StructType(Seq( - StructField("size", IntegerType, nullable = false), - StructField("indices", ArrayType(IntegerType, containsNull = false), nullable = false), - StructField("values", ArrayType(DoubleType, containsNull = false), nullable = false))) - - override def serialize(obj: Any): Row = obj match { - case v: SparseVector => - val row: GenericMutableRow = new GenericMutableRow(3) - row.setInt(0, v.size) - row.update(1, v.indices.toSeq) - row.update(2, v.values.toSeq) - row - } - - override def deserialize(datum: Any): SparseVector = { - datum match { - case row: Row => - require(row.length == 3, - s"SparseVectorUDT.deserialize given row with length ${row.length} but expect 3.") - val vSize = row.getInt(0) - val indices = row.getAs[Seq[Int]](1).toArray - val values = row.getAs[Seq[Double]](2).toArray - new SparseVector(vSize, indices, values) - } - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index 4f0ac6ebbb604..df35577b659c3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -30,7 +30,7 @@ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.catalyst.{UDTRegistry, ScalaReflectionLock} +import org.apache.spark.sql.catalyst.ScalaReflectionLock import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, Row} import org.apache.spark.sql.catalyst.types.decimal._ import org.apache.spark.sql.catalyst.util.Metadata