Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[HUDI-7305] Fix cast exception for byte/short/float partitioned field #10518

Merged
merged 7 commits into from
Jan 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2334,6 +2334,43 @@ class TestInsertTable extends HoodieSparkSqlTestBase {
})
}

test("Test various data types as partition fields") {
withRecordType()(withTempDir { tmp =>
val tableName = generateTableName
spark.sql(
s"""
|CREATE TABLE $tableName (
| id INT,
| boolean_field BOOLEAN,
| float_field FLOAT,
| byte_field BYTE,
| short_field SHORT,
| decimal_field DECIMAL(10, 5),
| date_field DATE,
| string_field STRING,
| timestamp_field TIMESTAMP
|) USING hudi
| TBLPROPERTIES (primaryKey = 'id')
| PARTITIONED BY (boolean_field, float_field, byte_field, short_field, decimal_field, date_field, string_field, timestamp_field)
|LOCATION '${tmp.getCanonicalPath}'
""".stripMargin)

// Insert data into partitioned table
spark.sql(
s"""
|INSERT INTO $tableName VALUES
|(1, TRUE, CAST(1.0 as FLOAT), 1, 1, 1234.56789, DATE '2021-01-05', 'partition1', TIMESTAMP '2021-01-05 10:00:00'),
|(2, FALSE,CAST(2.0 as FLOAT), 2, 2, 6789.12345, DATE '2021-01-06', 'partition2', TIMESTAMP '2021-01-06 11:00:00')
""".stripMargin)

checkAnswer(s"SELECT id, boolean_field FROM $tableName ORDER BY id")(
Seq(1, true),
Seq(2, false)
)
})
}


def ingestAndValidateDataDupPolicy(tableType: String, tableName: String, tmp: File,
expectedOperationtype: WriteOperationType = WriteOperationType.INSERT,
setOptions: List[String] = List.empty,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ package org.apache.spark.sql.execution.datasources
import org.apache.hadoop.fs.Path
import org.apache.hudi.common.util.PartitionPathEncodeUtils.DEFAULT_PARTITION_PATH
import org.apache.hudi.spark3.internal.ReflectUtil
import org.apache.hudi.util.JFunction
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.unescapePathName
import org.apache.spark.sql.catalyst.expressions.{Cast, Literal}
Expand All @@ -29,10 +28,9 @@ import org.apache.spark.sql.execution.datasources.PartitioningUtils.timestampPar
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

import java.lang.{Boolean => JBoolean, Double => JDouble, Long => JLong}
import java.lang.{Double => JDouble, Long => JLong}
import java.math.{BigDecimal => JBigDecimal}
import java.time.ZoneId
import java.util
import java.util.concurrent.ConcurrentHashMap
import java.util.{Locale, TimeZone}
import scala.collection.convert.Wrappers.JConcurrentMapWrapper
Expand Down Expand Up @@ -259,10 +257,12 @@ object Spark3ParsePartitionUtil extends SparkParsePartitionUtil {
zoneId: ZoneId): Any = desiredType match {
case _ if value == DEFAULT_PARTITION_PATH => null
case NullType => null
case BooleanType => JBoolean.parseBoolean(value)
case StringType => UTF8String.fromString(unescapePathName(value))
case ByteType => Integer.parseInt(value).toByte
case ShortType => Integer.parseInt(value).toShort
case IntegerType => Integer.parseInt(value)
case LongType => JLong.parseLong(value)
case FloatType => JDouble.parseDouble(value).toFloat
case DoubleType => JDouble.parseDouble(value)
case _: DecimalType => Literal(new JBigDecimal(value)).value
case DateType =>
Expand All @@ -274,6 +274,8 @@ object Spark3ParsePartitionUtil extends SparkParsePartitionUtil {
}.getOrElse {
Cast(Cast(Literal(value), DateType, Some(zoneId.getId)), dt).eval()
}
case BinaryType => value.getBytes()
case BooleanType => value.toBoolean
case dt => throw new IllegalArgumentException(s"Unexpected type $dt")
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, can we write a UT for it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure.

Expand Down
Loading