Skip to content

Commit

Permalink
fix MapType in JsonRDD
Browse files Browse the repository at this point in the history
  • Loading branch information
Davies Liu committed May 20, 2015
1 parent 6338c40 commit b97af11
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 9 deletions.
22 changes: 22 additions & 0 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,28 @@ def test_save_and_load(self):

shutil.rmtree(tmpPath)

def test_json_with_map(self):
# regression test for SPARK-7565
tmpPath = tempfile.mkdtemp()
shutil.rmtree(tmpPath)
rdd = self.sc.parallelize(['{"obj": {"a": "hello"}}', '{"obj": {"b": "world"}}'], 1)
schema = StructType([StructField("obj", MapType(StringType(), StringType()), True)])
df = self.sqlCtx.jsonRDD(rdd, schema)
rs = [({'a': 'hello'},), ({'b': 'world'},)]
self.assertEqual(rs, list(map(tuple, df.collect())))
df.write.parquet(tmpPath)
df2 = self.sqlCtx.read.parquet(tmpPath)
self.assertEqual(rs, list(map(tuple, df2.collect())))

self.sqlCtx.setConf("spark.sql.json.useJacksonStreamingAPI", "false")
df3 = self.sqlCtx.jsonRDD(rdd, schema)
self.assertEqual(rs, list(map(tuple, df3.collect())))
rs = list(map(tuple, df.collect()))
df.write.parquet(tmpPath, 'overwrite')
df4 = self.sqlCtx.read.parquet(tmpPath)
self.assertEqual(rs, list(map(tuple, df4.collect())))
shutil.rmtree(tmpPath)

def test_help_command(self):
# Regression test for SPARK-5464
rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,10 +150,10 @@ private[sql] object JacksonParser {
private def convertMap(
factory: JsonFactory,
parser: JsonParser,
valueType: DataType): Map[String, Any] = {
val builder = Map.newBuilder[String, Any]
valueType: DataType): Map[UTF8String, Any] = {
val builder = Map.newBuilder[UTF8String, Any]
while (nextUntil(parser, JsonToken.END_OBJECT)) {
builder += parser.getCurrentName -> convertField(factory, parser, valueType)
builder += UTF8String(parser.getCurrentName) -> convertField(factory, parser, valueType)
}

builder.result()
Expand Down Expand Up @@ -181,7 +181,7 @@ private[sql] object JacksonParser {
val row = new GenericMutableRow(schema.length)
for (corruptIndex <- schema.getFieldIndex(columnNameOfCorruptRecords)) {
require(schema(corruptIndex).dataType == StringType)
row.update(corruptIndex, record)
row.update(corruptIndex, UTF8String(record))
}

Seq(row)
Expand Down
13 changes: 8 additions & 5 deletions sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,18 @@ package org.apache.spark.sql.json
import java.sql.Timestamp

import scala.collection.Map
import scala.collection.convert.Wrappers.{JMapWrapper, JListWrapper}
import scala.collection.convert.Wrappers.{JListWrapper, JMapWrapper}

import com.fasterxml.jackson.core.{JsonGenerator, JsonProcessingException}
import com.fasterxml.jackson.core.JsonProcessingException
import com.fasterxml.jackson.databind.ObjectMapper

import org.apache.spark.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.util.DateUtils
import org.apache.spark.sql.types._
import org.apache.spark.Logging

private[sql] object JsonRDD extends Logging {

Expand Down Expand Up @@ -422,7 +422,10 @@ private[sql] object JsonRDD extends Logging {
value.asInstanceOf[Seq[Any]].map(enforceCorrectType(_, elementType))
case MapType(StringType, valueType, _) =>
val map = value.asInstanceOf[Map[String, Any]]
map.mapValues(enforceCorrectType(_, valueType)).map(identity)
map.map {
case (k, v) =>
(UTF8String(k), enforceCorrectType(v, valueType))
}.map(identity)
case struct: StructType => asRow(value.asInstanceOf[Map[String, Any]], struct)
case DateType => toDate(value)
case TimestampType => toTimestamp(value)
Expand Down

0 comments on commit b97af11

Please sign in to comment.