Skip to content

Commit

Permalink
Fix null handling bug; add tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
JoshRosen committed May 29, 2015
1 parent 8033d4c commit 677ff27
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -272,46 +272,46 @@ object CatalystTypeConverters {
}
}

private object BooleanConverter extends CatalystTypeConverter[Boolean, Boolean, Boolean] {
private object BooleanConverter extends CatalystTypeConverter[Boolean, Any, Any] {
override def toScalaImpl(row: Row, column: Int): Boolean = row.getBoolean(column)
override def toScala(catalystValue: Boolean): Boolean = catalystValue
override protected def toCatalystImpl(scalaValue: Boolean): Boolean = scalaValue
override def toScala(catalystValue: Any): Any = catalystValue
override def toCatalystImpl(scalaValue: Boolean): Boolean = scalaValue
}

private object ByteConverter extends CatalystTypeConverter[Byte, Byte, Byte] {
private object ByteConverter extends CatalystTypeConverter[Byte, Any, Any] {
override def toScalaImpl(row: Row, column: Int): Byte = row.getByte(column)
override def toScala(catalystValue: Byte): Byte = catalystValue
override protected def toCatalystImpl(scalaValue: Byte): Byte = scalaValue
override def toScala(catalystValue: Any): Any = catalystValue
override def toCatalystImpl(scalaValue: Byte): Byte = scalaValue
}

private object ShortConverter extends CatalystTypeConverter[Short, Short, Short] {
private object ShortConverter extends CatalystTypeConverter[Short, Any, Any] {
override def toScalaImpl(row: Row, column: Int): Short = row.getShort(column)
override def toScala(catalystValue: Short): Short = catalystValue
override protected def toCatalystImpl(scalaValue: Short): Short = scalaValue
override def toScala(catalystValue: Any): Any = catalystValue
override def toCatalystImpl(scalaValue: Short): Short = scalaValue
}

private object IntConverter extends CatalystTypeConverter[Int, Int, Int] {
private object IntConverter extends CatalystTypeConverter[Int, Any, Any] {
override def toScalaImpl(row: Row, column: Int): Int = row.getInt(column)
override def toScala(catalystValue: Int): Int = catalystValue
override protected def toCatalystImpl(scalaValue: Int): Int = scalaValue
override def toScala(catalystValue: Any): Any = catalystValue
override def toCatalystImpl(scalaValue: Int): Int = scalaValue
}

private object LongConverter extends CatalystTypeConverter[Long, Long, Long] {
private object LongConverter extends CatalystTypeConverter[Long, Any, Any] {
override def toScalaImpl(row: Row, column: Int): Long = row.getLong(column)
override def toScala(catalystValue: Long): Long = catalystValue
override protected def toCatalystImpl(scalaValue: Long): Long = scalaValue
override def toScala(catalystValue: Any): Any = catalystValue
override def toCatalystImpl(scalaValue: Long): Long = scalaValue
}

private object FloatConverter extends CatalystTypeConverter[Float, Float, Float] {
private object FloatConverter extends CatalystTypeConverter[Float, Any, Any] {
override def toScalaImpl(row: Row, column: Int): Float = row.getFloat(column)
override def toScala(catalystValue: Float): Float = catalystValue
override protected def toCatalystImpl(scalaValue: Float): Float = scalaValue
override def toScala(catalystValue: Any): Any = catalystValue
override def toCatalystImpl(scalaValue: Float): Float = scalaValue
}

private object DoubleConverter extends CatalystTypeConverter[Double, Double, Double] {
private object DoubleConverter extends CatalystTypeConverter[Double, Any, Any] {
override def toScalaImpl(row: Row, column: Int): Double = row.getDouble(column)
override def toScala(catalystValue: Double): Double = catalystValue
override protected def toCatalystImpl(scalaValue: Double): Double = scalaValue
override def toScala(catalystValue: Any): Any = catalystValue
override def toCatalystImpl(scalaValue: Double): Double = scalaValue
}

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst

import org.scalatest.FunSuite

import org.apache.spark.sql.Row
import org.apache.spark.sql.types._

class CatalystTypeConvertersSuite extends FunSuite {

private val simpleTypes: Seq[DataType] = Seq(
StringType,
DateType,
BooleanType,
ByteType,
ShortType,
IntegerType,
LongType,
FloatType,
DoubleType)

test("null handling in rows") {
val schema = StructType(simpleTypes.map(t => StructField(t.getClass.getName, t)))
val convertToCatalyst = CatalystTypeConverters.createToCatalystConverter(schema)
val convertToScala = CatalystTypeConverters.createToScalaConverter(schema)

val scalaRow = Row.fromSeq(Seq.fill(simpleTypes.length)(null))
assert(convertToScala(convertToCatalyst(scalaRow)) === scalaRow)
}

test("null handling for individual values") {
for (dataType <- simpleTypes) {
assert(CatalystTypeConverters.createToScalaConverter(dataType)(null) === null)
}
}
}

0 comments on commit 677ff27

Please sign in to comment.