Skip to content

Commit

Permalink
feat: support array_append (#1072)
Browse files Browse the repository at this point in the history
* feat: support array_append

* formatted code

* rewrite array_append plan to match spark behaviour and fixed bug in QueryPlan serde

* remove unwrap

* Fix for Spark 3.3

* refactor array_append binary expression serde code

* Disabled array_append test for spark 4.0+
  • Loading branch information
NoeB authored Nov 13, 2024
1 parent 712658e commit 9657b75
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 2 deletions.
31 changes: 30 additions & 1 deletion native/core/src/execution/datafusion/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ use datafusion::{
},
prelude::SessionContext,
};
use datafusion_functions_nested::concat::ArrayAppend;
use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr};

use datafusion_comet_proto::{
Expand All @@ -107,7 +108,8 @@ use datafusion_common::{
};
use datafusion_expr::expr::find_df_window_func;
use datafusion_expr::{
AggregateUDF, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition,
AggregateUDF, ScalarUDF, WindowFrame, WindowFrameBound, WindowFrameUnits,
WindowFunctionDefinition,
};
use datafusion_physical_expr::expressions::{Literal, StatsType};
use datafusion_physical_expr::window::WindowExpr;
Expand Down Expand Up @@ -691,6 +693,33 @@ impl PhysicalPlanner {
expr.ordinal as usize,
)))
}
ExprStruct::ArrayAppend(expr) => {
let left =
self.create_expr(expr.left.as_ref().unwrap(), Arc::clone(&input_schema))?;
let right =
self.create_expr(expr.right.as_ref().unwrap(), Arc::clone(&input_schema))?;
let return_type = left.data_type(&input_schema)?;
let args = vec![Arc::clone(&left), right];
let datafusion_array_append =
Arc::new(ScalarUDF::new_from_impl(ArrayAppend::new()));
let array_append_expr: Arc<dyn PhysicalExpr> = Arc::new(ScalarFunctionExpr::new(
"array_append",
datafusion_array_append,
args,
return_type,
));

let is_null_expr: Arc<dyn PhysicalExpr> = Arc::new(IsNullExpr::new(left));
let null_literal_expr: Arc<dyn PhysicalExpr> =
Arc::new(Literal::new(ScalarValue::Null));

let case_expr = CaseExpr::try_new(
None,
vec![(is_null_expr, null_literal_expr)],
Some(array_append_expr),
)?;
Ok(Arc::new(case_expr))
}
expr => Err(ExecutionError::GeneralError(format!(
"Not implemented: {:?}",
expr
Expand Down
1 change: 1 addition & 0 deletions native/proto/src/proto/expr.proto
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ message Expr {
ToJson to_json = 55;
ListExtract list_extract = 56;
GetArrayStructFields get_array_struct_fields = 57;
BinaryExpr array_append = 58;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2237,7 +2237,12 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
withInfo(expr, "unsupported arguments for GetArrayStructFields", child)
None
}

case _ if expr.prettyName == "array_append" =>
createBinaryExpr(
expr.children(0),
expr.children(1),
inputs,
(builder, binaryExpr) => builder.setArrayAppend(binaryExpr))
case _ =>
withInfo(expr, s"${expr.prettyName} is not supported", expr.children: _*)
None
Expand Down
24 changes: 24 additions & 0 deletions spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2313,4 +2313,28 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
}
}
}

test("array_append") {
// array append has been added in Spark 3.4 and in Spark 4.0 it gets written to ArrayInsert
assume(isSpark34Plus && !isSpark40Plus)
Seq(true, false).foreach { dictionaryEnabled =>
withTempDir { dir =>
val path = new Path(dir.toURI.toString, "test.parquet")
makeParquetFileAllTypes(path, dictionaryEnabled = dictionaryEnabled, 10000)
spark.read.parquet(path.toString).createOrReplaceTempView("t1");
checkSparkAnswerAndOperator(spark.sql("Select array_append(array(_1),false) from t1"))
checkSparkAnswerAndOperator(
spark.sql("SELECT array_append(array(_2, _3, _4), 4) FROM t1"))
checkSparkAnswerAndOperator(
spark.sql("SELECT array_append(array(_2, _3, _4), null) FROM t1"));
checkSparkAnswerAndOperator(
spark.sql("SELECT array_append(array(_6, _7), CAST(6.5 AS DOUBLE)) FROM t1"));
checkSparkAnswerAndOperator(spark.sql("SELECT array_append(array(_8), 'test') FROM t1"));
checkSparkAnswerAndOperator(spark.sql("SELECT array_append(array(_19), _19) FROM t1"));
checkSparkAnswerAndOperator(
spark.sql("SELECT array_append((CASE WHEN _2 =_3 THEN array(_4) END), _4) FROM t1"));
}

}
}
}

0 comments on commit 9657b75

Please sign in to comment.