From e8f34c3ecd50fc3b5dcc4f491c7817d5ecfb02be Mon Sep 17 00:00:00 2001 From: stream2000 <18889897088@163.com> Date: Fri, 19 Jan 2024 10:12:43 +0800 Subject: [PATCH] [HUDI-7305] Fix cast exception for byte/short/float partitioned field (#10518) --- .../spark/sql/hudi/TestInsertTable.scala | 37 +++++++++++++++++++ .../Spark3ParsePartitionUtil.scala | 10 +++-- 2 files changed, 43 insertions(+), 4 deletions(-) diff --git a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestInsertTable.scala b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestInsertTable.scala index e7324a1354fe5..ef62a69477228 100644 --- a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestInsertTable.scala +++ b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestInsertTable.scala @@ -2242,6 +2242,43 @@ class TestInsertTable extends HoodieSparkSqlTestBase { }) } + test("Test various data types as partition fields") { + withRecordType()(withTempDir { tmp => + val tableName = generateTableName + spark.sql( + s""" + |CREATE TABLE $tableName ( + | id INT, + | boolean_field BOOLEAN, + | float_field FLOAT, + | byte_field BYTE, + | short_field SHORT, + | decimal_field DECIMAL(10, 5), + | date_field DATE, + | string_field STRING, + | timestamp_field TIMESTAMP + |) USING hudi + | TBLPROPERTIES (primaryKey = 'id') + | PARTITIONED BY (boolean_field, float_field, byte_field, short_field, decimal_field, date_field, string_field, timestamp_field) + |LOCATION '${tmp.getCanonicalPath}' + """.stripMargin) + + // Insert data into partitioned table + spark.sql( + s""" + |INSERT INTO $tableName VALUES + |(1, TRUE, CAST(1.0 as FLOAT), 1, 1, 1234.56789, DATE '2021-01-05', 'partition1', TIMESTAMP '2021-01-05 10:00:00'), + |(2, FALSE,CAST(2.0 as FLOAT), 2, 2, 6789.12345, DATE '2021-01-06', 'partition2', TIMESTAMP '2021-01-06 11:00:00') + """.stripMargin) + + checkAnswer(s"SELECT id, boolean_field FROM $tableName ORDER BY id")( + Seq(1, true), + Seq(2, false) + ) + }) + } + + def ingestAndValidateDataDupPolicy(tableType: String, tableName: String, tmp: File, expectedOperationtype: WriteOperationType = WriteOperationType.INSERT, setOptions: List[String] = List.empty, diff --git a/hudi-spark-datasource/hudi-spark3-common/src/main/scala/org/apache/spark/sql/execution/datasources/Spark3ParsePartitionUtil.scala b/hudi-spark-datasource/hudi-spark3-common/src/main/scala/org/apache/spark/sql/execution/datasources/Spark3ParsePartitionUtil.scala index ebe92a5a32a91..fca21d202a99c 100644 --- a/hudi-spark-datasource/hudi-spark3-common/src/main/scala/org/apache/spark/sql/execution/datasources/Spark3ParsePartitionUtil.scala +++ b/hudi-spark-datasource/hudi-spark3-common/src/main/scala/org/apache/spark/sql/execution/datasources/Spark3ParsePartitionUtil.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.execution.datasources import org.apache.hadoop.fs.Path import org.apache.hudi.common.util.PartitionPathEncodeUtils.DEFAULT_PARTITION_PATH import org.apache.hudi.spark3.internal.ReflectUtil -import org.apache.hudi.util.JFunction import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.unescapePathName import org.apache.spark.sql.catalyst.expressions.{Cast, Literal} @@ -29,10 +28,9 @@ import org.apache.spark.sql.execution.datasources.PartitioningUtils.timestampPar import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -import java.lang.{Boolean => JBoolean, Double => JDouble, Long => JLong} +import java.lang.{Double => JDouble, Long => JLong} import java.math.{BigDecimal => JBigDecimal} import java.time.ZoneId -import java.util import java.util.concurrent.ConcurrentHashMap import java.util.{Locale, TimeZone} import scala.collection.convert.Wrappers.JConcurrentMapWrapper @@ -259,10 +257,12 @@ object Spark3ParsePartitionUtil extends SparkParsePartitionUtil { zoneId: ZoneId): Any = desiredType match { case _ if value == DEFAULT_PARTITION_PATH => null case NullType => null - case BooleanType => JBoolean.parseBoolean(value) case StringType => UTF8String.fromString(unescapePathName(value)) + case ByteType => Integer.parseInt(value).toByte + case ShortType => Integer.parseInt(value).toShort case IntegerType => Integer.parseInt(value) case LongType => JLong.parseLong(value) + case FloatType => JDouble.parseDouble(value).toFloat case DoubleType => JDouble.parseDouble(value) case _: DecimalType => Literal(new JBigDecimal(value)).value case DateType => @@ -274,6 +274,8 @@ object Spark3ParsePartitionUtil extends SparkParsePartitionUtil { }.getOrElse { Cast(Cast(Literal(value), DateType, Some(zoneId.getId)), dt).eval() } + case BinaryType => value.getBytes() + case BooleanType => value.toBoolean case dt => throw new IllegalArgumentException(s"Unexpected type $dt") }