From 7bc9a8c6249300ded31ea931c463d0a8f798e193 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 25 Aug 2015 01:06:36 -0700 Subject: [PATCH] [SPARK-10195] [SQL] Data sources Filter should not expose internal types Spark SQL's data sources API exposes Catalyst's internal types through its Filter interfaces. This is a problem because types like UTF8String are not stable developer APIs and should not be exposed to third-parties. This issue caused incompatibilities when upgrading our `spark-redshift` library to work against Spark 1.5.0. To avoid these issues in the future we should only expose public types through these Filter objects. This patch accomplishes this by using CatalystTypeConverters to add the appropriate conversions. Author: Josh Rosen Closes #8403 from JoshRosen/datasources-internal-vs-external-types. --- .../datasources/DataSourceStrategy.scala | 67 ++++++++++--------- .../execution/datasources/jdbc/JDBCRDD.scala | 2 +- .../datasources/parquet/ParquetFilters.scala | 19 +++--- .../spark/sql/sources/FilteredScanSuite.scala | 7 ++ 4 files changed, 54 insertions(+), 41 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 2a4c40db8bb66..6c1ef6a6df887 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -20,7 +20,8 @@ package org.apache.spark.sql.execution.datasources import org.apache.spark.{Logging, TaskContext} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.{MapPartitionsRDD, RDD, UnionRDD} -import org.apache.spark.sql.catalyst.{InternalRow, expressions} +import org.apache.spark.sql.catalyst.CatalystTypeConverters.convertToScala +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, expressions} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical @@ -344,45 +345,47 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { */ protected[sql] def selectFilters(filters: Seq[Expression]) = { def translate(predicate: Expression): Option[Filter] = predicate match { - case expressions.EqualTo(a: Attribute, Literal(v, _)) => - Some(sources.EqualTo(a.name, v)) - case expressions.EqualTo(Literal(v, _), a: Attribute) => - Some(sources.EqualTo(a.name, v)) - - case expressions.EqualNullSafe(a: Attribute, Literal(v, _)) => - Some(sources.EqualNullSafe(a.name, v)) - case expressions.EqualNullSafe(Literal(v, _), a: Attribute) => - Some(sources.EqualNullSafe(a.name, v)) - - case expressions.GreaterThan(a: Attribute, Literal(v, _)) => - Some(sources.GreaterThan(a.name, v)) - case expressions.GreaterThan(Literal(v, _), a: Attribute) => - Some(sources.LessThan(a.name, v)) - - case expressions.LessThan(a: Attribute, Literal(v, _)) => - Some(sources.LessThan(a.name, v)) - case expressions.LessThan(Literal(v, _), a: Attribute) => - Some(sources.GreaterThan(a.name, v)) - - case expressions.GreaterThanOrEqual(a: Attribute, Literal(v, _)) => - Some(sources.GreaterThanOrEqual(a.name, v)) - case expressions.GreaterThanOrEqual(Literal(v, _), a: Attribute) => - Some(sources.LessThanOrEqual(a.name, v)) - - case expressions.LessThanOrEqual(a: Attribute, Literal(v, _)) => - Some(sources.LessThanOrEqual(a.name, v)) - case expressions.LessThanOrEqual(Literal(v, _), a: Attribute) => - Some(sources.GreaterThanOrEqual(a.name, v)) + case expressions.EqualTo(a: Attribute, Literal(v, t)) => + Some(sources.EqualTo(a.name, convertToScala(v, t))) + case expressions.EqualTo(Literal(v, t), a: Attribute) => + Some(sources.EqualTo(a.name, convertToScala(v, t))) + + case expressions.EqualNullSafe(a: Attribute, Literal(v, t)) => + Some(sources.EqualNullSafe(a.name, convertToScala(v, t))) + case expressions.EqualNullSafe(Literal(v, t), a: Attribute) => + Some(sources.EqualNullSafe(a.name, convertToScala(v, t))) + + case expressions.GreaterThan(a: Attribute, Literal(v, t)) => + Some(sources.GreaterThan(a.name, convertToScala(v, t))) + case expressions.GreaterThan(Literal(v, t), a: Attribute) => + Some(sources.LessThan(a.name, convertToScala(v, t))) + + case expressions.LessThan(a: Attribute, Literal(v, t)) => + Some(sources.LessThan(a.name, convertToScala(v, t))) + case expressions.LessThan(Literal(v, t), a: Attribute) => + Some(sources.GreaterThan(a.name, convertToScala(v, t))) + + case expressions.GreaterThanOrEqual(a: Attribute, Literal(v, t)) => + Some(sources.GreaterThanOrEqual(a.name, convertToScala(v, t))) + case expressions.GreaterThanOrEqual(Literal(v, t), a: Attribute) => + Some(sources.LessThanOrEqual(a.name, convertToScala(v, t))) + + case expressions.LessThanOrEqual(a: Attribute, Literal(v, t)) => + Some(sources.LessThanOrEqual(a.name, convertToScala(v, t))) + case expressions.LessThanOrEqual(Literal(v, t), a: Attribute) => + Some(sources.GreaterThanOrEqual(a.name, convertToScala(v, t))) case expressions.InSet(a: Attribute, set) => - Some(sources.In(a.name, set.toArray)) + val toScala = CatalystTypeConverters.createToScalaConverter(a.dataType) + Some(sources.In(a.name, set.toArray.map(toScala))) // Because we only convert In to InSet in Optimizer when there are more than certain // items. So it is possible we still get an In expression here that needs to be pushed // down. case expressions.In(a: Attribute, list) if !list.exists(!_.isInstanceOf[Literal]) => val hSet = list.map(e => e.eval(EmptyRow)) - Some(sources.In(a.name, hSet.toArray)) + val toScala = CatalystTypeConverters.createToScalaConverter(a.dataType) + Some(sources.In(a.name, hSet.toArray.map(toScala))) case expressions.IsNull(a: Attribute) => Some(sources.IsNull(a.name)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index e537d631f4559..730d88b024cb1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -262,7 +262,7 @@ private[sql] class JDBCRDD( * Converts value to SQL expression. */ private def compileValue(value: Any): Any = value match { - case stringValue: UTF8String => s"'${escapeSql(stringValue.toString)}'" + case stringValue: String => s"'${escapeSql(stringValue)}'" case _ => value } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala index c74c8388632f5..c6b3fe7900da8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala @@ -32,7 +32,6 @@ import org.apache.spark.SparkEnv import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.sources import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String private[sql] object ParquetFilters { val PARQUET_FILTER_DATA = "org.apache.spark.sql.parquet.row.filter" @@ -65,7 +64,7 @@ private[sql] object ParquetFilters { case StringType => (n: String, v: Any) => FilterApi.eq( binaryColumn(n), - Option(v).map(s => Binary.fromByteArray(s.asInstanceOf[UTF8String].getBytes)).orNull) + Option(v).map(s => Binary.fromByteArray(s.asInstanceOf[String].getBytes("utf-8"))).orNull) case BinaryType => (n: String, v: Any) => FilterApi.eq( binaryColumn(n), @@ -86,7 +85,7 @@ private[sql] object ParquetFilters { case StringType => (n: String, v: Any) => FilterApi.notEq( binaryColumn(n), - Option(v).map(s => Binary.fromByteArray(s.asInstanceOf[UTF8String].getBytes)).orNull) + Option(v).map(s => Binary.fromByteArray(s.asInstanceOf[String].getBytes("utf-8"))).orNull) case BinaryType => (n: String, v: Any) => FilterApi.notEq( binaryColumn(n), @@ -104,7 +103,8 @@ private[sql] object ParquetFilters { (n: String, v: Any) => FilterApi.lt(doubleColumn(n), v.asInstanceOf[java.lang.Double]) case StringType => (n: String, v: Any) => - FilterApi.lt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes)) + FilterApi.lt(binaryColumn(n), + Binary.fromByteArray(v.asInstanceOf[String].getBytes("utf-8"))) case BinaryType => (n: String, v: Any) => FilterApi.lt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]])) @@ -121,7 +121,8 @@ private[sql] object ParquetFilters { (n: String, v: Any) => FilterApi.ltEq(doubleColumn(n), v.asInstanceOf[java.lang.Double]) case StringType => (n: String, v: Any) => - FilterApi.ltEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes)) + FilterApi.ltEq(binaryColumn(n), + Binary.fromByteArray(v.asInstanceOf[String].getBytes("utf-8"))) case BinaryType => (n: String, v: Any) => FilterApi.ltEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]])) @@ -138,7 +139,8 @@ private[sql] object ParquetFilters { (n: String, v: Any) => FilterApi.gt(doubleColumn(n), v.asInstanceOf[java.lang.Double]) case StringType => (n: String, v: Any) => - FilterApi.gt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes)) + FilterApi.gt(binaryColumn(n), + Binary.fromByteArray(v.asInstanceOf[String].getBytes("utf-8"))) case BinaryType => (n: String, v: Any) => FilterApi.gt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]])) @@ -155,7 +157,8 @@ private[sql] object ParquetFilters { (n: String, v: Any) => FilterApi.gtEq(doubleColumn(n), v.asInstanceOf[java.lang.Double]) case StringType => (n: String, v: Any) => - FilterApi.gtEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes)) + FilterApi.gtEq(binaryColumn(n), + Binary.fromByteArray(v.asInstanceOf[String].getBytes("utf-8"))) case BinaryType => (n: String, v: Any) => FilterApi.gtEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]])) @@ -177,7 +180,7 @@ private[sql] object ParquetFilters { case StringType => (n: String, v: Set[Any]) => FilterApi.userDefined(binaryColumn(n), - SetInFilter(v.map(e => Binary.fromByteArray(e.asInstanceOf[UTF8String].getBytes)))) + SetInFilter(v.map(s => Binary.fromByteArray(s.asInstanceOf[String].getBytes("utf-8"))))) case BinaryType => (n: String, v: Set[Any]) => FilterApi.userDefined(binaryColumn(n), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala index c81c3d3982805..68ce37c00077e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.sources import scala.language.existentials import org.apache.spark.rdd.RDD +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.sql._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -78,6 +79,9 @@ case class SimpleFilteredScan(from: Int, to: Int)(@transient val sqlContext: SQL case StringStartsWith("c", v) => _.startsWith(v) case StringEndsWith("c", v) => _.endsWith(v) case StringContains("c", v) => _.contains(v) + case EqualTo("c", v: String) => _.equals(v) + case EqualTo("c", v: UTF8String) => sys.error("UTF8String should not appear in filters") + case In("c", values) => (s: String) => values.map(_.asInstanceOf[String]).toSet.contains(s) case _ => (c: String) => true } @@ -237,6 +241,9 @@ class FilteredScanSuite extends DataSourceTest with SharedSQLContext { testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like '%eE%'", 1) testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like '%Ee%'", 0) + testPushDown("SELECT c FROM oneToTenFiltered WHERE c = 'aaaaaAAAAA'", 1) + testPushDown("SELECT c FROM oneToTenFiltered WHERE c IN ('aaaaaAAAAA', 'foo')", 1) + def testPushDown(sqlString: String, expectedCount: Int): Unit = { test(s"PushDown Returns $expectedCount: $sqlString") { val queryExecution = sql(sqlString).queryExecution