From 43a45e170577198b2c424e45f7c90dfa928031a7 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Thu, 10 Jul 2014 20:10:58 -0700 Subject: [PATCH] Remove sql.util.package introduced in a previous commit. --- .../org/apache/spark/sql/SQLContext.scala | 6 +- .../org/apache/spark/sql/json/JsonRDD.scala | 66 ++++--- .../org/apache/spark/sql/util/package.scala | 175 ------------------ .../org/apache/spark/sql/json/JsonSuite.scala | 20 +- 4 files changed, 57 insertions(+), 210 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/util/package.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 99aaffe1f5ce4..024dc337cd047 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -94,7 +94,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * @group userf */ def applySchema[A](rdd: RDD[A],schema: StructType, f: A => Row): SchemaRDD = - applySchemaPartitions(rdd, schema, (iter: Iterator[A]) => iter.map(f)) + applySchemaToPartitions(rdd, schema, (iter: Iterator[A]) => iter.map(f)) /** * Creates a [[SchemaRDD]] from an [[RDD]] by applying a schema to this RDD and using a function @@ -102,7 +102,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * * @group userf */ - def applySchemaPartitions[A]( + def applySchemaToPartitions[A]( rdd: RDD[A], schema: StructType, f: Iterator[A] => Iterator[Row]): SchemaRDD = @@ -154,7 +154,7 @@ class SQLContext(@transient val sparkContext: SparkContext) @Experimental def jsonRDD(json: RDD[String], samplingRatio: Double): SchemaRDD = { val schema = JsonRDD.nullTypeToStringType(JsonRDD.inferSchema(json, samplingRatio)) - applySchemaPartitions(json, schema, JsonRDD.jsonStringToRow(schema, _: Iterator[String])) + applySchemaToPartitions(json, schema, JsonRDD.jsonStringToRow(schema, _: Iterator[String])) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index f8aba3d543932..bec741c96b678 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -28,7 +28,6 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.Logging -import org.apache.spark.sql.util private[sql] object JsonRDD extends Logging { @@ -271,12 +270,29 @@ private[sql] object JsonRDD extends Logging { } } - private def toDecimalValue: PartialFunction[Any, BigDecimal] = { - def bigIntegerToDecimalValue: PartialFunction[Any, BigDecimal] = { - case v: java.math.BigInteger => BigDecimal(v) + private def toLong(value: Any): Long = { + value match { + case value: java.lang.Integer => value.asInstanceOf[Int].toLong + case value: java.lang.Long => value.asInstanceOf[Long] } + } - bigIntegerToDecimalValue orElse util.toDecimalValue + private def toDouble(value: Any): Double = { + value match { + case value: java.lang.Integer => value.asInstanceOf[Int].toDouble + case value: java.lang.Long => value.asInstanceOf[Long].toDouble + case value: java.lang.Double => value.asInstanceOf[Double] + } + } + + private def toDecimal(value: Any): BigDecimal = { + value match { + case value: java.lang.Integer => BigDecimal(value) + case value: java.lang.Long => BigDecimal(value) + case value: java.math.BigInteger => BigDecimal(value) + case value: java.lang.Double => BigDecimal(value) + case value: java.math.BigDecimal => BigDecimal(value) + } } private def toJsonArrayString(seq: Seq[Any]): String = { @@ -287,7 +303,7 @@ private[sql] object JsonRDD extends Logging { element => if (count > 0) builder.append(",") count += 1 - builder.append(toStringValue(element)) + builder.append(toString(element)) } builder.append("]") @@ -302,31 +318,37 @@ private[sql] object JsonRDD extends Logging { case (key, value) => if (count > 0) builder.append(",") count += 1 - builder.append(s"""\"${key}\":${toStringValue(value)}""") + builder.append(s"""\"${key}\":${toString(value)}""") } builder.append("}") builder.toString() } - private def toStringValue: PartialFunction[Any, String] = { - def complexValueToStringValue: PartialFunction[Any, String] = { - case v: Map[String, Any] => toJsonObjectString(v) - case v: Seq[Any] => toJsonArrayString(v) + private def toString(value: Any): String = { + value match { + case value: Map[String, Any] => toJsonObjectString(value) + case value: Seq[Any] => toJsonArrayString(value) + case value => Option(value).map(_.toString).orNull } - - complexValueToStringValue orElse util.toStringValue } - private[json] def castToType: PartialFunction[(Any, DataType), Any] = { - def jsonSpecificCast: PartialFunction[(Any, DataType), Any] = { - case (v, StringType) => toStringValue(v) - case (v, DecimalType) => toDecimalValue(v) - case (v, ArrayType(elementType)) => - v.asInstanceOf[Seq[Any]].map(castToType(_, elementType)) + private[json] def enforceCorrectType(value: Any, desiredType: DataType): Any ={ + if (value == null) { + null + } else { + desiredType match { + case ArrayType(elementType) => + value.asInstanceOf[Seq[Any]].map(enforceCorrectType(_, elementType)) + case StringType => toString(value) + case IntegerType => value.asInstanceOf[IntegerType.JvmType] + case LongType => toLong(value) + case DoubleType => toDouble(value) + case DecimalType => toDecimal(value) + case BooleanType => value.asInstanceOf[BooleanType.JvmType] + case NullType => null + } } - - jsonSpecificCast orElse util.castToType } private def asRow(json: Map[String,Any], schema: StructType): Row = { @@ -348,7 +370,7 @@ private[sql] object JsonRDD extends Logging { // Other cases case (StructField(name, dataType, _), i) => row.update(i, json.get(name).flatMap(v => Option(v)).map( - castToType(_, dataType)).getOrElse(null)) + enforceCorrectType(_, dataType)).getOrElse(null)) } row diff --git a/sql/core/src/main/scala/org/apache/spark/sql/util/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/util/package.scala deleted file mode 100644 index 2be4d5cf53af2..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/util/package.scala +++ /dev/null @@ -1,175 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import scala.math.BigDecimal - -import org.apache.spark.annotation.DeveloperApi - -package object util { - - - /** - * :: DeveloperApi :: - */ - @DeveloperApi - def toBooleanValue: PartialFunction[Any, BooleanType.JvmType] = { - case v: BooleanType.JvmType => v - case v: ByteType.JvmType if v == 1 => true - case v: ByteType.JvmType if v == 0 => false - case v: ShortType.JvmType if v == 1 => true - case v: ShortType.JvmType if v == 0 => false - case v: IntegerType.JvmType if v == 1 => true - case v: IntegerType.JvmType if v == 0 => false - case v: LongType.JvmType if v == 1 => true - case v: LongType.JvmType if v == 0 => false - case v: StringType.JvmType if v.toLowerCase == "true" => true - case v: StringType.JvmType if v.toLowerCase == "false" => false - } - - /** - * :: DeveloperApi :: - */ - @DeveloperApi - def toStringValue: PartialFunction[Any, StringType.JvmType] = { - case v => Option(v).map(_.toString).orNull - } - - /** - * :: DeveloperApi :: - */ - @DeveloperApi - def toByteValue: PartialFunction[Any, ByteType.JvmType] = { - case v: BooleanType.JvmType => if (v) 1.toByte else 0.toByte - case v: ByteType.JvmType => v - case v: StringType.JvmType => v.toByte - } - - /** - * :: DeveloperApi :: - */ - @DeveloperApi - def toShortValue: PartialFunction[Any, ShortType.JvmType] = { - case v: BooleanType.JvmType => if (v) 1.toShort else 0.toShort - case v: ByteType.JvmType => v.toShort - case v: ShortType.JvmType => v - case v: StringType.JvmType => v.toShort - } - - /** - * :: DeveloperApi :: - */ - @DeveloperApi - def toIntegerValue: PartialFunction[Any, IntegerType.JvmType] = { - case v: BooleanType.JvmType => if (v) 1 else 0 - case v: ByteType.JvmType => v.toInt - case v: ShortType.JvmType => v.toInt - case v: IntegerType.JvmType => v - case v: StringType.JvmType => v.toInt - } - - /** - * :: DeveloperApi :: - */ - @DeveloperApi - def toLongValue: PartialFunction[Any, LongType.JvmType] = { - case v: BooleanType.JvmType => if (v) 1.toLong else 0.toLong - case v: ByteType.JvmType => v.toLong - case v: ShortType.JvmType => v.toLong - case v: IntegerType.JvmType => v.toLong - case v: LongType.JvmType => v - // We can convert a Timestamp object to a Long because a Long representation of - // a Timestamp object has a clear meaning - // (milliseconds since January 1, 1970, 00:00:00 GMT). - case v: TimestampType.JvmType => v.getTime - case v: StringType.JvmType => v.toLong - } - - /** - * :: DeveloperApi :: - */ - @DeveloperApi - def toFloatValue: PartialFunction[Any, FloatType.JvmType] = { - case v: BooleanType.JvmType => if (v) 1.toFloat else 0.toFloat - case v: ByteType.JvmType => v.toFloat - case v: ShortType.JvmType => v.toFloat - case v: IntegerType.JvmType => v.toFloat - case v: FloatType.JvmType => v - case v: StringType.JvmType => v.toFloat - } - - /** - * :: DeveloperApi :: - */ - @DeveloperApi - def toDoubleValue: PartialFunction[Any, DoubleType.JvmType] = { - case v: BooleanType.JvmType => if (v) 1.toDouble else 0.toDouble - case v: ByteType.JvmType => v.toDouble - case v: ShortType.JvmType => v.toDouble - case v: IntegerType.JvmType => v.toDouble - case v: LongType.JvmType => v.toDouble - case v: FloatType.JvmType => v.toDouble - case v: DoubleType.JvmType => v - case v: StringType.JvmType => v.toDouble - } - - /** - * :: DeveloperApi :: - */ - @DeveloperApi - def toDecimalValue: PartialFunction[Any, DecimalType.JvmType] = { - case v: BooleanType.JvmType => if (v) BigDecimal(1) else BigDecimal(0) - case v: ByteType.JvmType => BigDecimal(v) - case v: ShortType.JvmType => BigDecimal(v) - case v: IntegerType.JvmType => BigDecimal(v) - case v: LongType.JvmType => BigDecimal(v) - case v: FloatType.JvmType => BigDecimal(v) - case v: DoubleType.JvmType => BigDecimal(v) - case v: TimestampType.JvmType => BigDecimal(v.getTime) - case v: StringType.JvmType => BigDecimal(v) - } - - /** - * :: DeveloperApi :: - */ - @DeveloperApi - def toTimestampValue: PartialFunction[Any, TimestampType.JvmType] = { - case v: LongType.JvmType => new java.sql.Timestamp(v) - case v: TimestampType.JvmType => v - case v: StringType.JvmType => java.sql.Timestamp.valueOf(v) - } - - /** - * :: DeveloperApi :: - */ - @DeveloperApi - def castToType: PartialFunction[(Any, DataType), Any] = { - case (null, _) => null - case (_, NullType) => null - case (v, BooleanType) => toBooleanValue(v) - case (v, StringType) => toStringValue(v) - case (v, ByteType) => toByteValue(v) - case (v, ShortType) => toShortValue(v) - case (v, IntegerType) => toIntegerValue(v) - case (v, LongType) => toLongValue(v) - case (v, FloatType) => toFloatValue(v) - case (v, DoubleType) => toDoubleValue(v) - case (v, DecimalType) => toDecimalValue(v) - case (v, TimestampType) => toTimestampValue(v) - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index 27391e6708076..e765cfc83a397 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical.LeafNode import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.json.JsonRDD.{castToType, compatibleType} +import org.apache.spark.sql.json.JsonRDD.{enforceCorrectType, compatibleType} import org.apache.spark.sql.QueryTest import org.apache.spark.sql.test.TestSQLContext._ @@ -41,19 +41,19 @@ class JsonSuite extends QueryTest { } val intNumber: Int = 2147483647 - checkTypePromotion(intNumber, castToType(intNumber, IntegerType)) - checkTypePromotion(intNumber.toLong, castToType(intNumber, LongType)) - checkTypePromotion(intNumber.toDouble, castToType(intNumber, DoubleType)) - checkTypePromotion(BigDecimal(intNumber), castToType(intNumber, DecimalType)) + checkTypePromotion(intNumber, enforceCorrectType(intNumber, IntegerType)) + checkTypePromotion(intNumber.toLong, enforceCorrectType(intNumber, LongType)) + checkTypePromotion(intNumber.toDouble, enforceCorrectType(intNumber, DoubleType)) + checkTypePromotion(BigDecimal(intNumber), enforceCorrectType(intNumber, DecimalType)) val longNumber: Long = 9223372036854775807L - checkTypePromotion(longNumber, castToType(longNumber, LongType)) - checkTypePromotion(longNumber.toDouble, castToType(longNumber, DoubleType)) - checkTypePromotion(BigDecimal(longNumber), castToType(longNumber, DecimalType)) + checkTypePromotion(longNumber, enforceCorrectType(longNumber, LongType)) + checkTypePromotion(longNumber.toDouble, enforceCorrectType(longNumber, DoubleType)) + checkTypePromotion(BigDecimal(longNumber), enforceCorrectType(longNumber, DecimalType)) val doubleNumber: Double = 1.7976931348623157E308d - checkTypePromotion(doubleNumber.toDouble, castToType(doubleNumber, DoubleType)) - checkTypePromotion(BigDecimal(doubleNumber), castToType(doubleNumber, DecimalType)) + checkTypePromotion(doubleNumber.toDouble, enforceCorrectType(doubleNumber, DoubleType)) + checkTypePromotion(BigDecimal(doubleNumber), enforceCorrectType(doubleNumber, DecimalType)) } test("Get compatible type") {