Skip to content

Commit

Permalink
initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
kiszk committed Nov 5, 2016
1 parent 8a9ca19 commit ba7494d
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ object ExpressionEncoder {
val cls = mirror.runtimeClass(tpe)
val flat = !ScalaReflection.definedByConstructorParams(tpe)

val inputObject = BoundReference(0, ScalaReflection.dataTypeFor[T], nullable = true)
val inputObject = BoundReference(0, ScalaReflection.dataTypeFor[T], nullable = !cls.isPrimitive)
val nullSafeInput = if (flat) {
inputObject
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import java.util.Arrays
import scala.collection.mutable.ArrayBuffer
import scala.reflect.runtime.universe.TypeTag

import org.apache.spark.sql.Encoders
import org.apache.spark.sql.{Encoder, Encoders}
import org.apache.spark.sql.catalyst.{OptionalData, PrimitiveData}
import org.apache.spark.sql.catalyst.analysis.AnalysisTest
import org.apache.spark.sql.catalyst.dsl.plans._
Expand Down Expand Up @@ -338,6 +338,18 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest {
}
}

test("nullable of encoder serializer") {
def checkNullable[T: Encoder](nullable: Boolean*): Unit = {
assert(encoderFor[T].serializer.map(_.nullable) === nullable.toSeq)
}

// test for flat encoders
checkNullable[Int](false)
checkNullable[Option[Int]](true)
checkNullable[java.lang.Integer](true)
checkNullable[String](true)
}

test("null check for map key") {
val encoder = ExpressionEncoder[Map[String, Int]]()
val e = intercept[RuntimeException](encoder.toRow(Map(("a", 1), (null, 2))))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -935,6 +935,13 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
assert(dataset.collect() sameElements Array(resultValue, resultValue))
}

test("SPARK-18284: Serializer should have correct nullable value") {
val df1 = sparkContext.parallelize(Seq(1, 2, 3, 4), 1).toDF()
assert(df1.schema(0).nullable == false)
val df2 = sparkContext.parallelize(Seq(Integer.valueOf(1), Integer.valueOf(2)), 1).toDF()
assert(df2.schema(0).nullable == true)
}

Seq(true, false).foreach { eager =>
def testCheckpointing(testName: String)(f: => Unit): Unit = {
test(s"Dataset.checkpoint() - $testName (eager = $eager)") {
Expand Down

0 comments on commit ba7494d

Please sign in to comment.