diff --git a/build/mvn b/build/mvn
index efa4f9364ea52..1405983982d4c 100755
--- a/build/mvn
+++ b/build/mvn
@@ -154,4 +154,4 @@ export MAVEN_OPTS=${MAVEN_OPTS:-"$_COMPILE_JVM_OPTS"}
echo "Using \`mvn\` from path: $MVN_BIN" 1>&2
# Last, call the `mvn` command as usual
-${MVN_BIN} -DzincPort=${ZINC_PORT} "$@"
+"${MVN_BIN}" -DzincPort=${ZINC_PORT} "$@"
diff --git a/dev/run-tests.py b/dev/run-tests.py
index 5e8c8590b5c34..cd4590864b7d7 100755
--- a/dev/run-tests.py
+++ b/dev/run-tests.py
@@ -357,7 +357,7 @@ def build_spark_unidoc_sbt(hadoop_version):
exec_sbt(profiles_and_goals)
-def build_spark_assembly_sbt(hadoop_version):
+def build_spark_assembly_sbt(hadoop_version, checkstyle=False):
# Enable all of the profiles for the build:
build_profiles = get_hadoop_profiles(hadoop_version) + modules.root.build_profile_flags
sbt_goals = ["assembly/package"]
@@ -366,6 +366,9 @@ def build_spark_assembly_sbt(hadoop_version):
" ".join(profiles_and_goals))
exec_sbt(profiles_and_goals)
+ if checkstyle:
+ run_java_style_checks()
+
# Note that we skip Unidoc build only if Hadoop 2.6 is explicitly set in this SBT build.
# Due to a different dependency resolution in SBT & Unidoc by an unknown reason, the
# documentation build fails on a specific machine & environment in Jenkins but it was unable
@@ -570,11 +573,13 @@ def main():
or f.endswith("scalastyle-config.xml")
for f in changed_files):
run_scala_style_checks()
+ should_run_java_style_checks = False
if not changed_files or any(f.endswith(".java")
or f.endswith("checkstyle.xml")
or f.endswith("checkstyle-suppressions.xml")
for f in changed_files):
- run_java_style_checks()
+ # Run SBT Checkstyle after the build to prevent a side-effect to the build.
+ should_run_java_style_checks = True
if not changed_files or any(f.endswith("lint-python")
or f.endswith("tox.ini")
or f.endswith(".py")
@@ -603,7 +608,7 @@ def main():
detect_binary_inop_with_mima(hadoop_version)
# Since we did not build assembly/package before running dev/mima, we need to
# do it here because the tests still rely on it; see SPARK-13294 for details.
- build_spark_assembly_sbt(hadoop_version)
+ build_spark_assembly_sbt(hadoop_version, should_run_java_style_checks)
# run the test suites
run_scala_tests(build_tool, hadoop_version, test_modules, excluded_tags)
diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py
index a0e20d39c20da..3efe2adb6e2a4 100644
--- a/python/pyspark/sql/readwriter.py
+++ b/python/pyspark/sql/readwriter.py
@@ -177,7 +177,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None,
mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None,
multiLine=None, allowUnquotedControlChars=None, lineSep=None, samplingRatio=None,
- encoding=None):
+ dropFieldIfAllNull=None, encoding=None):
"""
Loads JSON files and returns the results as a :class:`DataFrame`.
@@ -246,6 +246,9 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
set, it covers all ``\\r``, ``\\r\\n`` and ``\\n``.
:param samplingRatio: defines fraction of input JSON objects used for schema inferring.
If None is set, it uses the default value, ``1.0``.
+ :param dropFieldIfAllNull: whether to ignore column of all null values or empty
+ array/struct during schema inference. If None is set, it
+ uses the default value, ``false``.
>>> df1 = spark.read.json('python/test_support/sql/people.json')
>>> df1.dtypes
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtil.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtil.java
index d224332d8a6c9..023ec139652c5 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtil.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtil.java
@@ -21,6 +21,9 @@
import java.io.Reader;
import javax.xml.namespace.QName;
+import javax.xml.parsers.DocumentBuilder;
+import javax.xml.parsers.DocumentBuilderFactory;
+import javax.xml.parsers.ParserConfigurationException;
import javax.xml.xpath.XPath;
import javax.xml.xpath.XPathConstants;
import javax.xml.xpath.XPathExpression;
@@ -37,9 +40,15 @@
* This is based on Hive's UDFXPathUtil implementation.
*/
public class UDFXPathUtil {
+ public static final String SAX_FEATURE_PREFIX = "http://xml.org/sax/features/";
+ public static final String EXTERNAL_GENERAL_ENTITIES_FEATURE = "external-general-entities";
+ public static final String EXTERNAL_PARAMETER_ENTITIES_FEATURE = "external-parameter-entities";
+ private DocumentBuilderFactory dbf = DocumentBuilderFactory.newInstance();
+ private DocumentBuilder builder = null;
private XPath xpath = XPathFactory.newInstance().newXPath();
private ReusableStringReader reader = new ReusableStringReader();
private InputSource inputSource = new InputSource(reader);
+
private XPathExpression expression = null;
private String oldPath = null;
@@ -65,14 +74,31 @@ public Object eval(String xml, String path, QName qname) throws XPathExpressionE
return null;
}
+ if (builder == null){
+ try {
+ initializeDocumentBuilderFactory();
+ builder = dbf.newDocumentBuilder();
+ } catch (ParserConfigurationException e) {
+ throw new RuntimeException(
+ "Error instantiating DocumentBuilder, cannot build xml parser", e);
+ }
+ }
+
reader.set(xml);
try {
- return expression.evaluate(inputSource, qname);
+ return expression.evaluate(builder.parse(inputSource), qname);
} catch (XPathExpressionException e) {
throw new RuntimeException("Invalid XML document: " + e.getMessage() + "\n" + xml, e);
+ } catch (Exception e) {
+ throw new RuntimeException("Error loading expression '" + oldPath + "'", e);
}
}
+ private void initializeDocumentBuilderFactory() throws ParserConfigurationException {
+ dbf.setFeature(SAX_FEATURE_PREFIX + EXTERNAL_GENERAL_ENTITIES_FEATURE, false);
+ dbf.setFeature(SAX_FEATURE_PREFIX + EXTERNAL_PARAMETER_ENTITIES_FEATURE, false);
+ }
+
public Boolean evalBoolean(String xml, String path) throws XPathExpressionException {
return (Boolean) eval(xml, path, XPathConstants.BOOLEAN);
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala
index 0b95a8821b05a..b47ec0b72c638 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala
@@ -132,7 +132,7 @@ object Encoders {
* - primitive types: boolean, int, double, etc.
* - boxed types: Boolean, Integer, Double, etc.
* - String
- * - java.math.BigDecimal
+ * - java.math.BigDecimal, java.math.BigInteger
* - time related: java.sql.Date, java.sql.Timestamp
* - collection types: only array and java.util.List currently, map support is in progress
* - nested java bean.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
index efc2882f0a3d3..cbea3c017a265 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
@@ -128,7 +128,7 @@ object ExpressionEncoder {
case b: BoundReference if b == originalInputObject => newInputObject
})
- if (enc.flat) {
+ val serializerExpr = if (enc.flat) {
newSerializer.head
} else {
// For non-flat encoder, the input object is not top level anymore after being combined to
@@ -146,6 +146,7 @@ object ExpressionEncoder {
Invoke(Literal.fromObject(None), "equals", BooleanType, newInputObject :: Nil))
If(nullCheck, Literal.create(null, struct.dataType), struct)
}
+ Alias(serializerExpr, s"_${index + 1}")()
}
val childrenDeserializers = encoders.zipWithIndex.map { case (enc, index) =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala
index 2ff12acb2946f..c081772116f84 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala
@@ -73,6 +73,9 @@ private[sql] class JSONOptions(
val columnNameOfCorruptRecord =
parameters.getOrElse("columnNameOfCorruptRecord", defaultColumnNameOfCorruptRecord)
+ // Whether to ignore column of all null values or empty array/struct during schema inference
+ val dropFieldIfAllNull = parameters.get("dropFieldIfAllNull").map(_.toBoolean).getOrElse(false)
+
val timeZone: TimeZone = DateTimeUtils.getTimeZone(
parameters.getOrElse(DateTimeUtils.TIMEZONE_OPTION, defaultTimeZoneId))
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtilSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtilSuite.scala
index c4cde7091154b..0fec15bc42c17 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtilSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtilSuite.scala
@@ -77,6 +77,27 @@ class UDFXPathUtilSuite extends SparkFunSuite {
assert(ret == "foo")
}
+ test("embedFailure") {
+ import org.apache.commons.io.FileUtils
+ import java.io.File
+ val secretValue = String.valueOf(Math.random)
+ val tempFile = File.createTempFile("verifyembed", ".tmp")
+ tempFile.deleteOnExit()
+ val fname = tempFile.getAbsolutePath
+
+ FileUtils.writeStringToFile(tempFile, secretValue)
+
+ val xml =
+ s"""
+ |
+ |]>
+ |&embed;
+ """.stripMargin
+ val evaled = new UDFXPathUtil().evalString(xml, "/foo")
+ assert(evaled.isEmpty)
+ }
+
test("number eval") {
var ret =
util.evalNumber("truefalseb3c1-77", "a/c[2]")
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/XPathExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/XPathExpressionSuite.scala
index bfa18a0919e45..c6f6d3abb860c 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/XPathExpressionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/XPathExpressionSuite.scala
@@ -40,8 +40,9 @@ class XPathExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
// Test error message for invalid XML document
val e1 = intercept[RuntimeException] { testExpr("/a>", "a", null.asInstanceOf[T]) }
- assert(e1.getCause.getMessage.contains("Invalid XML document") &&
- e1.getCause.getMessage.contains("/a>"))
+ assert(e1.getCause.getCause.getMessage.contains(
+ "XML document structures must start and end within the same entity."))
+ assert(e1.getMessage.contains("/a>"))
// Test error message for invalid xpath
val e2 = intercept[RuntimeException] { testExpr("", "!#$", null.asInstanceOf[T]) }
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java
index a79080a249ec8..926396414816c 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java
@@ -23,10 +23,9 @@
* A mix in interface for {@link DataSourceReader}. Data source readers can implement this
* interface to report statistics to Spark.
*
- * Statistics are reported to the optimizer before a projection or any filters are pushed to the
- * DataSourceReader. Implementations that return more accurate statistics based on projection and
- * filters will not improve query performance until the planner can push operators before getting
- * stats.
+ * Statistics are reported to the optimizer before any operator is pushed to the DataSourceReader.
+ * Implementations that return more accurate statistics based on pushed operators will not improve
+ * query performance until the planner can push operators before getting stats.
*/
@InterfaceStability.Evolving
public interface SupportsReportStatistics extends DataSourceReader {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
index de6be5f76e15a..ec9352a7fa055 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
@@ -381,6 +381,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
* that should be used for parsing.
*
`samplingRatio` (default is 1.0): defines fraction of input JSON objects used
* for schema inferring.
+ * `dropFieldIfAllNull` (default `false`): whether to ignore column of all null values or
+ * empty array/struct during schema inference.
*
*
* @since 2.0.0
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
index 0b4dd76c7d860..997cf92449c68 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
@@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.QueryPlan
-import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning}
+import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, SparkPlan, WholeStageCodegenExec}
import org.apache.spark.sql.execution.vectorized._
import org.apache.spark.sql.types._
@@ -169,8 +169,8 @@ case class InMemoryTableScanExec(
// But the cached version could alias output, so we need to replace output.
override def outputPartitioning: Partitioning = {
relation.cachedPlan.outputPartitioning match {
- case h: HashPartitioning => updateAttribute(h).asInstanceOf[HashPartitioning]
- case _ => relation.cachedPlan.outputPartitioning
+ case e: Expression => updateAttribute(e).asInstanceOf[Partitioning]
+ case other => other
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala
index e7eed95a560a3..f6edc7bfb3750 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala
@@ -75,7 +75,7 @@ private[sql] object JsonInferSchema {
// active SparkSession and `SQLConf.get` may point to the wrong configs.
val rootType = mergedTypesFromPartitions.toLocalIterator.fold(StructType(Nil))(typeMerger)
- canonicalizeType(rootType) match {
+ canonicalizeType(rootType, configOptions) match {
case Some(st: StructType) => st
case _ =>
// canonicalizeType erases all empty structs, including the only one we want to keep
@@ -181,33 +181,33 @@ private[sql] object JsonInferSchema {
}
/**
- * Convert NullType to StringType and remove StructTypes with no fields
+ * Recursively canonicalizes inferred types, e.g., removes StructTypes with no fields,
+ * drops NullTypes or converts them to StringType based on provided options.
*/
- private def canonicalizeType(tpe: DataType): Option[DataType] = tpe match {
- case at @ ArrayType(elementType, _) =>
- for {
- canonicalType <- canonicalizeType(elementType)
- } yield {
- at.copy(canonicalType)
- }
+ private def canonicalizeType(tpe: DataType, options: JSONOptions): Option[DataType] = tpe match {
+ case at: ArrayType =>
+ canonicalizeType(at.elementType, options)
+ .map(t => at.copy(elementType = t))
case StructType(fields) =>
- val canonicalFields: Array[StructField] = for {
- field <- fields
- if field.name.length > 0
- canonicalType <- canonicalizeType(field.dataType)
- } yield {
- field.copy(dataType = canonicalType)
+ val canonicalFields = fields.filter(_.name.nonEmpty).flatMap { f =>
+ canonicalizeType(f.dataType, options)
+ .map(t => f.copy(dataType = t))
}
-
- if (canonicalFields.length > 0) {
- Some(StructType(canonicalFields))
+ // SPARK-8093: empty structs should be deleted
+ if (canonicalFields.isEmpty) {
+ None
} else {
- // per SPARK-8093: empty structs should be deleted
+ Some(StructType(canonicalFields))
+ }
+
+ case NullType =>
+ if (options.dropFieldIfAllNull) {
None
+ } else {
+ Some(StringType)
}
- case NullType => Some(StringType)
case other => Some(other)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala
index e08af218513fd..7613eb210c659 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala
@@ -23,17 +23,24 @@ import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression}
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics}
-import org.apache.spark.sql.execution.datasources.DataSourceStrategy
-import org.apache.spark.sql.sources.{DataSourceRegister, Filter}
+import org.apache.spark.sql.sources.DataSourceRegister
import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, ReadSupport, ReadSupportWithSchema}
-import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, SupportsPushDownCatalystFilters, SupportsPushDownFilters, SupportsPushDownRequiredColumns, SupportsReportStatistics}
+import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, SupportsReportStatistics}
import org.apache.spark.sql.types.StructType
+/**
+ * A logical plan representing a data source v2 scan.
+ *
+ * @param source An instance of a [[DataSourceV2]] implementation.
+ * @param options The options for this scan. Used to create fresh [[DataSourceReader]].
+ * @param userSpecifiedSchema The user-specified schema for this scan. Used to create fresh
+ * [[DataSourceReader]].
+ */
case class DataSourceV2Relation(
source: DataSourceV2,
output: Seq[AttributeReference],
options: Map[String, String],
- userSpecifiedSchema: Option[StructType] = None)
+ userSpecifiedSchema: Option[StructType])
extends LeafNode with MultiInstanceRelation with DataSourceV2StringFormat {
import DataSourceV2Relation._
@@ -42,14 +49,7 @@ case class DataSourceV2Relation(
override def simpleString: String = "RelationV2 " + metadataString
- lazy val v2Options: DataSourceOptions = makeV2Options(options)
-
- def newReader: DataSourceReader = userSpecifiedSchema match {
- case Some(userSchema) =>
- source.asReadSupportWithSchema.createReader(userSchema, v2Options)
- case None =>
- source.asReadSupport.createReader(v2Options)
- }
+ def newReader(): DataSourceReader = source.createReader(options, userSpecifiedSchema)
override def computeStats(): Statistics = newReader match {
case r: SupportsReportStatistics =>
@@ -139,83 +139,26 @@ object DataSourceV2Relation {
source.getClass.getSimpleName
}
}
- }
-
- private def makeV2Options(options: Map[String, String]): DataSourceOptions = {
- new DataSourceOptions(options.asJava)
- }
- private def schema(
- source: DataSourceV2,
- v2Options: DataSourceOptions,
- userSchema: Option[StructType]): StructType = {
- val reader = userSchema match {
- case Some(s) =>
- source.asReadSupportWithSchema.createReader(s, v2Options)
- case _ =>
- source.asReadSupport.createReader(v2Options)
+ def createReader(
+ options: Map[String, String],
+ userSpecifiedSchema: Option[StructType]): DataSourceReader = {
+ val v2Options = new DataSourceOptions(options.asJava)
+ userSpecifiedSchema match {
+ case Some(s) =>
+ asReadSupportWithSchema.createReader(s, v2Options)
+ case _ =>
+ asReadSupport.createReader(v2Options)
+ }
}
- reader.readSchema()
}
def create(
source: DataSourceV2,
options: Map[String, String],
- userSpecifiedSchema: Option[StructType] = None): DataSourceV2Relation = {
- val output = schema(source, makeV2Options(options), userSpecifiedSchema).toAttributes
- DataSourceV2Relation(source, output, options, userSpecifiedSchema)
- }
-
- def pushRequiredColumns(
- relation: DataSourceV2Relation,
- reader: DataSourceReader,
- struct: StructType): Seq[AttributeReference] = {
- reader match {
- case projectionSupport: SupportsPushDownRequiredColumns =>
- projectionSupport.pruneColumns(struct)
- // return the output columns from the relation that were projected
- val attrMap = relation.output.map(a => a.name -> a).toMap
- projectionSupport.readSchema().map(f => attrMap(f.name))
- case _ =>
- relation.output
- }
- }
-
- def pushFilters(
- reader: DataSourceReader,
- filters: Seq[Expression]): (Seq[Expression], Seq[Expression]) = {
- reader match {
- case r: SupportsPushDownCatalystFilters =>
- val postScanFilters = r.pushCatalystFilters(filters.toArray)
- val pushedFilters = r.pushedCatalystFilters()
- (postScanFilters, pushedFilters)
-
- case r: SupportsPushDownFilters =>
- // A map from translated data source filters to original catalyst filter expressions.
- val translatedFilterToExpr = scala.collection.mutable.HashMap.empty[Filter, Expression]
- // Catalyst filter expression that can't be translated to data source filters.
- val untranslatableExprs = scala.collection.mutable.ArrayBuffer.empty[Expression]
-
- for (filterExpr <- filters) {
- val translated = DataSourceStrategy.translateFilter(filterExpr)
- if (translated.isDefined) {
- translatedFilterToExpr(translated.get) = filterExpr
- } else {
- untranslatableExprs += filterExpr
- }
- }
-
- // Data source filters that need to be evaluated again after scanning. which means
- // the data source cannot guarantee the rows returned can pass these filters.
- // As a result we must return it so Spark can plan an extra filter operator.
- val postScanFilters =
- r.pushFilters(translatedFilterToExpr.keys.toArray).map(translatedFilterToExpr)
- // The filters which are marked as pushed to this data source
- val pushedFilters = r.pushedFilters().map(translatedFilterToExpr)
-
- (untranslatableExprs ++ postScanFilters, pushedFilters)
-
- case _ => (filters, Nil)
- }
+ userSpecifiedSchema: Option[StructType]): DataSourceV2Relation = {
+ val reader = source.createReader(options, userSpecifiedSchema)
+ DataSourceV2Relation(
+ source, reader.readSchema().toAttributes, options, userSpecifiedSchema)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
index 8bf858c38d76c..182aa2906cf1e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
@@ -17,51 +17,115 @@
package org.apache.spark.sql.execution.datasources.v2
-import org.apache.spark.sql.{execution, Strategy}
-import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, AttributeSet}
+import scala.collection.mutable
+
+import org.apache.spark.sql.{sources, Strategy}
+import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, AttributeSet, Expression}
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
-import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan}
+import org.apache.spark.sql.execution.datasources.DataSourceStrategy
import org.apache.spark.sql.execution.streaming.continuous.{WriteToContinuousDataSource, WriteToContinuousDataSourceExec}
+import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, SupportsPushDownCatalystFilters, SupportsPushDownFilters, SupportsPushDownRequiredColumns}
object DataSourceV2Strategy extends Strategy {
- override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
- case PhysicalOperation(project, filters, relation: DataSourceV2Relation) =>
- val projectSet = AttributeSet(project.flatMap(_.references))
- val filterSet = AttributeSet(filters.flatMap(_.references))
-
- val projection = if (filterSet.subsetOf(projectSet) &&
- AttributeSet(relation.output) == projectSet) {
- // When the required projection contains all of the filter columns and column pruning alone
- // can produce the required projection, push the required projection.
- // A final projection may still be needed if the data source produces a different column
- // order or if it cannot prune all of the nested columns.
- relation.output
- } else {
- // When there are filter columns not already in the required projection or when the required
- // projection is more complicated than column pruning, base column pruning on the set of
- // all columns needed by both.
- (projectSet ++ filterSet).toSeq
- }
- val reader = relation.newReader
+ /**
+ * Pushes down filters to the data source reader
+ *
+ * @return pushed filter and post-scan filters.
+ */
+ private def pushFilters(
+ reader: DataSourceReader,
+ filters: Seq[Expression]): (Seq[Expression], Seq[Expression]) = {
+ reader match {
+ case r: SupportsPushDownCatalystFilters =>
+ val postScanFilters = r.pushCatalystFilters(filters.toArray)
+ val pushedFilters = r.pushedCatalystFilters()
+ (pushedFilters, postScanFilters)
+
+ case r: SupportsPushDownFilters =>
+ // A map from translated data source filters to original catalyst filter expressions.
+ val translatedFilterToExpr = mutable.HashMap.empty[sources.Filter, Expression]
+ // Catalyst filter expression that can't be translated to data source filters.
+ val untranslatableExprs = mutable.ArrayBuffer.empty[Expression]
+
+ for (filterExpr <- filters) {
+ val translated = DataSourceStrategy.translateFilter(filterExpr)
+ if (translated.isDefined) {
+ translatedFilterToExpr(translated.get) = filterExpr
+ } else {
+ untranslatableExprs += filterExpr
+ }
+ }
+
+ // Data source filters that need to be evaluated again after scanning. which means
+ // the data source cannot guarantee the rows returned can pass these filters.
+ // As a result we must return it so Spark can plan an extra filter operator.
+ val postScanFilters = r.pushFilters(translatedFilterToExpr.keys.toArray)
+ .map(translatedFilterToExpr)
+ // The filters which are marked as pushed to this data source
+ val pushedFilters = r.pushedFilters().map(translatedFilterToExpr)
+ (pushedFilters, untranslatableExprs ++ postScanFilters)
+
+ case _ => (Nil, filters)
+ }
+ }
- val output = DataSourceV2Relation.pushRequiredColumns(relation, reader,
- projection.asInstanceOf[Seq[AttributeReference]].toStructType)
+ /**
+ * Applies column pruning to the data source, w.r.t. the references of the given expressions.
+ *
+ * @return new output attributes after column pruning.
+ */
+ // TODO: nested column pruning.
+ private def pruneColumns(
+ reader: DataSourceReader,
+ relation: DataSourceV2Relation,
+ exprs: Seq[Expression]): Seq[AttributeReference] = {
+ reader match {
+ case r: SupportsPushDownRequiredColumns =>
+ val requiredColumns = AttributeSet(exprs.flatMap(_.references))
+ val neededOutput = relation.output.filter(requiredColumns.contains)
+ if (neededOutput != relation.output) {
+ r.pruneColumns(neededOutput.toStructType)
+ val nameToAttr = relation.output.map(_.name).zip(relation.output).toMap
+ r.readSchema().toAttributes.map {
+ // We have to keep the attribute id during transformation.
+ a => a.withExprId(nameToAttr(a.name).exprId)
+ }
+ } else {
+ relation.output
+ }
+
+ case _ => relation.output
+ }
+ }
- val (postScanFilters, pushedFilters) = DataSourceV2Relation.pushFilters(reader, filters)
- logInfo(s"Post-Scan Filters: ${postScanFilters.mkString(",")}")
- logInfo(s"Pushed Filters: ${pushedFilters.mkString(", ")}")
+ override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
+ case PhysicalOperation(project, filters, relation: DataSourceV2Relation) =>
+ val reader = relation.newReader()
+ // `pushedFilters` will be pushed down and evaluated in the underlying data sources.
+ // `postScanFilters` need to be evaluated after the scan.
+ // `postScanFilters` and `pushedFilters` can overlap, e.g. the parquet row group filter.
+ val (pushedFilters, postScanFilters) = pushFilters(reader, filters)
+ val output = pruneColumns(reader, relation, project ++ postScanFilters)
+ logInfo(
+ s"""
+ |Pushing operators to ${relation.source.getClass}
+ |Pushed Filters: ${pushedFilters.mkString(", ")}
+ |Post-Scan Filters: ${postScanFilters.mkString(",")}
+ |Output: ${output.mkString(", ")}
+ """.stripMargin)
val scan = DataSourceV2ScanExec(
output, relation.source, relation.options, pushedFilters, reader)
- val filter = postScanFilters.reduceLeftOption(And)
- val withFilter = filter.map(execution.FilterExec(_, scan)).getOrElse(scan)
+ val filterCondition = postScanFilters.reduceLeftOption(And)
+ val withFilter = filterCondition.map(FilterExec(_, scan)).getOrElse(scan)
val withProjection = if (withFilter.output != project) {
- execution.ProjectExec(project, withFilter)
+ ProjectExec(project, withFilter)
} else {
withFilter
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala
index 09f79a2de0ba0..1a5b7599bb7d9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala
@@ -24,7 +24,7 @@ import org.apache.spark.broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Expression, SortOrder}
-import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning}
+import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.{LeafExecNode, SparkPlan, UnaryExecNode}
import org.apache.spark.sql.internal.SQLConf
@@ -70,7 +70,7 @@ case class ReusedExchangeExec(override val output: Seq[Attribute], child: Exchan
}
override def outputPartitioning: Partitioning = child.outputPartitioning match {
- case h: HashPartitioning => h.copy(expressions = h.expressions.map(updateAttr))
+ case e: Expression => updateAttr(e).asInstanceOf[Partitioning]
case other => other
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
index ae93965bc50ed..ef8dc3a325a33 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
@@ -270,6 +270,8 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo
* per file
* `lineSep` (default covers all `\r`, `\r\n` and `\n`): defines the line separator
* that should be used for parsing.
+ * `dropFieldIfAllNull` (default `false`): whether to ignore column of all null values or
+ * empty array/struct during schema inference.
*
*
* @since 2.0.0
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index c132cab1b38cf..2c695fc58fd8c 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -34,6 +34,7 @@
import org.junit.*;
import org.junit.rules.ExpectedException;
+import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.*;
import org.apache.spark.sql.*;
@@ -336,6 +337,23 @@ public void testTupleEncoder() {
Assert.assertEquals(data5, ds5.collectAsList());
}
+ @Test
+ public void testTupleEncoderSchema() {
+ Encoder>> encoder =
+ Encoders.tuple(Encoders.STRING(), Encoders.tuple(Encoders.STRING(), Encoders.STRING()));
+ List>> data = Arrays.asList(tuple2("1", tuple2("a", "b")),
+ tuple2("2", tuple2("c", "d")));
+ Dataset ds1 = spark.createDataset(data, encoder).toDF("value1", "value2");
+
+ JavaPairRDD> pairRDD = jsc.parallelizePairs(data);
+ Dataset ds2 = spark.createDataset(JavaPairRDD.toRDD(pairRDD), encoder)
+ .toDF("value1", "value2");
+
+ Assert.assertEquals(ds1.schema(), ds2.schema());
+ Assert.assertEquals(ds1.select(expr("value2._1")).collectAsList(),
+ ds2.select(expr("value2._1")).collectAsList());
+ }
+
@Test
public void testNestedTupleEncoder() {
// test ((int, string), string)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
index 81b7e18773f81..6982c22f4771d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
@@ -83,25 +83,6 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
}.sum
}
- test("withColumn doesn't invalidate cached dataframe") {
- var evalCount = 0
- val myUDF = udf((x: String) => { evalCount += 1; "result" })
- val df = Seq(("test", 1)).toDF("s", "i").select(myUDF($"s"))
- df.cache()
-
- df.collect()
- assert(evalCount === 1)
-
- df.collect()
- assert(evalCount === 1)
-
- val df2 = df.withColumn("newColumn", lit(1))
- df2.collect()
-
- // We should not reevaluate the cached dataframe
- assert(evalCount === 1)
- }
-
test("cache temp table") {
withTempView("tempTable") {
testData.select('key).createOrReplaceTempView("tempTable")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala
index e0561ee2797a5..82a93f74dd76c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala
@@ -17,12 +17,15 @@
package org.apache.spark.sql
+import org.scalatest.concurrent.TimeLimits
+import org.scalatest.time.SpanSugar._
+
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.storage.StorageLevel
-class DatasetCacheSuite extends QueryTest with SharedSQLContext {
+class DatasetCacheSuite extends QueryTest with SharedSQLContext with TimeLimits {
import testImplicits._
test("get storage level") {
@@ -96,4 +99,37 @@ class DatasetCacheSuite extends QueryTest with SharedSQLContext {
agged.unpersist()
assert(agged.storageLevel == StorageLevel.NONE, "The Dataset agged should not be cached.")
}
+
+ test("persist and then withColumn") {
+ val df = Seq(("test", 1)).toDF("s", "i")
+ val df2 = df.withColumn("newColumn", lit(1))
+
+ df.cache()
+ assertCached(df)
+ assertCached(df2)
+
+ df.count()
+ assertCached(df2)
+
+ df.unpersist()
+ assert(df.storageLevel == StorageLevel.NONE)
+ }
+
+ test("cache UDF result correctly") {
+ val expensiveUDF = udf({x: Int => Thread.sleep(10000); x})
+ val df = spark.range(0, 10).toDF("a").withColumn("b", expensiveUDF($"a"))
+ val df2 = df.agg(sum(df("b")))
+
+ df.cache()
+ df.count()
+ assertCached(df2)
+
+ // udf has been evaluated during caching, and thus should not be re-evaluated here
+ failAfter(5 seconds) {
+ df2.collect()
+ }
+
+ df.unpersist()
+ assert(df.storageLevel == StorageLevel.NONE)
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 562a756b50ecd..4f3e3de973eea 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -1467,13 +1467,26 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
intercept[NullPointerException](ds.as[(Int, Int)].collect())
}
+ test("SPARK-24548: Dataset with tuple encoders should have correct schema") {
+ val encoder = Encoders.tuple(newStringEncoder,
+ Encoders.tuple(newStringEncoder, newStringEncoder))
+
+ val data = Seq(("a", ("1", "2")), ("b", ("3", "4")))
+ val rdd = sparkContext.parallelize(data)
+
+ val ds1 = spark.createDataset(rdd)
+ val ds2 = spark.createDataset(rdd)(encoder)
+ assert(ds1.schema == ds2.schema)
+ checkDataset(ds1.select("_2._2"), ds2.select("_2._2").collect(): _*)
+ }
+
test("SPARK-24571: filtering of string values by char literal") {
val df = Seq("Amsterdam", "San Francisco", "X").toDF("city")
checkAnswer(df.where('city === 'X'), Seq(Row("X")))
checkAnswer(
df.where($"city".contains(new java.lang.Character('A'))),
Seq(Row("Amsterdam")))
- }
+ }
}
case class TestDataUnion(x: Int, y: Int, z: Int)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index 37d468739c613..d254345e8fa54 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -18,13 +18,13 @@
package org.apache.spark.sql.execution
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{execution, Row}
+import org.apache.spark.sql.{execution, DataFrame, Row}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Range, Repartition, Sort, Union}
import org.apache.spark.sql.catalyst.plans.physical._
-import org.apache.spark.sql.execution.columnar.InMemoryRelation
+import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec}
import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReusedExchangeExec, ReuseExchange, ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.functions._
@@ -703,6 +703,66 @@ class PlannerSuite extends SharedSQLContext {
Range(1, 2, 1, 1)))
df.queryExecution.executedPlan.execute()
}
+
+ test("SPARK-24556: always rewrite output partitioning in ReusedExchangeExec " +
+ "and InMemoryTableScanExec") {
+ def checkOutputPartitioningRewrite(
+ plans: Seq[SparkPlan],
+ expectedPartitioningClass: Class[_]): Unit = {
+ assert(plans.size == 1)
+ val plan = plans.head
+ val partitioning = plan.outputPartitioning
+ assert(partitioning.getClass == expectedPartitioningClass)
+ val partitionedAttrs = partitioning.asInstanceOf[Expression].references
+ assert(partitionedAttrs.subsetOf(plan.outputSet))
+ }
+
+ def checkReusedExchangeOutputPartitioningRewrite(
+ df: DataFrame,
+ expectedPartitioningClass: Class[_]): Unit = {
+ val reusedExchange = df.queryExecution.executedPlan.collect {
+ case r: ReusedExchangeExec => r
+ }
+ checkOutputPartitioningRewrite(reusedExchange, expectedPartitioningClass)
+ }
+
+ def checkInMemoryTableScanOutputPartitioningRewrite(
+ df: DataFrame,
+ expectedPartitioningClass: Class[_]): Unit = {
+ val inMemoryScan = df.queryExecution.executedPlan.collect {
+ case m: InMemoryTableScanExec => m
+ }
+ checkOutputPartitioningRewrite(inMemoryScan, expectedPartitioningClass)
+ }
+
+ // ReusedExchange is HashPartitioning
+ val df1 = Seq(1 -> "a").toDF("i", "j").repartition($"i")
+ val df2 = Seq(1 -> "a").toDF("i", "j").repartition($"i")
+ checkReusedExchangeOutputPartitioningRewrite(df1.union(df2), classOf[HashPartitioning])
+
+ // ReusedExchange is RangePartitioning
+ val df3 = Seq(1 -> "a").toDF("i", "j").orderBy($"i")
+ val df4 = Seq(1 -> "a").toDF("i", "j").orderBy($"i")
+ checkReusedExchangeOutputPartitioningRewrite(df3.union(df4), classOf[RangePartitioning])
+
+ // InMemoryTableScan is HashPartitioning
+ Seq(1 -> "a").toDF("i", "j").repartition($"i").persist()
+ checkInMemoryTableScanOutputPartitioningRewrite(
+ Seq(1 -> "a").toDF("i", "j").repartition($"i"), classOf[HashPartitioning])
+
+ // InMemoryTableScan is RangePartitioning
+ spark.range(1, 100, 1, 10).toDF().persist()
+ checkInMemoryTableScanOutputPartitioningRewrite(
+ spark.range(1, 100, 1, 10).toDF(), classOf[RangePartitioning])
+
+ // InMemoryTableScan is PartitioningCollection
+ withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
+ Seq(1 -> "a").toDF("i", "j").join(Seq(1 -> "a").toDF("m", "n"), $"i" === $"m").persist()
+ checkInMemoryTableScanOutputPartitioningRewrite(
+ Seq(1 -> "a").toDF("i", "j").join(Seq(1 -> "a").toDF("m", "n"), $"i" === $"m"),
+ classOf[PartitioningCollection])
+ }
+ }
}
// Used for unit-testing EnsureRequirements
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
index 4b3921c61a000..a8a4a524a97f9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
@@ -2427,4 +2427,53 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
spark.read.option("mode", "PERMISSIVE").option("encoding", "UTF-8").json(Seq(badJson).toDS()),
Row(badJson))
}
+
+ test("SPARK-23772 ignore column of all null values or empty array during schema inference") {
+ withTempPath { tempDir =>
+ val path = tempDir.getAbsolutePath
+
+ // primitive types
+ Seq(
+ """{"a":null, "b":1, "c":3.0}""",
+ """{"a":null, "b":null, "c":"string"}""",
+ """{"a":null, "b":null, "c":null}""")
+ .toDS().write.text(path)
+ var df = spark.read.format("json")
+ .option("dropFieldIfAllNull", true)
+ .load(path)
+ var expectedSchema = new StructType()
+ .add("b", LongType).add("c", StringType)
+ assert(df.schema === expectedSchema)
+ checkAnswer(df, Row(1, "3.0") :: Row(null, "string") :: Row(null, null) :: Nil)
+
+ // arrays
+ Seq(
+ """{"a":[2, 1], "b":[null, null], "c":null, "d":[[], [null]], "e":[[], null, [[]]]}""",
+ """{"a":[null], "b":[null], "c":[], "d":[null, []], "e":null}""",
+ """{"a":null, "b":null, "c":[], "d":null, "e":[null, [], null]}""")
+ .toDS().write.mode("overwrite").text(path)
+ df = spark.read.format("json")
+ .option("dropFieldIfAllNull", true)
+ .load(path)
+ expectedSchema = new StructType()
+ .add("a", ArrayType(LongType))
+ assert(df.schema === expectedSchema)
+ checkAnswer(df, Row(Array(2, 1)) :: Row(Array(null)) :: Row(null) :: Nil)
+
+ // structs
+ Seq(
+ """{"a":{"a1": 1, "a2":"string"}, "b":{}}""",
+ """{"a":{"a1": 2, "a2":null}, "b":{"b1":[null]}}""",
+ """{"a":null, "b":null}""")
+ .toDS().write.mode("overwrite").text(path)
+ df = spark.read.format("json")
+ .option("dropFieldIfAllNull", true)
+ .load(path)
+ expectedSchema = new StructType()
+ .add("a", StructType(StructField("a1", LongType) :: StructField("a2", StringType)
+ :: Nil))
+ assert(df.schema === expectedSchema)
+ checkAnswer(df, Row(Row(1, "string")) :: Row(Row(2, null)) :: Row(null) :: Nil)
+ }
+ }
}