Skip to content

Commit

Permalink
fix: [comet-parquet-exec] fix regressions original comet native scal …
Browse files Browse the repository at this point in the history
…implementation (#1170)

* fix: CometScanExec was created for unsupported cases if only COMET_NATIVE_SCAN is enabled

* fix: Another try to fix '  test("Comet native metrics: BroadcastHashJoin")

* fix: some tests are valid only when full native scan is enabled

* Merge pull request #1 from andygrove/fix-tests-spark-cast-options
  • Loading branch information
parthchandra authored Dec 13, 2024
1 parent 06cdd22 commit 8563edf
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ public void init() throws URISyntaxException, IOException {
requestedSchema =
CometParquetReadSupport.clipParquetSchema(
requestedSchema, sparkSchema, isCaseSensitive, useFieldId, ignoreMissingIds);
if (requestedSchema.getColumns().size() != sparkSchema.size()) {
if (requestedSchema.getFieldCount() != sparkSchema.size()) {
throw new IllegalArgumentException(
String.format(
"Spark schema has %d columns while " + "Parquet schema has %d columns",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ public void init() throws URISyntaxException, IOException {
requestedSchema =
CometParquetReadSupport.clipParquetSchema(
requestedSchema, sparkSchema, isCaseSensitive, useFieldId, ignoreMissingIds);
if (requestedSchema.getColumns().size() != sparkSchema.size()) {
if (requestedSchema.getFieldCount() != sparkSchema.size()) {
throw new IllegalArgumentException(
String.format(
"Spark schema has %d columns while " + "Parquet schema has %d columns",
Expand All @@ -267,9 +267,9 @@ public void init() throws URISyntaxException, IOException {
// ShimFileFormat.findRowIndexColumnIndexInSchema(sparkSchema);
for (int i = 0; i < requestedSchema.getFieldCount(); i++) {
Type t = requestedSchema.getFields().get(i);
Preconditions.checkState(
t.isPrimitive() && !t.isRepetition(Type.Repetition.REPEATED),
"Complex type is not supported");
// Preconditions.checkState(
// t.isPrimitive() && !t.isRepetition(Type.Repetition.REPEATED),
// "Complex type is not supported");
String[] colPath = paths.get(i);
if (nonPartitionFields[i].name().equals(ShimFileFormat.ROW_INDEX_TEMPORARY_COLUMN_NAME())) {
// Values of ROW_INDEX_TEMPORARY_COLUMN_NAME column are always populated with
Expand Down
16 changes: 8 additions & 8 deletions native/spark-expr/src/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,7 @@ impl SparkCastOptions {
eval_mode,
timezone: timezone.to_string(),
allow_incompat,
is_adapting_schema: false,
is_adapting_schema: false
}
}

Expand All @@ -583,6 +583,7 @@ impl SparkCastOptions {
is_adapting_schema: false,
}
}

}

/// Spark-compatible cast implementation. Defers to DataFusion's cast where that is known
Expand Down Expand Up @@ -2087,7 +2088,7 @@ mod tests {

let timezone = "UTC".to_string();
// test casting string dictionary array to timestamp array
let cast_options = SparkCastOptions::new(EvalMode::Legacy, timezone.clone(), false);
let cast_options = SparkCastOptions::new(EvalMode::Legacy, &timezone, false);
let result = cast_array(
dict_array,
&DataType::Timestamp(TimeUnit::Microsecond, Some(timezone.clone().into())),
Expand Down Expand Up @@ -2296,7 +2297,7 @@ mod tests {
fn test_cast_unsupported_timestamp_to_date() {
// Since datafusion uses chrono::Datetime internally not all dates representable by TimestampMicrosecondType are supported
let timestamps: PrimitiveArray<TimestampMicrosecondType> = vec![i64::MAX].into();
let cast_options = SparkCastOptions::new(EvalMode::Legacy, "UTC".to_string(), false);
let cast_options = SparkCastOptions::new(EvalMode::Legacy, "UTC", false);
let result = cast_array(
Arc::new(timestamps.with_timezone("Europe/Copenhagen")),
&DataType::Date32,
Expand All @@ -2309,7 +2310,7 @@ mod tests {
fn test_cast_invalid_timezone() {
let timestamps: PrimitiveArray<TimestampMicrosecondType> = vec![i64::MAX].into();
let cast_options =
SparkCastOptions::new(EvalMode::Legacy, "Not a valid timezone".to_string(), false);
SparkCastOptions::new(EvalMode::Legacy, "Not a valid timezone", false);
let result = cast_array(
Arc::new(timestamps.with_timezone("Europe/Copenhagen")),
&DataType::Date32,
Expand All @@ -2335,7 +2336,7 @@ mod tests {
let string_array = cast_array(
c,
&DataType::Utf8,
&SparkCastOptions::new(EvalMode::Legacy, "UTC".to_owned(), false),
&SparkCastOptions::new(EvalMode::Legacy, "UTC", false),
)
.unwrap();
let string_array = string_array.as_string::<i32>();
Expand Down Expand Up @@ -2400,10 +2401,9 @@ mod tests {
let cast_array = spark_cast(
ColumnarValue::Array(c),
&DataType::Struct(fields),
EvalMode::Legacy,
&SparkCastOptions::new(EvalMode::Legacy,
"UTC",
false,
false,
false)
)
.unwrap();
if let ColumnarValue::Array(cast_array) = cast_array {
Expand Down
5 changes: 4 additions & 1 deletion spark/src/main/scala/org/apache/comet/DataTypeSupport.scala
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,10 @@ trait DataTypeSupport {
BinaryType | StringType | _: DecimalType | DateType | TimestampType =>
true
case t: DataType if t.typeName == "timestamp_ntz" => true
case _: StructType => true
case _: StructType
if CometConf.COMET_FULL_NATIVE_SCAN_ENABLED
.get() || CometConf.COMET_NATIVE_ARROW_SCAN_ENABLED.get() =>
true
case _ => false
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,15 @@ case class CometScanExec(
// exposed for testing
lazy val bucketedScan: Boolean = wrapped.bucketedScan

override lazy val (outputPartitioning, outputOrdering): (Partitioning, Seq[SortOrder]) =
(wrapped.outputPartitioning, wrapped.outputOrdering)
override lazy val (outputPartitioning, outputOrdering): (Partitioning, Seq[SortOrder]) = {
if (bucketedScan) {
(wrapped.outputPartitioning, wrapped.outputOrdering)
} else {
val files = selectedPartitions.flatMap(partition => partition.files)
val numPartitions = files.length
(UnknownPartitioning(numPartitions), wrapped.outputOrdering)
}
}

@transient
private lazy val pushedDownFilters = getPushedDownFilters(relation, dataFilters)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2217,6 +2217,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
withSQLConf(
SQLConf.USE_V1_SOURCE_LIST.key -> v1List,
CometConf.COMET_ENABLED.key -> "true",
CometConf.COMET_FULL_NATIVE_SCAN_ENABLED.key -> "true",
CometConf.COMET_EXPLAIN_FALLBACK_ENABLED.key -> "true") {

val df = spark.read.parquet(dir.toString())
Expand Down Expand Up @@ -2249,6 +2250,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
withSQLConf(
SQLConf.USE_V1_SOURCE_LIST.key -> v1List,
CometConf.COMET_ENABLED.key -> "true",
CometConf.COMET_FULL_NATIVE_SCAN_ENABLED.key -> "true",
CometConf.COMET_EXPLAIN_FALLBACK_ENABLED.key -> "true") {

val df = spark.read.parquet(dir.toString())
Expand Down

0 comments on commit 8563edf

Please sign in to comment.