Skip to content

Commit

Permalink
Removed Vector UDTs
Browse files Browse the repository at this point in the history
  • Loading branch information
jkbradley committed Nov 2, 2014
1 parent 5817b2b commit e13cd8a
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 112 deletions.
111 changes: 0 additions & 111 deletions mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,9 @@ import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV}

import org.apache.spark.SparkException
import org.apache.spark.mllib.util.NumericParser
import org.apache.spark.sql.catalyst.UDTRegistry
import org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
import org.apache.spark.sql.catalyst.types._
import org.apache.spark.sql.Row

/**
* Represents a numeric vector, whose index type is Int and value type is Double.
Expand Down Expand Up @@ -86,12 +84,6 @@ sealed trait Vector extends Serializable {
*/
object Vectors {

// Note: Explicit registration is only needed for Vector and SparseVector;
// the annotation works for DenseVector.
UDTRegistry.registerType(scala.reflect.runtime.universe.typeOf[Vector], new VectorUDT())
UDTRegistry.registerType(scala.reflect.runtime.universe.typeOf[SparseVector],
new SparseVectorUDT())

/**
* Creates a dense vector from its values.
*/
Expand Down Expand Up @@ -202,7 +194,6 @@ object Vectors {
/**
* A dense vector represented by a value array.
*/
@SQLUserDefinedType(udt = classOf[DenseVectorUDT])
class DenseVector(val values: Array[Double]) extends Vector {

override def size: Int = values.length
Expand Down Expand Up @@ -254,105 +245,3 @@ class SparseVector(

private[mllib] override def toBreeze: BV[Double] = new BSV[Double](indices, values, size)
}

/**
* User-defined type for [[Vector]] which allows easy interaction with SQL
* via [[org.apache.spark.sql.SchemaRDD]].
*/
private[spark] class VectorUDT extends UserDefinedType[Vector] {

/**
* vectorType: 0 = dense, 1 = sparse.
* dense, sparse: One element holds the vector, and the other is null.
*/
override def sqlType: StructType = StructType(Seq(
StructField("vectorType", ByteType, nullable = false),
StructField("dense", new DenseVectorUDT, nullable = true),
StructField("sparse", new SparseVectorUDT, nullable = true)))

override def serialize(obj: Any): Row = {
val row = new GenericMutableRow(3)
obj match {
case v: DenseVector =>
row.setByte(0, 0)
row.update(1, new DenseVectorUDT().serialize(obj))
row.setNullAt(2)
case v: SparseVector =>
row.setByte(0, 1)
row.setNullAt(1)
row.update(2, new SparseVectorUDT().serialize(obj))
}
row
}

override def deserialize(datum: Any): Vector = {
datum match {
case row: Row =>
require(row.length == 3,
s"VectorUDT.deserialize given row with length ${row.length} but requires length == 3")
val vectorType = row.getByte(0)
vectorType match {
case 0 =>
new DenseVectorUDT().deserialize(row.getAs[Row](1))
case 1 =>
new SparseVectorUDT().deserialize(row.getAs[Row](2))
}
}
}
}

/**
* User-defined type for [[DenseVector]] which allows easy interaction with SQL
* via [[org.apache.spark.sql.SchemaRDD]].
*/
private[spark] class DenseVectorUDT extends UserDefinedType[DenseVector] {

override def sqlType: DataType = ArrayType(DoubleType, containsNull = false)

override def serialize(obj: Any): Seq[Double] = {
obj match {
case v: DenseVector =>
v.values.toSeq
}
}

override def deserialize(datum: Any): DenseVector = {
datum match {
case values: Seq[_] =>
new DenseVector(values.asInstanceOf[Seq[Double]].toArray)
}
}
}

/**
* User-defined type for [[SparseVector]] which allows easy interaction with SQL
* via [[org.apache.spark.sql.SchemaRDD]].
*/
private[spark] class SparseVectorUDT extends UserDefinedType[SparseVector] {

override def sqlType: StructType = StructType(Seq(
StructField("size", IntegerType, nullable = false),
StructField("indices", ArrayType(IntegerType, containsNull = false), nullable = false),
StructField("values", ArrayType(DoubleType, containsNull = false), nullable = false)))

override def serialize(obj: Any): Row = obj match {
case v: SparseVector =>
val row: GenericMutableRow = new GenericMutableRow(3)
row.setInt(0, v.size)
row.update(1, v.indices.toSeq)
row.update(2, v.values.toSeq)
row
}

override def deserialize(datum: Any): SparseVector = {
datum match {
case row: Row =>
require(row.length == 3,
s"SparseVectorUDT.deserialize given row with length ${row.length} but expect 3.")
val vSize = row.getInt(0)
val indices = row.getAs[Seq[Int]](1).toArray
val values = row.getAs[Seq[Double]](2).toArray
new SparseVector(vSize, indices, values)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.sql.catalyst.{UDTRegistry, ScalaReflectionLock}
import org.apache.spark.sql.catalyst.ScalaReflectionLock
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, Row}
import org.apache.spark.sql.catalyst.types.decimal._
import org.apache.spark.sql.catalyst.util.Metadata
Expand Down

0 comments on commit e13cd8a

Please sign in to comment.