Skip to content

Commit

Permalink
[HUDI-7305] Fix cast exception for byte/short/float partitioned field (
Browse files Browse the repository at this point in the history
  • Loading branch information
stream2000 authored Jan 19, 2024
1 parent 4073664 commit 696911e
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2334,6 +2334,43 @@ class TestInsertTable extends HoodieSparkSqlTestBase {
})
}

test("Test various data types as partition fields") {
withRecordType()(withTempDir { tmp =>
val tableName = generateTableName
spark.sql(
s"""
|CREATE TABLE $tableName (
| id INT,
| boolean_field BOOLEAN,
| float_field FLOAT,
| byte_field BYTE,
| short_field SHORT,
| decimal_field DECIMAL(10, 5),
| date_field DATE,
| string_field STRING,
| timestamp_field TIMESTAMP
|) USING hudi
| TBLPROPERTIES (primaryKey = 'id')
| PARTITIONED BY (boolean_field, float_field, byte_field, short_field, decimal_field, date_field, string_field, timestamp_field)
|LOCATION '${tmp.getCanonicalPath}'
""".stripMargin)

// Insert data into partitioned table
spark.sql(
s"""
|INSERT INTO $tableName VALUES
|(1, TRUE, CAST(1.0 as FLOAT), 1, 1, 1234.56789, DATE '2021-01-05', 'partition1', TIMESTAMP '2021-01-05 10:00:00'),
|(2, FALSE,CAST(2.0 as FLOAT), 2, 2, 6789.12345, DATE '2021-01-06', 'partition2', TIMESTAMP '2021-01-06 11:00:00')
""".stripMargin)

checkAnswer(s"SELECT id, boolean_field FROM $tableName ORDER BY id")(
Seq(1, true),
Seq(2, false)
)
})
}


def ingestAndValidateDataDupPolicy(tableType: String, tableName: String, tmp: File,
expectedOperationtype: WriteOperationType = WriteOperationType.INSERT,
setOptions: List[String] = List.empty,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ package org.apache.spark.sql.execution.datasources
import org.apache.hadoop.fs.Path
import org.apache.hudi.common.util.PartitionPathEncodeUtils.DEFAULT_PARTITION_PATH
import org.apache.hudi.spark3.internal.ReflectUtil
import org.apache.hudi.util.JFunction
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.unescapePathName
import org.apache.spark.sql.catalyst.expressions.{Cast, Literal}
Expand All @@ -29,10 +28,9 @@ import org.apache.spark.sql.execution.datasources.PartitioningUtils.timestampPar
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

import java.lang.{Boolean => JBoolean, Double => JDouble, Long => JLong}
import java.lang.{Double => JDouble, Long => JLong}
import java.math.{BigDecimal => JBigDecimal}
import java.time.ZoneId
import java.util
import java.util.concurrent.ConcurrentHashMap
import java.util.{Locale, TimeZone}
import scala.collection.convert.Wrappers.JConcurrentMapWrapper
Expand Down Expand Up @@ -259,10 +257,12 @@ object Spark3ParsePartitionUtil extends SparkParsePartitionUtil {
zoneId: ZoneId): Any = desiredType match {
case _ if value == DEFAULT_PARTITION_PATH => null
case NullType => null
case BooleanType => JBoolean.parseBoolean(value)
case StringType => UTF8String.fromString(unescapePathName(value))
case ByteType => Integer.parseInt(value).toByte
case ShortType => Integer.parseInt(value).toShort
case IntegerType => Integer.parseInt(value)
case LongType => JLong.parseLong(value)
case FloatType => JDouble.parseDouble(value).toFloat
case DoubleType => JDouble.parseDouble(value)
case _: DecimalType => Literal(new JBigDecimal(value)).value
case DateType =>
Expand All @@ -274,6 +274,8 @@ object Spark3ParsePartitionUtil extends SparkParsePartitionUtil {
}.getOrElse {
Cast(Cast(Literal(value), DateType, Some(zoneId.getId)), dt).eval()
}
case BinaryType => value.getBytes()
case BooleanType => value.toBoolean
case dt => throw new IllegalArgumentException(s"Unexpected type $dt")
}

Expand Down

0 comments on commit 696911e

Please sign in to comment.