From b97af11547c65cb20ab869e0d81c5f869a6935ec Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 20 May 2015 15:48:24 -0700 Subject: [PATCH] fix MapType in JsonRDD --- python/pyspark/sql/tests.py | 22 +++++++++++++++++++ .../apache/spark/sql/json/JacksonParser.scala | 8 +++---- .../org/apache/spark/sql/json/JsonRDD.scala | 13 ++++++----- 3 files changed, 34 insertions(+), 9 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 7e349962416c9..42e91520a6781 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -512,6 +512,28 @@ def test_save_and_load(self): shutil.rmtree(tmpPath) + def test_json_with_map(self): + # regression test for SPARK-7565 + tmpPath = tempfile.mkdtemp() + shutil.rmtree(tmpPath) + rdd = self.sc.parallelize(['{"obj": {"a": "hello"}}', '{"obj": {"b": "world"}}'], 1) + schema = StructType([StructField("obj", MapType(StringType(), StringType()), True)]) + df = self.sqlCtx.jsonRDD(rdd, schema) + rs = [({'a': 'hello'},), ({'b': 'world'},)] + self.assertEqual(rs, list(map(tuple, df.collect()))) + df.write.parquet(tmpPath) + df2 = self.sqlCtx.read.parquet(tmpPath) + self.assertEqual(rs, list(map(tuple, df2.collect()))) + + self.sqlCtx.setConf("spark.sql.json.useJacksonStreamingAPI", "false") + df3 = self.sqlCtx.jsonRDD(rdd, schema) + self.assertEqual(rs, list(map(tuple, df3.collect()))) + rs = list(map(tuple, df.collect())) + df.write.parquet(tmpPath, 'overwrite') + df4 = self.sqlCtx.read.parquet(tmpPath) + self.assertEqual(rs, list(map(tuple, df4.collect()))) + shutil.rmtree(tmpPath) + def test_help_command(self): # Regression test for SPARK-5464 rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala index 81611513582a8..0e223758051a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala @@ -150,10 +150,10 @@ private[sql] object JacksonParser { private def convertMap( factory: JsonFactory, parser: JsonParser, - valueType: DataType): Map[String, Any] = { - val builder = Map.newBuilder[String, Any] + valueType: DataType): Map[UTF8String, Any] = { + val builder = Map.newBuilder[UTF8String, Any] while (nextUntil(parser, JsonToken.END_OBJECT)) { - builder += parser.getCurrentName -> convertField(factory, parser, valueType) + builder += UTF8String(parser.getCurrentName) -> convertField(factory, parser, valueType) } builder.result() @@ -181,7 +181,7 @@ private[sql] object JacksonParser { val row = new GenericMutableRow(schema.length) for (corruptIndex <- schema.getFieldIndex(columnNameOfCorruptRecords)) { require(schema(corruptIndex).dataType == StringType) - row.update(corruptIndex, record) + row.update(corruptIndex, UTF8String(record)) } Seq(row) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index 4c32710a17bc7..036a4b08ca27a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -20,18 +20,18 @@ package org.apache.spark.sql.json import java.sql.Timestamp import scala.collection.Map -import scala.collection.convert.Wrappers.{JMapWrapper, JListWrapper} +import scala.collection.convert.Wrappers.{JListWrapper, JMapWrapper} -import com.fasterxml.jackson.core.{JsonGenerator, JsonProcessingException} +import com.fasterxml.jackson.core.JsonProcessingException import com.fasterxml.jackson.databind.ObjectMapper +import org.apache.spark.Logging import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.util.DateUtils import org.apache.spark.sql.types._ -import org.apache.spark.Logging private[sql] object JsonRDD extends Logging { @@ -422,7 +422,10 @@ private[sql] object JsonRDD extends Logging { value.asInstanceOf[Seq[Any]].map(enforceCorrectType(_, elementType)) case MapType(StringType, valueType, _) => val map = value.asInstanceOf[Map[String, Any]] - map.mapValues(enforceCorrectType(_, valueType)).map(identity) + map.map { + case (k, v) => + (UTF8String(k), enforceCorrectType(v, valueType)) + }.map(identity) case struct: StructType => asRow(value.asInstanceOf[Map[String, Any]], struct) case DateType => toDate(value) case TimestampType => toTimestamp(value)