From 1e87a4525ebaa17cef7974651c113110fc483d9d Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 17 May 2015 11:45:12 -0700 Subject: [PATCH 01/13] WIP refactoring of CatalystTypeConverters --- .../sql/catalyst/CatalystTypeConverters.scala | 496 +++++++++--------- 1 file changed, 239 insertions(+), 257 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index 75a493b248f6e..f30a414c977c3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -18,6 +18,8 @@ package org.apache.spark.sql.catalyst import java.lang.{Iterable => JavaIterable} +import java.math.{BigDecimal => JavaBigDecimal} +import java.sql.Date import java.util.{Map => JavaMap} import scala.collection.mutable.HashMap @@ -34,183 +36,263 @@ object CatalystTypeConverters { // Since the map values can be mutable, we explicitly import scala.collection.Map at here. import scala.collection.Map - /** - * Converts Scala objects to catalyst rows / types. This method is slow, and for batch - * conversion you should be using converter produced by createToCatalystConverter. - * Note: This is always called after schemaFor has been called. - * This ordering is important for UDT registration. - */ - def convertToCatalyst(a: Any, dataType: DataType): Any = (a, dataType) match { - // Check UDT first since UDTs can override other types - case (obj, udt: UserDefinedType[_]) => - udt.serialize(obj) - - case (o: Option[_], _) => - o.map(convertToCatalyst(_, dataType)).orNull + private def getConverterForType(dataType: DataType): CatalystTypeConverter[Any, Any] = { + val converter = dataType match { + // Check UDT first since UDTs can override other types + case udt: UserDefinedType[_] => UDTConverter(udt) + case arrayType: ArrayType => ArrayConverter(arrayType.elementType) + case mapType: MapType => MapConverter(mapType.keyType, mapType.valueType) + case structType: StructType => StructConverter(structType) + case StringType => StringConverter + case DateType => DateConverter + case dt: DecimalType => BigDecimalConverter + case BooleanType => BooleanConverter + case ByteType => ByteConverter + case ShortType => ShortConverter + case IntegerType => IntConverter + case LongType => LongConverter + case FloatType => FloatConverter + case DoubleType => DoubleConverter + case _ => IdentityConverter + } + converter.asInstanceOf[CatalystTypeConverter[Any, Any]] + } - case (s: Seq[_], arrayType: ArrayType) => - s.map(convertToCatalyst(_, arrayType.elementType)) + private abstract class CatalystTypeConverter[ScalaType, CatalystType] extends Serializable { - case (jit: JavaIterable[_], arrayType: ArrayType) => { - val iter = jit.iterator - var listOfItems: List[Any] = List() - while (iter.hasNext) { - val item = iter.next() - listOfItems :+= convertToCatalyst(item, arrayType.elementType) + final def toCatalyst(maybeScalaValue: Any): Any = { + maybeScalaValue match { + case None => null + case null => null + case Some(scalaValue: ScalaType) => toCatalystImpl(scalaValue) + case scalaValue: ScalaType => toCatalystImpl(scalaValue) } - listOfItems } - case (s: Array[_], arrayType: ArrayType) => - s.toSeq.map(convertToCatalyst(_, arrayType.elementType)) - - case (m: Map[_, _], mapType: MapType) => - m.map { case (k, v) => - convertToCatalyst(k, mapType.keyType) -> convertToCatalyst(v, mapType.valueType) - } + final def toScala(row: Row, column: Int): Any = { + if (row.isNullAt(column)) null else toScalaImpl(row, column) + } - case (jmap: JavaMap[_, _], mapType: MapType) => - val iter = jmap.entrySet.iterator - var listOfEntries: List[(Any, Any)] = List() - while (iter.hasNext) { - val entry = iter.next() - listOfEntries :+= (convertToCatalyst(entry.getKey, mapType.keyType), - convertToCatalyst(entry.getValue, mapType.valueType)) - } - listOfEntries.toMap - - case (p: Product, structType: StructType) => - val ar = new Array[Any](structType.size) - val iter = p.productIterator - var idx = 0 - while (idx < structType.size) { - ar(idx) = convertToCatalyst(iter.next(), structType.fields(idx).dataType) - idx += 1 - } - new GenericRowWithSchema(ar, structType) + def toScala(catalystValue: CatalystType): ScalaType + protected def toCatalystImpl(scalaValue: ScalaType): CatalystType + protected def toScalaImpl(row: Row, column: Int): ScalaType + } - case (d: String, _) => - UTF8String(d) + private abstract class PrimitiveCatalystTypeConverter[T] extends CatalystTypeConverter[T, T] { + override final def toScala(catalystValue: T): T = catalystValue + override final def toCatalystImpl(scalaValue: T): T = scalaValue + } - case (d: BigDecimal, _) => - Decimal(d) + private object IdentityConverter extends CatalystTypeConverter[Any, Any] { + override def toCatalystImpl(scalaValue: Any): Any = scalaValue + override def toScala(catalystValue: Any): Any = catalystValue + override def toScalaImpl(row: Row, column: Int): Any = row(column) + } - case (d: java.math.BigDecimal, _) => - Decimal(d) + private case class UDTConverter(udt: UserDefinedType[_]) extends CatalystTypeConverter[Any, Any] { + override def toCatalystImpl(scalaValue: Any): Any = udt.serialize(scalaValue) + override def toScala(catalystValue: Any): Any = udt.deserialize(catalystValue) + override def toScalaImpl(row: Row, column: Int): Any = toScala(row(column)) + } - case (d: java.sql.Date, _) => - DateUtils.fromJavaDate(d) + // Converter for array, seq, iterables. + private case class ArrayConverter( + elementType: DataType) extends CatalystTypeConverter[Any, Seq[Any]] { + + private[this] val elementConverter = getConverterForType(elementType) + + override def toCatalystImpl(scalaValue: Any): Seq[Any] = { + scalaValue match { + case a: Array[_] => a.toSeq.map(elementConverter.toCatalyst) + case s: Seq[_] => s.map(elementConverter.toCatalyst) + case i: JavaIterable[_] => + val iter = i.iterator + var convertedIterable: List[Any] = List() + while (iter.hasNext) { + val item = iter.next() + convertedIterable :+= elementConverter.toCatalyst(item) + } + convertedIterable + } + } - case (r: Row, structType: StructType) => - val converters = structType.fields.map { - f => (item: Any) => convertToCatalyst(item, f.dataType) + override def toScala(catalystValue: Seq[Any]): Seq[Any] = { + if (catalystValue == null) { + null + } else { + catalystValue.asInstanceOf[Seq[_]].map(elementConverter.toScala) } - convertRowWithConverters(r, structType, converters) + } - case (other, _) => - other + override def toScalaImpl(row: Row, column: Int): Seq[Any] = + toScala(row(column).asInstanceOf[Seq[Any]]) } - /** - * Creates a converter function that will convert Scala objects to the specified catalyst type. - * Typical use case would be converting a collection of rows that have the same schema. You will - * call this function once to get a converter, and apply it to every row. - */ - private[sql] def createToCatalystConverter(dataType: DataType): Any => Any = { - def extractOption(item: Any): Any = item match { - case opt: Option[_] => opt.orNull - case other => other - } + private case class MapConverter( + keyType: DataType, + valueType: DataType + ) extends CatalystTypeConverter[Any, Map[Any, Any]] { - dataType match { - // Check UDT first since UDTs can override other types - case udt: UserDefinedType[_] => - (item) => extractOption(item) match { - case null => null - case other => udt.serialize(other) + private[this] val keyConverter = getConverterForType(keyType) + private[this] val valueConverter = getConverterForType(valueType) + + override def toCatalystImpl(scalaValue: Any): Map[Any, Any] = scalaValue match { + case m: Map[_, _] => + m.map { case (k, v) => + keyConverter.toCatalyst(k) -> valueConverter.toCatalyst(v) } - case arrayType: ArrayType => - val elementConverter = createToCatalystConverter(arrayType.elementType) - (item: Any) => { - extractOption(item) match { - case a: Array[_] => a.toSeq.map(elementConverter) - case s: Seq[_] => s.map(elementConverter) - case i: JavaIterable[_] => { - val iter = i.iterator - var convertedIterable: List[Any] = List() - while (iter.hasNext) { - val item = iter.next() - convertedIterable :+= elementConverter(item) - } - convertedIterable - } - case null => null - } + case jmap: JavaMap[_, _] => + val iter = jmap.entrySet.iterator + val convertedMap: HashMap[Any, Any] = HashMap() + while (iter.hasNext) { + val entry = iter.next() + val key = keyConverter.toCatalyst(entry.getKey) + convertedMap(key) = valueConverter.toCatalyst(entry.getValue) } + convertedMap + } - case mapType: MapType => - val keyConverter = createToCatalystConverter(mapType.keyType) - val valueConverter = createToCatalystConverter(mapType.valueType) - (item: Any) => { - extractOption(item) match { - case m: Map[_, _] => - m.map { case (k, v) => - keyConverter(k) -> valueConverter(v) - } - - case jmap: JavaMap[_, _] => - val iter = jmap.entrySet.iterator - val convertedMap: HashMap[Any, Any] = HashMap() - while (iter.hasNext) { - val entry = iter.next() - convertedMap(keyConverter(entry.getKey)) = valueConverter(entry.getValue) - } - convertedMap - - case null => null - } + override def toScala(catalystValue: Map[Any, Any]): Map[Any, Any] = { + if (catalystValue == null) { + null + } else { + catalystValue.map { case (k, v) => + keyConverter.toScala(k) -> valueConverter.toScala(v) } + } + } - case structType: StructType => - val converters = structType.fields.map(f => createToCatalystConverter(f.dataType)) - (item: Any) => { - extractOption(item) match { - case r: Row => - convertRowWithConverters(r, structType, converters) - - case p: Product => - val ar = new Array[Any](structType.size) - val iter = p.productIterator - var idx = 0 - while (idx < structType.size) { - ar(idx) = converters(idx)(iter.next()) - idx += 1 - } - new GenericRowWithSchema(ar, structType) - - case null => - null - } + override def toScalaImpl(row: Row, column: Int): Map[Any, Any] = + toScala(row(column).asInstanceOf[Map[Any, Any]]) + } + + private case class StructConverter( + structType: StructType) extends CatalystTypeConverter[Any, Row] { + + private[this] val converters = structType.fields.map { f => getConverterForType(f.dataType) } + + override def toCatalystImpl(scalaValue: Any): Row = scalaValue match { + case row: Row => + val ar = new Array[Any](row.size) + var idx = 0 + while (idx < row.size) { + ar(idx) = converters(idx).toCatalyst(row(idx)) + idx += 1 } + new GenericRowWithSchema(ar, structType) + + case p: Product => + val ar = new Array[Any](structType.size) + val iter = p.productIterator + var idx = 0 + while (idx < structType.size) { + ar(idx) = converters(idx).toCatalyst(iter.next()) + idx += 1 + } + new GenericRowWithSchema(ar, structType) + } - case dateType: DateType => (item: Any) => extractOption(item) match { - case d: java.sql.Date => DateUtils.fromJavaDate(d) - case other => other + override def toScala(row: Row): Row = { + if (row == null) { + null + } else { + val ar = new Array[Any](row.size) + var idx = 0 + while (idx < row.size) { + ar(idx) = converters(idx).toScala(row, idx) + idx += 1 + } + new GenericRowWithSchema(ar, structType) } + } - case dataType: StringType => (item: Any) => extractOption(item) match { - case s: String => UTF8String(s) - case other => other - } + override def toScalaImpl(row: Row, column: Int): Row = toScala(row(column).asInstanceOf[Row]) + } - case _ => - (item: Any) => extractOption(item) match { - case d: BigDecimal => Decimal(d) - case d: java.math.BigDecimal => Decimal(d) - case other => other - } + private object StringConverter extends CatalystTypeConverter[Any, Any] { + override def toCatalystImpl(scalaValue: Any): UTF8String = scalaValue match { + case str: String => UTF8String(str) + case utf8: UTF8String => utf8 + } + override def toScala(catalystValue: Any): String = catalystValue match { + case null => null + case str: String => str + case utf8: UTF8String => utf8.toString() } + override def toScalaImpl(row: Row, column: Int): String = row(column).toString + } + + private object DateConverter extends CatalystTypeConverter[Date, Any] { + override def toCatalystImpl(scalaValue: Date): Int = DateUtils.fromJavaDate(scalaValue) + override def toScala(catalystValue: Any): Date = + if (catalystValue == null) null else DateUtils.toJavaDate(catalystValue.asInstanceOf[Int]) + override def toScalaImpl(row: Row, column: Int): Date = toScala(row.getInt(column)) + } + + private object BigDecimalConverter extends CatalystTypeConverter[Any, Decimal] { + override def toCatalystImpl(scalaValue: Any): Decimal = scalaValue match { + case d: BigDecimal => Decimal(d) + case d: JavaBigDecimal => Decimal(d) + case d: Decimal => d + } + override def toScala(catalystValue: Decimal): JavaBigDecimal = catalystValue.toJavaBigDecimal + override def toScalaImpl(row: Row, column: Int): JavaBigDecimal = row.get(column) match { + case d: JavaBigDecimal => d + case d: Decimal => d.toJavaBigDecimal + } + } + + private object BooleanConverter extends PrimitiveCatalystTypeConverter[Boolean] { + override def toScalaImpl(row: Row, column: Int): Boolean = row.getBoolean(column) + } + + private object ByteConverter extends PrimitiveCatalystTypeConverter[Byte] { + override def toScalaImpl(row: Row, column: Int): Byte = row.getByte(column) + } + + private object ShortConverter extends PrimitiveCatalystTypeConverter[Short] { + override def toScalaImpl(row: Row, column: Int): Short = row.getShort(column) + } + + private object IntConverter extends PrimitiveCatalystTypeConverter[Int] { + override def toScalaImpl(row: Row, column: Int): Int = row.getInt(column) + } + + private object LongConverter extends PrimitiveCatalystTypeConverter[Long] { + override def toScalaImpl(row: Row, column: Int): Long = row.getLong(column) + } + + private object FloatConverter extends PrimitiveCatalystTypeConverter[Float] { + override def toScalaImpl(row: Row, column: Int): Float = row.getFloat(column) + } + + private object DoubleConverter extends PrimitiveCatalystTypeConverter[Double] { + override def toScalaImpl(row: Row, column: Int): Double = row.getDouble(column) + } + + /** + * Converts Scala objects to catalyst rows / types. This method is slow, and for batch + * conversion you should be using converter produced by createToCatalystConverter. + * Note: This is always called after schemaFor has been called. + * This ordering is important for UDT registration. + */ + def convertToCatalyst(scalaValue: Any, dataType: DataType): Any = { + // Check UDT first since UDTs can override other types + dataType match { + case udt: UserDefinedType[_] => udt.serialize(scalaValue) + case option: Option[_] => option.map(convertToCatalyst(_, dataType)).orNull + case _ => getConverterForType(dataType).toCatalyst(scalaValue) + } + } + + /** + * Creates a converter function that will convert Scala objects to the specified catalyst type. + * Typical use case would be converting a collection of rows that have the same schema. You will + * call this function once to get a converter, and apply it to every row. + */ + private[sql] def createToCatalystConverter(dataType: DataType): Any => Any = { + getConverterForType(dataType).toCatalyst } /** @@ -221,10 +303,10 @@ object CatalystTypeConverters { * This is used to create an RDD or test results with correct types for Catalyst. */ def convertToCatalyst(a: Any): Any = a match { - case s: String => UTF8String(s) - case d: java.sql.Date => DateUtils.fromJavaDate(d) - case d: BigDecimal => Decimal(d) - case d: java.math.BigDecimal => Decimal(d) + case s: String => StringConverter.toCatalyst(s) + case d: Date => DateConverter.toCatalyst(d) + case d: BigDecimal => BigDecimalConverter.toCatalyst(d) + case d: JavaBigDecimal => BigDecimalConverter.toCatalyst(d) case seq: Seq[Any] => seq.map(convertToCatalyst) case r: Row => Row(r.toSeq.map(convertToCatalyst): _*) case arr: Array[Any] => arr.toSeq.map(convertToCatalyst).toArray @@ -238,33 +320,8 @@ object CatalystTypeConverters { * This method is slow, and for batch conversion you should be using converter * produced by createToScalaConverter. */ - def convertToScala(a: Any, dataType: DataType): Any = (a, dataType) match { - // Check UDT first since UDTs can override other types - case (d, udt: UserDefinedType[_]) => - udt.deserialize(d) - - case (s: Seq[_], arrayType: ArrayType) => - s.map(convertToScala(_, arrayType.elementType)) - - case (m: Map[_, _], mapType: MapType) => - m.map { case (k, v) => - convertToScala(k, mapType.keyType) -> convertToScala(v, mapType.valueType) - } - - case (r: Row, s: StructType) => - convertRowToScala(r, s) - - case (d: Decimal, _: DecimalType) => - d.toJavaBigDecimal - - case (i: Int, DateType) => - DateUtils.toJavaDate(i) - - case (s: UTF8String, StringType) => - s.toString() - - case (other, _) => - other + def convertToScala(catalystValue: Any, dataType: DataType): Any = { + getConverterForType(dataType).toScala(catalystValue) } /** @@ -272,82 +329,7 @@ object CatalystTypeConverters { * Typical use case would be converting a collection of rows that have the same schema. You will * call this function once to get a converter, and apply it to every row. */ - private[sql] def createToScalaConverter(dataType: DataType): Any => Any = dataType match { - // Check UDT first since UDTs can override other types - case udt: UserDefinedType[_] => - (item: Any) => if (item == null) null else udt.deserialize(item) - - case arrayType: ArrayType => - val elementConverter = createToScalaConverter(arrayType.elementType) - (item: Any) => if (item == null) null else item.asInstanceOf[Seq[_]].map(elementConverter) - - case mapType: MapType => - val keyConverter = createToScalaConverter(mapType.keyType) - val valueConverter = createToScalaConverter(mapType.valueType) - (item: Any) => if (item == null) { - null - } else { - item.asInstanceOf[Map[_, _]].map { case (k, v) => - keyConverter(k) -> valueConverter(v) - } - } - - case s: StructType => - val converters = s.fields.map(f => createToScalaConverter(f.dataType)) - (item: Any) => { - if (item == null) { - null - } else { - convertRowWithConverters(item.asInstanceOf[Row], s, converters) - } - } - - case _: DecimalType => - (item: Any) => item match { - case d: Decimal => d.toJavaBigDecimal - case other => other - } - - case DateType => - (item: Any) => item match { - case i: Int => DateUtils.toJavaDate(i) - case other => other - } - - case StringType => - (item: Any) => item match { - case s: UTF8String => s.toString() - case other => other - } - - case other => - (item: Any) => item - } - - def convertRowToScala(r: Row, schema: StructType): Row = { - val ar = new Array[Any](r.size) - var idx = 0 - while (idx < r.size) { - ar(idx) = convertToScala(r(idx), schema.fields(idx).dataType) - idx += 1 - } - new GenericRowWithSchema(ar, schema) - } - - /** - * Converts a row by applying the provided set of converter functions. It is used for both - * toScala and toCatalyst conversions. - */ - private[sql] def convertRowWithConverters( - row: Row, - schema: StructType, - converters: Array[Any => Any]): Row = { - val ar = new Array[Any](row.size) - var idx = 0 - while (idx < row.size) { - ar(idx) = converters(idx)(row(idx)) - idx += 1 - } - new GenericRowWithSchema(ar, schema) + private[sql] def createToScalaConverter(dataType: DataType): Any => Any = { + getConverterForType(dataType).toScala } } From 7ca7fcb633b94d2af8ef7a09f3eaf78a9e8aa2f2 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 17 May 2015 16:14:10 -0700 Subject: [PATCH 02/13] Comments and cleanup --- .../sql/catalyst/CatalystTypeConverters.scala | 55 ++++++++++++++----- 1 file changed, 41 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index f30a414c977c3..e7a2f5e9e0cda 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -21,6 +21,7 @@ import java.lang.{Iterable => JavaIterable} import java.math.{BigDecimal => JavaBigDecimal} import java.sql.Date import java.util.{Map => JavaMap} +import javax.annotation.Nullable import scala.collection.mutable.HashMap @@ -38,7 +39,6 @@ object CatalystTypeConverters { private def getConverterForType(dataType: DataType): CatalystTypeConverter[Any, Any] = { val converter = dataType match { - // Check UDT first since UDTs can override other types case udt: UserDefinedType[_] => UDTConverter(udt) case arrayType: ArrayType => ArrayConverter(arrayType.elementType) case mapType: MapType => MapConverter(mapType.keyType, mapType.valueType) @@ -58,26 +58,58 @@ object CatalystTypeConverters { converter.asInstanceOf[CatalystTypeConverter[Any, Any]] } + /** + * Converts a Scala type to its Catalyst equivalent (and vice versa). + */ private abstract class CatalystTypeConverter[ScalaType, CatalystType] extends Serializable { - final def toCatalyst(maybeScalaValue: Any): Any = { + /** + * Converts a Scala type to its Catalyst equivalent while automatically handling nulls + * and Options. + */ + final def toCatalyst(@Nullable maybeScalaValue: Any): CatalystType = { maybeScalaValue match { - case None => null - case null => null - case Some(scalaValue: ScalaType) => toCatalystImpl(scalaValue) + case opt: Option[ScalaType] => + if (opt.isDefined) { + toCatalystImpl(opt.get) + } else { + null.asInstanceOf[CatalystType] + } + case null => null.asInstanceOf[CatalystType] case scalaValue: ScalaType => toCatalystImpl(scalaValue) } } + /** + * Given a Catalyst row, convert the value at column `column` to its Scala equivalent. + */ final def toScala(row: Row, column: Int): Any = { if (row.isNullAt(column)) null else toScalaImpl(row, column) } - def toScala(catalystValue: CatalystType): ScalaType + /** + * Convert a Catalyst value to its Scala equivalent. + */ + def toScala(@Nullable catalystValue: CatalystType): ScalaType + + /** + * Converts a Scala value to its Catalyst equivalent. + * @param scalaValue the Scala value, guaranteed not to be null. + * @return the Catalyst value. + */ protected def toCatalystImpl(scalaValue: ScalaType): CatalystType + + /** + * Given a Catalyst row, convert the value at column `column` to its Scala equivalent. + * This method will only be called on non-null columns. + */ protected def toScalaImpl(row: Row, column: Int): ScalaType } + /** + * Convenience wrapper to write type converters for primitives. We use a converter for primitives + * so that we can use type-specific field accessors when converting Catalyst rows to Scala rows. + */ private abstract class PrimitiveCatalystTypeConverter[T] extends CatalystTypeConverter[T, T] { override final def toScala(catalystValue: T): T = catalystValue override final def toCatalystImpl(scalaValue: T): T = scalaValue @@ -278,16 +310,11 @@ object CatalystTypeConverters { * This ordering is important for UDT registration. */ def convertToCatalyst(scalaValue: Any, dataType: DataType): Any = { - // Check UDT first since UDTs can override other types - dataType match { - case udt: UserDefinedType[_] => udt.serialize(scalaValue) - case option: Option[_] => option.map(convertToCatalyst(_, dataType)).orNull - case _ => getConverterForType(dataType).toCatalyst(scalaValue) - } + getConverterForType(dataType).toCatalyst(scalaValue) } /** - * Creates a converter function that will convert Scala objects to the specified catalyst type. + * Creates a converter function that will convert Scala objects to the specified Catalyst type. * Typical use case would be converting a collection of rows that have the same schema. You will * call this function once to get a converter, and apply it to every row. */ @@ -296,7 +323,7 @@ object CatalystTypeConverters { } /** - * Converts Scala objects to catalyst rows / types. + * Converts Scala objects to Catalyst rows / types. * * Note: This should be called before do evaluation on Row * (It does not support UDT) From ae3278d59e86f3bb41718ff5c7599c378363fc04 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 24 May 2015 17:57:26 -0700 Subject: [PATCH 03/13] Throw ClassCastException errors during inbound conversions. --- .../sql/catalyst/CatalystTypeConverters.scala | 83 +++++++++++-------- 1 file changed, 47 insertions(+), 36 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index e7a2f5e9e0cda..586d8167b9272 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -37,7 +37,7 @@ object CatalystTypeConverters { // Since the map values can be mutable, we explicitly import scala.collection.Map at here. import scala.collection.Map - private def getConverterForType(dataType: DataType): CatalystTypeConverter[Any, Any] = { + private def getConverterForType(dataType: DataType): CatalystTypeConverter[Any, Any, Any] = { val converter = dataType match { case udt: UserDefinedType[_] => UDTConverter(udt) case arrayType: ArrayType => ArrayConverter(arrayType.elementType) @@ -55,13 +55,18 @@ object CatalystTypeConverters { case DoubleType => DoubleConverter case _ => IdentityConverter } - converter.asInstanceOf[CatalystTypeConverter[Any, Any]] + converter.asInstanceOf[CatalystTypeConverter[Any, Any, Any]] } /** * Converts a Scala type to its Catalyst equivalent (and vice versa). + * + * @tparam ScalaInputType The type of Scala values that can be converted to Catalyst. + * @tparam ScalaOutputType The type of Scala values returned when converting Catalyst to Scala. + * @tparam CatalystType The internal Catalyst type used to represent values of this Scala type. */ - private abstract class CatalystTypeConverter[ScalaType, CatalystType] extends Serializable { + private abstract class CatalystTypeConverter[ScalaInputType, ScalaOutputType, CatalystType] + extends Serializable { /** * Converts a Scala type to its Catalyst equivalent while automatically handling nulls @@ -69,67 +74,59 @@ object CatalystTypeConverters { */ final def toCatalyst(@Nullable maybeScalaValue: Any): CatalystType = { maybeScalaValue match { - case opt: Option[ScalaType] => + case opt: Option[ScalaInputType] => if (opt.isDefined) { toCatalystImpl(opt.get) } else { null.asInstanceOf[CatalystType] } case null => null.asInstanceOf[CatalystType] - case scalaValue: ScalaType => toCatalystImpl(scalaValue) + case scalaValue: ScalaInputType => toCatalystImpl(scalaValue) } } /** * Given a Catalyst row, convert the value at column `column` to its Scala equivalent. */ - final def toScala(row: Row, column: Int): Any = { - if (row.isNullAt(column)) null else toScalaImpl(row, column) + final def toScala(row: Row, column: Int): ScalaOutputType = { + if (row.isNullAt(column)) null.asInstanceOf[ScalaOutputType] else toScalaImpl(row, column) } /** * Convert a Catalyst value to its Scala equivalent. */ - def toScala(@Nullable catalystValue: CatalystType): ScalaType + def toScala(@Nullable catalystValue: CatalystType): ScalaOutputType /** * Converts a Scala value to its Catalyst equivalent. * @param scalaValue the Scala value, guaranteed not to be null. * @return the Catalyst value. */ - protected def toCatalystImpl(scalaValue: ScalaType): CatalystType + protected def toCatalystImpl(scalaValue: ScalaInputType): CatalystType /** * Given a Catalyst row, convert the value at column `column` to its Scala equivalent. * This method will only be called on non-null columns. */ - protected def toScalaImpl(row: Row, column: Int): ScalaType - } - - /** - * Convenience wrapper to write type converters for primitives. We use a converter for primitives - * so that we can use type-specific field accessors when converting Catalyst rows to Scala rows. - */ - private abstract class PrimitiveCatalystTypeConverter[T] extends CatalystTypeConverter[T, T] { - override final def toScala(catalystValue: T): T = catalystValue - override final def toCatalystImpl(scalaValue: T): T = scalaValue + protected def toScalaImpl(row: Row, column: Int): ScalaOutputType } - private object IdentityConverter extends CatalystTypeConverter[Any, Any] { + private object IdentityConverter extends CatalystTypeConverter[Any, Any, Any] { override def toCatalystImpl(scalaValue: Any): Any = scalaValue override def toScala(catalystValue: Any): Any = catalystValue override def toScalaImpl(row: Row, column: Int): Any = row(column) } - private case class UDTConverter(udt: UserDefinedType[_]) extends CatalystTypeConverter[Any, Any] { + private case class UDTConverter( + udt: UserDefinedType[_]) extends CatalystTypeConverter[Any, Any, Any] { override def toCatalystImpl(scalaValue: Any): Any = udt.serialize(scalaValue) override def toScala(catalystValue: Any): Any = udt.deserialize(catalystValue) override def toScalaImpl(row: Row, column: Int): Any = toScala(row(column)) } - // Converter for array, seq, iterables. + /** Converter for arrays, sequences, and Java iterables. */ private case class ArrayConverter( - elementType: DataType) extends CatalystTypeConverter[Any, Seq[Any]] { + elementType: DataType) extends CatalystTypeConverter[Any, Seq[Any], Seq[Any]] { private[this] val elementConverter = getConverterForType(elementType) @@ -162,8 +159,8 @@ object CatalystTypeConverters { private case class MapConverter( keyType: DataType, - valueType: DataType - ) extends CatalystTypeConverter[Any, Map[Any, Any]] { + valueType: DataType) + extends CatalystTypeConverter[Any, Map[Any, Any], Map[Any, Any]] { private[this] val keyConverter = getConverterForType(keyType) private[this] val valueConverter = getConverterForType(valueType) @@ -200,7 +197,7 @@ object CatalystTypeConverters { } private case class StructConverter( - structType: StructType) extends CatalystTypeConverter[Any, Row] { + structType: StructType) extends CatalystTypeConverter[Any, Row, Row] { private[this] val converters = structType.fields.map { f => getConverterForType(f.dataType) } @@ -242,7 +239,7 @@ object CatalystTypeConverters { override def toScalaImpl(row: Row, column: Int): Row = toScala(row(column).asInstanceOf[Row]) } - private object StringConverter extends CatalystTypeConverter[Any, Any] { + private object StringConverter extends CatalystTypeConverter[Any, String, Any] { override def toCatalystImpl(scalaValue: Any): UTF8String = scalaValue match { case str: String => UTF8String(str) case utf8: UTF8String => utf8 @@ -255,14 +252,14 @@ object CatalystTypeConverters { override def toScalaImpl(row: Row, column: Int): String = row(column).toString } - private object DateConverter extends CatalystTypeConverter[Date, Any] { + private object DateConverter extends CatalystTypeConverter[Date, Date, Any] { override def toCatalystImpl(scalaValue: Date): Int = DateUtils.fromJavaDate(scalaValue) override def toScala(catalystValue: Any): Date = if (catalystValue == null) null else DateUtils.toJavaDate(catalystValue.asInstanceOf[Int]) override def toScalaImpl(row: Row, column: Int): Date = toScala(row.getInt(column)) } - private object BigDecimalConverter extends CatalystTypeConverter[Any, Decimal] { + private object BigDecimalConverter extends CatalystTypeConverter[Any, JavaBigDecimal, Decimal] { override def toCatalystImpl(scalaValue: Any): Decimal = scalaValue match { case d: BigDecimal => Decimal(d) case d: JavaBigDecimal => Decimal(d) @@ -275,32 +272,46 @@ object CatalystTypeConverters { } } - private object BooleanConverter extends PrimitiveCatalystTypeConverter[Boolean] { + private object BooleanConverter extends CatalystTypeConverter[Boolean, Boolean, Boolean] { override def toScalaImpl(row: Row, column: Int): Boolean = row.getBoolean(column) + override def toScala(catalystValue: Boolean): Boolean = catalystValue + override protected def toCatalystImpl(scalaValue: Boolean): Boolean = scalaValue } - private object ByteConverter extends PrimitiveCatalystTypeConverter[Byte] { + private object ByteConverter extends CatalystTypeConverter[Byte, Byte, Byte] { override def toScalaImpl(row: Row, column: Int): Byte = row.getByte(column) + override def toScala(catalystValue: Byte): Byte = catalystValue + override protected def toCatalystImpl(scalaValue: Byte): Byte = scalaValue } - private object ShortConverter extends PrimitiveCatalystTypeConverter[Short] { + private object ShortConverter extends CatalystTypeConverter[Short, Short, Short] { override def toScalaImpl(row: Row, column: Int): Short = row.getShort(column) + override def toScala(catalystValue: Short): Short = catalystValue + override protected def toCatalystImpl(scalaValue: Short): Short = scalaValue } - private object IntConverter extends PrimitiveCatalystTypeConverter[Int] { + private object IntConverter extends CatalystTypeConverter[Int, Int, Int] { override def toScalaImpl(row: Row, column: Int): Int = row.getInt(column) + override def toScala(catalystValue: Int): Int = catalystValue + override protected def toCatalystImpl(scalaValue: Int): Int = scalaValue } - private object LongConverter extends PrimitiveCatalystTypeConverter[Long] { + private object LongConverter extends CatalystTypeConverter[Long, Long, Long] { override def toScalaImpl(row: Row, column: Int): Long = row.getLong(column) + override def toScala(catalystValue: Long): Long = catalystValue + override protected def toCatalystImpl(scalaValue: Long): Long = scalaValue } - private object FloatConverter extends PrimitiveCatalystTypeConverter[Float] { + private object FloatConverter extends CatalystTypeConverter[Float, Float, Float] { override def toScalaImpl(row: Row, column: Int): Float = row.getFloat(column) + override def toScala(catalystValue: Float): Float = catalystValue + override protected def toCatalystImpl(scalaValue: Float): Float = scalaValue } - private object DoubleConverter extends PrimitiveCatalystTypeConverter[Double] { + private object DoubleConverter extends CatalystTypeConverter[Double, Double, Double] { override def toScalaImpl(row: Row, column: Int): Double = row.getDouble(column) + override def toScala(catalystValue: Double): Double = catalystValue + override protected def toCatalystImpl(scalaValue: Double): Double = scalaValue } /** From 9c0e4e18bf3a05ac925429d46339a871a24659ac Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 24 May 2015 18:36:04 -0700 Subject: [PATCH 04/13] Remove last use of convertToScala(). --- .../spark/sql/catalyst/CatalystTypeConverters.scala | 9 --------- .../spark/sql/catalyst/expressions/generators.scala | 10 +++++++--- 2 files changed, 7 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index 586d8167b9272..b838f92255554 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -353,15 +353,6 @@ object CatalystTypeConverters { case other => other } - /** - * Converts Catalyst types used internally in rows to standard Scala types - * This method is slow, and for batch conversion you should be using converter - * produced by createToScalaConverter. - */ - def convertToScala(catalystValue: Any, dataType: DataType): Any = { - getConverterForType(dataType).toScala(catalystValue) - } - /** * Creates a converter function that will convert Catalyst types to Scala type. * Typical use case would be converting a collection of rows that have the same schema. You will diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 634138010fd21..c8f9aacc0baad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -71,12 +71,16 @@ case class UserDefinedGenerator( children: Seq[Expression]) extends Generator { + private[this] val inputRow: InterpretedProjection = new InterpretedProjection(children) + private[this] val convertToScala: (Row) => Row = { + val inputSchema = StructType(children.map(e => StructField(e.simpleString, e.dataType, true))) + CatalystTypeConverters.createToScalaConverter(inputSchema) + }.asInstanceOf[(Row => Row)] + override def eval(input: Row): TraversableOnce[Row] = { // TODO(davies): improve this // Convert the objects into Scala Type before calling function, we need schema to support UDT - val inputSchema = StructType(children.map(e => StructField(e.simpleString, e.dataType, true))) - val inputRow = new InterpretedProjection(children) - function(CatalystTypeConverters.convertToScala(inputRow(input), inputSchema).asInstanceOf[Row]) + function(convertToScala(inputRow(input))) } override def toString: String = s"UserDefinedGenerator(${children.mkString(",")})" From 85bba9dcbb95260f25ccfa5bdcd14510f02713ff Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 24 May 2015 18:40:45 -0700 Subject: [PATCH 05/13] Fix wrong input data in InMemoryColumnarQuerySuite The schema declares an array of booleans, but we passed an array of integers instead. --- .../apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala index 56591d9dba29e..055453e688e73 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala @@ -173,7 +173,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { new Timestamp(i), (1 to i).toSeq, (0 to i).map(j => s"map_key_$j" -> (Long.MaxValue - j)).toMap, - Row((i - 0.25).toFloat, (1 to i).toSeq)) + Row((i - 0.25).toFloat, Seq(true, false, null))) } createDataFrame(rdd, schema).registerTempTable("InMemoryCache_different_data_types") // Cache the table. From 8033d4c6a7c7e983e4e5d258b11750256eb3c0d9 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 24 May 2015 18:58:53 -0700 Subject: [PATCH 06/13] Fix serialization error in UserDefinedGenerator. --- .../sql/catalyst/expressions/generators.scala | 26 +++++++++++++++---- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index c8f9aacc0baad..c8ffcc3135c7f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -17,10 +17,13 @@ package org.apache.spark.sql.catalyst.expressions +import java.io.{ObjectInputStream, IOException} + import scala.collection.Map import org.apache.spark.sql.catalyst.{CatalystTypeConverters, trees} import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils /** * An expression that produces zero or more rows given a single input row. @@ -71,11 +74,24 @@ case class UserDefinedGenerator( children: Seq[Expression]) extends Generator { - private[this] val inputRow: InterpretedProjection = new InterpretedProjection(children) - private[this] val convertToScala: (Row) => Row = { - val inputSchema = StructType(children.map(e => StructField(e.simpleString, e.dataType, true))) - CatalystTypeConverters.createToScalaConverter(inputSchema) - }.asInstanceOf[(Row => Row)] + @transient private[this] var inputRow: InterpretedProjection = _ + @transient private[this] var convertToScala: (Row) => Row = _ + + private def initializeConverters(): Unit = { + inputRow = new InterpretedProjection(children) + convertToScala = { + val inputSchema = StructType(children.map(e => StructField(e.simpleString, e.dataType, true))) + CatalystTypeConverters.createToScalaConverter(inputSchema) + }.asInstanceOf[(Row => Row)] + } + + initializeConverters() + + @throws(classOf[IOException]) + private def readObject(ois: ObjectInputStream): Unit = Utils.tryOrIOException { + ois.defaultReadObject() + initializeConverters() + } override def eval(input: Row): TraversableOnce[Row] = { // TODO(davies): improve this From 677ff27c77b20fdd74132ed771e0589797eb9a06 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 24 May 2015 18:59:07 -0700 Subject: [PATCH 07/13] Fix null handling bug; add tests. --- .../sql/catalyst/CatalystTypeConverters.scala | 42 +++++++-------- .../CatalystTypeConvertersSuite.scala | 52 +++++++++++++++++++ 2 files changed, 73 insertions(+), 21 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index b838f92255554..0880bd499b2ac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -272,46 +272,46 @@ object CatalystTypeConverters { } } - private object BooleanConverter extends CatalystTypeConverter[Boolean, Boolean, Boolean] { + private object BooleanConverter extends CatalystTypeConverter[Boolean, Any, Any] { override def toScalaImpl(row: Row, column: Int): Boolean = row.getBoolean(column) - override def toScala(catalystValue: Boolean): Boolean = catalystValue - override protected def toCatalystImpl(scalaValue: Boolean): Boolean = scalaValue + override def toScala(catalystValue: Any): Any = catalystValue + override def toCatalystImpl(scalaValue: Boolean): Boolean = scalaValue } - private object ByteConverter extends CatalystTypeConverter[Byte, Byte, Byte] { + private object ByteConverter extends CatalystTypeConverter[Byte, Any, Any] { override def toScalaImpl(row: Row, column: Int): Byte = row.getByte(column) - override def toScala(catalystValue: Byte): Byte = catalystValue - override protected def toCatalystImpl(scalaValue: Byte): Byte = scalaValue + override def toScala(catalystValue: Any): Any = catalystValue + override def toCatalystImpl(scalaValue: Byte): Byte = scalaValue } - private object ShortConverter extends CatalystTypeConverter[Short, Short, Short] { + private object ShortConverter extends CatalystTypeConverter[Short, Any, Any] { override def toScalaImpl(row: Row, column: Int): Short = row.getShort(column) - override def toScala(catalystValue: Short): Short = catalystValue - override protected def toCatalystImpl(scalaValue: Short): Short = scalaValue + override def toScala(catalystValue: Any): Any = catalystValue + override def toCatalystImpl(scalaValue: Short): Short = scalaValue } - private object IntConverter extends CatalystTypeConverter[Int, Int, Int] { + private object IntConverter extends CatalystTypeConverter[Int, Any, Any] { override def toScalaImpl(row: Row, column: Int): Int = row.getInt(column) - override def toScala(catalystValue: Int): Int = catalystValue - override protected def toCatalystImpl(scalaValue: Int): Int = scalaValue + override def toScala(catalystValue: Any): Any = catalystValue + override def toCatalystImpl(scalaValue: Int): Int = scalaValue } - private object LongConverter extends CatalystTypeConverter[Long, Long, Long] { + private object LongConverter extends CatalystTypeConverter[Long, Any, Any] { override def toScalaImpl(row: Row, column: Int): Long = row.getLong(column) - override def toScala(catalystValue: Long): Long = catalystValue - override protected def toCatalystImpl(scalaValue: Long): Long = scalaValue + override def toScala(catalystValue: Any): Any = catalystValue + override def toCatalystImpl(scalaValue: Long): Long = scalaValue } - private object FloatConverter extends CatalystTypeConverter[Float, Float, Float] { + private object FloatConverter extends CatalystTypeConverter[Float, Any, Any] { override def toScalaImpl(row: Row, column: Int): Float = row.getFloat(column) - override def toScala(catalystValue: Float): Float = catalystValue - override protected def toCatalystImpl(scalaValue: Float): Float = scalaValue + override def toScala(catalystValue: Any): Any = catalystValue + override def toCatalystImpl(scalaValue: Float): Float = scalaValue } - private object DoubleConverter extends CatalystTypeConverter[Double, Double, Double] { + private object DoubleConverter extends CatalystTypeConverter[Double, Any, Any] { override def toScalaImpl(row: Row, column: Int): Double = row.getDouble(column) - override def toScala(catalystValue: Double): Double = catalystValue - override protected def toCatalystImpl(scalaValue: Double): Double = scalaValue + override def toScala(catalystValue: Any): Any = catalystValue + override def toCatalystImpl(scalaValue: Double): Double = scalaValue } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala new file mode 100644 index 0000000000000..4c36357a7d21b --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst + +import org.scalatest.FunSuite + +import org.apache.spark.sql.Row +import org.apache.spark.sql.types._ + +class CatalystTypeConvertersSuite extends FunSuite { + + private val simpleTypes: Seq[DataType] = Seq( + StringType, + DateType, + BooleanType, + ByteType, + ShortType, + IntegerType, + LongType, + FloatType, + DoubleType) + + test("null handling in rows") { + val schema = StructType(simpleTypes.map(t => StructField(t.getClass.getName, t))) + val convertToCatalyst = CatalystTypeConverters.createToCatalystConverter(schema) + val convertToScala = CatalystTypeConverters.createToScalaConverter(schema) + + val scalaRow = Row.fromSeq(Seq.fill(simpleTypes.length)(null)) + assert(convertToScala(convertToCatalyst(scalaRow)) === scalaRow) + } + + test("null handling for individual values") { + for (dataType <- simpleTypes) { + assert(CatalystTypeConverters.createToScalaConverter(dataType)(null) === null) + } + } +} From 6ad0ebbd0e8381cf2fab7249c24746df0dba73a4 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 25 May 2015 16:23:13 -0700 Subject: [PATCH 08/13] Fix JavaHashingTFSuite ClassCastException --- .../org/apache/spark/ml/feature/JavaHashingTFSuite.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java index da2218056307e..599e9cfd23ad4 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java @@ -55,9 +55,9 @@ public void tearDown() { @Test public void hashingTF() { JavaRDD jrdd = jsc.parallelize(Lists.newArrayList( - RowFactory.create(0, "Hi I heard about Spark"), - RowFactory.create(0, "I wish Java could use case classes"), - RowFactory.create(1, "Logistic regression models are neat") + RowFactory.create(0.0, "Hi I heard about Spark"), + RowFactory.create(0.0, "I wish Java could use case classes"), + RowFactory.create(1.0, "Logistic regression models are neat") )); StructType schema = new StructType(new StructField[]{ new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), From 3f7b2d8c60d2cdfd78ab1669108c6d2b80731abf Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 26 May 2015 16:03:52 -0700 Subject: [PATCH 09/13] Initialize converters lazily so that the attributes are resolved first --- .../sql/catalyst/expressions/generators.scala | 15 +++------------ 1 file changed, 3 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index c8ffcc3135c7f..b6191eafba71b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -17,13 +17,10 @@ package org.apache.spark.sql.catalyst.expressions -import java.io.{ObjectInputStream, IOException} - import scala.collection.Map import org.apache.spark.sql.catalyst.{CatalystTypeConverters, trees} import org.apache.spark.sql.types._ -import org.apache.spark.util.Utils /** * An expression that produces zero or more rows given a single input row. @@ -85,16 +82,10 @@ case class UserDefinedGenerator( }.asInstanceOf[(Row => Row)] } - initializeConverters() - - @throws(classOf[IOException]) - private def readObject(ois: ObjectInputStream): Unit = Utils.tryOrIOException { - ois.defaultReadObject() - initializeConverters() - } - override def eval(input: Row): TraversableOnce[Row] = { - // TODO(davies): improve this + if (inputRow == null) { + initializeConverters() + } // Convert the objects into Scala Type before calling function, we need schema to support UDT function(convertToScala(inputRow(input))) } From 6edf7f84b76a40c607ac37cb0286b5309f23f5db Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 27 May 2015 12:04:01 -0700 Subject: [PATCH 10/13] Re-add convertToScala(), since a Hive test still needs it --- .../spark/sql/catalyst/CatalystTypeConverters.scala | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index 0880bd499b2ac..7823181de3115 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -353,6 +353,15 @@ object CatalystTypeConverters { case other => other } + /** + * Converts Catalyst types used internally in rows to standard Scala types + * This method is slow, and for batch conversion you should be using converter + * produced by createToScalaConverter. + */ + def convertToScala(catalystValue: Any, dataType: DataType): Any = { + getConverterForType(dataType).toScala(catalystValue) + } + /** * Creates a converter function that will convert Catalyst types to Scala type. * Typical use case would be converting a collection of rows that have the same schema. You will From 598959343f90ab1c5056fca7a93c44ddd1a36dbe Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 29 May 2015 15:47:20 -0700 Subject: [PATCH 11/13] Use new SparkFunSuite base in CatalystTypeConvertersSuite --- .../spark/sql/catalyst/CatalystTypeConvertersSuite.scala | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala index 4c36357a7d21b..8c380d37cb7f0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.sql.catalyst -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.Row import org.apache.spark.sql.types._ -class CatalystTypeConvertersSuite extends FunSuite { +class CatalystTypeConvertersSuite extends SparkFunSuite { private val simpleTypes: Seq[DataType] = Seq( StringType, From befc613b3214b8b302f1ad0c61b7dc9f2dd0f1f2 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 31 May 2015 17:08:48 -0700 Subject: [PATCH 12/13] Add tests to document Option-handling behavior. --- .../sql/catalyst/CatalystTypeConvertersSuite.scala | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala index 8c380d37cb7f0..df0f04563edcf 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala @@ -48,4 +48,15 @@ class CatalystTypeConvertersSuite extends SparkFunSuite { assert(CatalystTypeConverters.createToScalaConverter(dataType)(null) === null) } } + + test("option handling in convertToCatalyst") { + // convertToCatalyst doesn't handle unboxing from Options. This is inconsistent with + // createToCatalystConverter but it may not actually matter as this is only called internally + // in a handful of places where we don't expect to receive Options. + assert(CatalystTypeConverters.convertToCatalyst(Some(123)) === Some(123)) + } + + test("option handling in createToCatalystConverter") { + assert(CatalystTypeConverters.createToCatalystConverter(IntegerType)(Some(123)) === 123) + } } From 740341b1780af43d12a713d70d11002ae87b4736 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 31 May 2015 17:09:17 -0700 Subject: [PATCH 13/13] Optimize method dispatch for primitive type conversions --- .../sql/catalyst/CatalystTypeConverters.scala | 86 ++++++++++++------- 1 file changed, 55 insertions(+), 31 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index 7823181de3115..2e7b4c236d8f8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -37,6 +37,19 @@ object CatalystTypeConverters { // Since the map values can be mutable, we explicitly import scala.collection.Map at here. import scala.collection.Map + private def isPrimitive(dataType: DataType): Boolean = { + dataType match { + case BooleanType => true + case ByteType => true + case ShortType => true + case IntegerType => true + case LongType => true + case FloatType => true + case DoubleType => true + case _ => false + } + } + private def getConverterForType(dataType: DataType): CatalystTypeConverter[Any, Any, Any] = { val converter = dataType match { case udt: UserDefinedType[_] => UDTConverter(udt) @@ -73,15 +86,17 @@ object CatalystTypeConverters { * and Options. */ final def toCatalyst(@Nullable maybeScalaValue: Any): CatalystType = { - maybeScalaValue match { - case opt: Option[ScalaInputType] => - if (opt.isDefined) { - toCatalystImpl(opt.get) - } else { - null.asInstanceOf[CatalystType] - } - case null => null.asInstanceOf[CatalystType] - case scalaValue: ScalaInputType => toCatalystImpl(scalaValue) + if (maybeScalaValue == null) { + null.asInstanceOf[CatalystType] + } else if (maybeScalaValue.isInstanceOf[Option[ScalaInputType]]) { + val opt = maybeScalaValue.asInstanceOf[Option[ScalaInputType]] + if (opt.isDefined) { + toCatalystImpl(opt.get) + } else { + null.asInstanceOf[CatalystType] + } + } else { + toCatalystImpl(maybeScalaValue.asInstanceOf[ScalaInputType]) } } @@ -272,46 +287,37 @@ object CatalystTypeConverters { } } - private object BooleanConverter extends CatalystTypeConverter[Boolean, Any, Any] { + private abstract class PrimitiveConverter[T] extends CatalystTypeConverter[T, Any, Any] { + final override def toScala(catalystValue: Any): Any = catalystValue + final override def toCatalystImpl(scalaValue: T): Any = scalaValue + } + + private object BooleanConverter extends PrimitiveConverter[Boolean] { override def toScalaImpl(row: Row, column: Int): Boolean = row.getBoolean(column) - override def toScala(catalystValue: Any): Any = catalystValue - override def toCatalystImpl(scalaValue: Boolean): Boolean = scalaValue } - private object ByteConverter extends CatalystTypeConverter[Byte, Any, Any] { + private object ByteConverter extends PrimitiveConverter[Byte] { override def toScalaImpl(row: Row, column: Int): Byte = row.getByte(column) - override def toScala(catalystValue: Any): Any = catalystValue - override def toCatalystImpl(scalaValue: Byte): Byte = scalaValue } - private object ShortConverter extends CatalystTypeConverter[Short, Any, Any] { + private object ShortConverter extends PrimitiveConverter[Short] { override def toScalaImpl(row: Row, column: Int): Short = row.getShort(column) - override def toScala(catalystValue: Any): Any = catalystValue - override def toCatalystImpl(scalaValue: Short): Short = scalaValue } - private object IntConverter extends CatalystTypeConverter[Int, Any, Any] { + private object IntConverter extends PrimitiveConverter[Int] { override def toScalaImpl(row: Row, column: Int): Int = row.getInt(column) - override def toScala(catalystValue: Any): Any = catalystValue - override def toCatalystImpl(scalaValue: Int): Int = scalaValue } - private object LongConverter extends CatalystTypeConverter[Long, Any, Any] { + private object LongConverter extends PrimitiveConverter[Long] { override def toScalaImpl(row: Row, column: Int): Long = row.getLong(column) - override def toScala(catalystValue: Any): Any = catalystValue - override def toCatalystImpl(scalaValue: Long): Long = scalaValue } - private object FloatConverter extends CatalystTypeConverter[Float, Any, Any] { + private object FloatConverter extends PrimitiveConverter[Float] { override def toScalaImpl(row: Row, column: Int): Float = row.getFloat(column) - override def toScala(catalystValue: Any): Any = catalystValue - override def toCatalystImpl(scalaValue: Float): Float = scalaValue } - private object DoubleConverter extends CatalystTypeConverter[Double, Any, Any] { + private object DoubleConverter extends PrimitiveConverter[Double] { override def toScalaImpl(row: Row, column: Int): Double = row.getDouble(column) - override def toScala(catalystValue: Any): Any = catalystValue - override def toCatalystImpl(scalaValue: Double): Double = scalaValue } /** @@ -330,7 +336,25 @@ object CatalystTypeConverters { * call this function once to get a converter, and apply it to every row. */ private[sql] def createToCatalystConverter(dataType: DataType): Any => Any = { - getConverterForType(dataType).toCatalyst + if (isPrimitive(dataType)) { + // Although the `else` branch here is capable of handling inbound conversion of primitives, + // we add some special-case handling for those types here. The motivation for this relates to + // Java method invocation costs: if we have rows that consist entirely of primitive columns, + // then returning the same conversion function for all of the columns means that the call site + // will be monomorphic instead of polymorphic. In microbenchmarks, this actually resulted in + // a measurable performance impact. Note that this optimization will be unnecessary if we + // use code generation to construct Scala Row -> Catalyst Row converters. + def convert(maybeScalaValue: Any): Any = { + if (maybeScalaValue.isInstanceOf[Option[Any]]) { + maybeScalaValue.asInstanceOf[Option[Any]].orNull + } else { + maybeScalaValue + } + } + convert + } else { + getConverterForType(dataType).toCatalyst + } } /**