Skip to content

Commit

Permalink
Moved udt case to top of all matches. Small cleanups
Browse files Browse the repository at this point in the history
  • Loading branch information
jkbradley committed Nov 2, 2014
1 parent b028675 commit 7f29656
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ object ScalaReflection {
* This ordering is important for UDT registration.
*/
def convertToCatalyst(a: Any, dataType: DataType): Any = (a, dataType) match {
// Check UDT first since UDTs can override other types
case (obj, udt: UserDefinedType[_]) => udt.serialize(obj)
case (o: Option[_], _) => o.map(convertToCatalyst(_, dataType)).orNull
case (s: Seq[_], arrayType: ArrayType) => s.map(convertToCatalyst(_, arrayType.elementType))
case (m: Map[_, _], mapType: MapType) => m.map { case (k, v) =>
Expand All @@ -54,18 +56,18 @@ object ScalaReflection {
convertToCatalyst(elem, field.dataType)
}.toArray)
case (d: BigDecimal, _) => Decimal(d)
case (obj, udt: UserDefinedType[_]) => udt.serialize(obj)
case (other, _) => other
}

/** Converts Catalyst types used internally in rows to standard Scala types */
def convertToScala(a: Any, dataType: DataType): Any = (a, dataType) match {
// Check UDT first since UDTs can override other types
case (d, udt: UserDefinedType[_]) => udt.deserialize(d)
case (s: Seq[_], arrayType: ArrayType) => s.map(convertToScala(_, arrayType.elementType))
case (m: Map[_, _], mapType: MapType) => m.map { case (k, v) =>
convertToScala(k, mapType.keyType) -> convertToScala(v, mapType.valueType)
}
case (d: Decimal, _: DecimalType) => d.toBigDecimal
case (d, udt: UserDefinedType[_]) => udt.deserialize(d)
case (other, _) => other
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@ private[sql] object CatalystConverter {
parent: CatalystConverter): Converter = {
val fieldType: DataType = field.dataType
fieldType match {
// Check UDT first since UDTs can override other types
case udt: UserDefinedType[_] => {
createConverter(field.copy(dataType = udt.sqlType), fieldIndex, parent)
}
// For native JVM types we use a converter with native arrays
case ArrayType(elementType: NativeType, false) => {
new CatalystNativeArrayConverter(elementType, fieldIndex, parent)
Expand All @@ -99,9 +103,6 @@ private[sql] object CatalystConverter {
fieldIndex,
parent)
}
case udt: UserDefinedType[_] => {
createConverter(field.copy(dataType = udt.sqlType), fieldIndex, parent)
}
// Strings, Shorts and Bytes do not have a corresponding type in Parquet
// so we need to treat them separately
case StringType => {
Expand Down Expand Up @@ -258,8 +259,8 @@ private[parquet] class CatalystGroupConverter(
schema,
index,
parent,
current=null,
buffer=new ArrayBuffer[Row](
current = null,
buffer = new ArrayBuffer[Row](
CatalystArrayConverter.INITIAL_ARRAY_SIZE))

/**
Expand Down Expand Up @@ -304,7 +305,7 @@ private[parquet] class CatalystGroupConverter(

override def end(): Unit = {
if (!isRootConverter) {
assert(current!=null) // there should be no empty groups
assert(current != null) // there should be no empty groups
buffer.append(new GenericRow(current.toArray))
parent.updateField(index, new GenericRow(buffer.toArray.asInstanceOf[Array[Any]]))
}
Expand Down Expand Up @@ -361,7 +362,7 @@ private[parquet] class CatalystPrimitiveRowConverter(

override def end(): Unit = {}

// Overriden here to avoid auto-boxing for primitive types
// Overridden here to avoid auto-boxing for primitive types
override protected[parquet] def updateBoolean(fieldIndex: Int, value: Boolean): Unit =
current.setBoolean(fieldIndex, value)

Expand Down Expand Up @@ -536,7 +537,7 @@ private[parquet] class CatalystNativeArrayConverter(
override protected[parquet] def updateField(fieldIndex: Int, value: Any): Unit =
throw new UnsupportedOperationException

// Overriden here to avoid auto-boxing for primitive types
// Overridden here to avoid auto-boxing for primitive types
override protected[parquet] def updateBoolean(fieldIndex: Int, value: Boolean): Unit = {
checkGrowBuffer()
buffer(elements) = value.asInstanceOf[NativeType]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,8 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging {
private[parquet] def writeValue(schema: DataType, value: Any): Unit = {
if (value != null) {
schema match {
// Check UDT first since UDTs can override other types
case t: UserDefinedType[_] => writeValue(t.sqlType, value)
case t @ ArrayType(_, _) => writeArray(
t,
value.asInstanceOf[CatalystConverter.ArrayScalaType[_]])
Expand All @@ -183,7 +185,6 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging {
case t @ StructType(_) => writeStruct(
t,
value.asInstanceOf[CatalystConverter.StructScalaType[_]])
case t: UserDefinedType[_] => writeValue(t.sqlType, value)
case _ => writePrimitive(schema.asInstanceOf[PrimitiveType], value)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,10 @@ private[parquet] object ParquetTypesConverter extends Logging {
builder.named(name)
}.getOrElse {
ctype match {
// Check UDT first since UDTs can override other types
case udt: UserDefinedType[_] => {
fromDataType(udt.sqlType, name, nullable, inArray)
}
case ArrayType(elementType, false) => {
val parquetElementType = fromDataType(
elementType,
Expand Down Expand Up @@ -337,9 +341,6 @@ private[parquet] object ParquetTypesConverter extends Logging {
parquetKeyType,
parquetValueType)
}
case udt: UserDefinedType[_] => {
fromDataType(udt.sqlType, name, nullable, inArray)
}
case _ => sys.error(s"Unsupported datatype $ctype")
}
}
Expand Down

0 comments on commit 7f29656

Please sign in to comment.