Skip to content
This repository has been archived by the owner on Dec 20, 2018. It is now read-only.

DecimalType is mapped to avro bytes with logicalType #276

Open
wants to merge 1 commit into
base: branch-3.2
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ This library supports writing of all Spark SQL types into Avro. For most types,
| ---------------|-----------|
| ByteType | int |
| ShortType | int |
| DecimalType | string |
| DecimalType | bytes |
| BinaryType | bytes |
| TimestampType | long |
| StructType | record |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,10 @@ private[avro] class AvroOutputWriter(
}
case ByteType | ShortType | IntegerType | LongType |
FloatType | DoubleType | StringType | BooleanType => identity
case _: DecimalType => (item: Any) => if (item == null) null else item.toString
case decimalType: DecimalType => (item: Any) => if (item == null) null else {
val decimal = item.asInstanceOf[java.math.BigDecimal]
ByteBuffer.wrap(decimal.unscaledValue().toByteArray)
}
case TimestampType => (item: Any) =>
if (item == null) null else item.asInstanceOf[Timestamp].getTime
case DateType => (item: Any) =>
Expand Down
55 changes: 52 additions & 3 deletions src/main/scala/com/databricks/spark/avro/SchemaConverters.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,20 @@ object SchemaConverters {

case class SchemaType(dataType: DataType, nullable: Boolean)

/**
* Indicator of a field with decimal logical type and scale property.
*/
private def isDecimalField(avroSchema: Schema): Boolean = {
val nullableLogicalTypeNode = avroSchema.getJsonProp("logicalType")
val logicalTypeOption = Option(nullableLogicalTypeNode).map(_.asText())
val matchLogicalType = logicalTypeOption == Some("decimal")
val hasScale = Option(decimalScaleProp(avroSchema))
.map(_.asInt(Int.MinValue)).exists(_ >= 0)
val hasPrecision = Option(decimalPrecisionProp(avroSchema))
.map(_.asInt(Int.MinValue)).exists(_ > 0)
matchLogicalType && hasScale && hasPrecision
}

/**
* This function takes an avro schema and returns a sql schema.
*/
Expand All @@ -46,7 +60,12 @@ object SchemaConverters {
case INT => SchemaType(IntegerType, nullable = false)
case STRING => SchemaType(StringType, nullable = false)
case BOOLEAN => SchemaType(BooleanType, nullable = false)
case BYTES => SchemaType(BinaryType, nullable = false)
case BYTES => if (isDecimalField(avroSchema)) {
SchemaType(DecimalType(
decimalPrecisionProp(avroSchema).asInt,
decimalScaleProp(avroSchema).asInt
), nullable = false)
} else SchemaType(BinaryType, nullable = false)
case DOUBLE => SchemaType(DoubleType, nullable = false)
case FLOAT => SchemaType(FloatType, nullable = false)
case LONG => SchemaType(LongType, nullable = false)
Expand Down Expand Up @@ -106,6 +125,14 @@ object SchemaConverters {
}
}

private def decimalScaleProp(avroSchema: Schema) = {
avroSchema.getJsonProp("scale")
}

private def decimalPrecisionProp(avroSchema: Schema) = {
avroSchema.getJsonProp("precision")
}

/**
* This function converts sparkSQL StructType into avro schema. This method uses two other
* converter methods in order to do the conversion.
Expand Down Expand Up @@ -170,6 +197,17 @@ object SchemaConverters {
bytes
}

case (decimalType: DecimalType, BYTES) =>
(item: AnyRef) =>
if (item == null) {
null
} else {
val byteBuffer = item.asInstanceOf[ByteBuffer]
val bytes = new Array[Byte](byteBuffer.remaining)
byteBuffer.get(bytes)
BigDecimal(BigInt(bytes), decimalType.scale)
}

case (struct: StructType, RECORD) =>
val length = struct.fields.length
val converters = new Array[AnyRef => AnyRef](length)
Expand Down Expand Up @@ -323,7 +361,8 @@ object SchemaConverters {
case LongType => schemaBuilder.longType()
case FloatType => schemaBuilder.floatType()
case DoubleType => schemaBuilder.doubleType()
case _: DecimalType => schemaBuilder.stringType()
case decimalType: DecimalType =>
createBytesWithDecimalLogicalType(schemaBuilder.bytesBuilder(), decimalType)
case StringType => schemaBuilder.stringType()
case BinaryType => schemaBuilder.bytesType()
case BooleanType => schemaBuilder.booleanType()
Expand All @@ -350,6 +389,15 @@ object SchemaConverters {
}
}

private def createBytesWithDecimalLogicalType[T](
bytesBuilder: BytesBuilder[T], decimalType: DecimalType) = {
bytesBuilder
.prop("logicalType", "decimal")
.prop("precision", decimalType.precision.toString)
.prop("scale", decimalType.scale.toString)
.endBytes()
}

/**
* This function is used to construct fields of the avro record, where schema of the field is
* specified by avro representation of dataType. Since builders for record fields are different
Expand All @@ -367,7 +415,8 @@ object SchemaConverters {
case LongType => newFieldBuilder.longType()
case FloatType => newFieldBuilder.floatType()
case DoubleType => newFieldBuilder.doubleType()
case _: DecimalType => newFieldBuilder.stringType()
case decimalType: DecimalType =>
createBytesWithDecimalLogicalType(newFieldBuilder.bytesBuilder(), decimalType)
case StringType => newFieldBuilder.stringType()
case BinaryType => newFieldBuilder.bytesType()
case BooleanType => newFieldBuilder.booleanType()
Expand Down
7 changes: 5 additions & 2 deletions src/test/scala/com/databricks/spark/avro/AvroSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,9 @@ class AvroSuite extends FunSuite with BeforeAndAfterAll {
for (i <- arrayOfByte.indices) {
arrayOfByte(i) = i.toByte
}

val decimalBytes = new java.math.BigDecimal("3.14").unscaledValue().toByteArray

val cityRDD = spark.sparkContext.parallelize(Seq(
Row("San Francisco", 12, new Timestamp(666), null, arrayOfByte),
Row("Palo Alto", null, new Timestamp(777), null, arrayOfByte),
Expand All @@ -492,9 +495,9 @@ class AvroSuite extends FunSuite with BeforeAndAfterAll {
val times = spark.read.avro(avroDir).select("Time").collect()
assert(times.map(_(0)).toSet == Set(666, 777, 42))

// DecimalType should be converted to string
// DecimalType should be converted to java.math.BigDecimal
val decimals = spark.read.avro(avroDir).select("Decimal").collect()
assert(decimals.map(_(0)).contains("3.14"))
assert(decimals.map(_(0)).contains(new java.math.BigDecimal("3.14")))

// There should be a null entry
val length = spark.read.avro(avroDir).select("Length").collect()
Expand Down