From c94cac82ea84f27fc3f95e9f9a193198ae1cccb8 Mon Sep 17 00:00:00 2001 From: Duong Cong Toai Date: Sat, 20 Apr 2024 10:03:23 +0200 Subject: [PATCH 01/19] validate input/output of udf --- .../user_defined_scalar_functions.rs | 40 ++++++++++++++++++- .../physical-expr/src/scalar_function.rs | 13 +++++- 2 files changed, 50 insertions(+), 3 deletions(-) diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index 86be887198ae..08f93cac44d0 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -15,12 +15,15 @@ // specific language governing permissions and limitations // under the License. +use arrow::array::Date32Builder; use arrow::compute::kernels::numeric::add; use arrow_array::{ - Array, ArrayRef, Float32Array, Float64Array, Int32Array, RecordBatch, UInt8Array, + Array, ArrayRef, Date32Array, Float32Array, Float64Array, Int32Array, RecordBatch, + StringArray, UInt8Array, }; use arrow_schema::DataType::Float64; use arrow_schema::{DataType, Field, Schema}; +use chrono::DateTime; use datafusion::execution::context::{FunctionFactory, RegisterFunction, SessionState}; use datafusion::prelude::*; use datafusion::{execution::registry::FunctionRegistry, test_util}; @@ -29,7 +32,7 @@ use datafusion_common::{ assert_batches_eq, assert_batches_sorted_eq, cast::as_float64_array, cast::as_int32_array, not_impl_err, plan_err, ExprSchema, Result, ScalarValue, }; -use datafusion_common::{exec_err, internal_err, DataFusionError}; +use datafusion_common::{assert_contains, downcast_value, exec_err, internal_err, DataFusionError}; use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; use datafusion_expr::{ @@ -168,6 +171,39 @@ async fn scalar_udf() -> Result<()> { Ok(()) } +#[tokio::test] +async fn test_row_mismatch_error_in_scalar_udf() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![Arc::new(Int32Array::from(vec![1,2]))], + )?; + + let ctx = SessionContext::new(); + + ctx.register_batch("t", batch)?; + + // udf that always return 1 row + let buggy_udf = Arc::new(|_: &[ColumnarValue]| { + println!("here"); + Ok(ColumnarValue::Array(Arc::new(Int32Array::from(vec![0])))) + }); + + ctx.register_udf(create_udf( + "buggy_func", + vec![DataType::Int32], + Arc::new(DataType::Int32), + Volatility::Immutable, + buggy_udf, + )); + assert_contains!( + ctx.sql("select buggy_func(a) from t").await?.show().await.err().unwrap().to_string(), + "UDF returned a different number of rows than expected" + ); + Ok(()) +} + #[tokio::test] async fn scalar_udf_zero_params() -> Result<()> { let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); diff --git a/datafusion/physical-expr/src/scalar_function.rs b/datafusion/physical-expr/src/scalar_function.rs index d34084236690..008fa57deba1 100644 --- a/datafusion/physical-expr/src/scalar_function.rs +++ b/datafusion/physical-expr/src/scalar_function.rs @@ -177,7 +177,18 @@ impl PhysicalExpr for ScalarFunctionExpr { let fun = create_physical_fun(fun)?; (fun)(&inputs) } - ScalarFunctionDefinition::UDF(ref fun) => fun.invoke(&inputs), + ScalarFunctionDefinition::UDF(ref fun) => { + let output = fun.invoke(&inputs)?; + let output_count = match &output { + ColumnarValue::Array(array) => array.len(), + ColumnarValue::Scalar(_) => 1, + }; + if output_count != batch.num_rows() { + return internal_err!("UDF returned a different number of rows than expected. Expected: {}, Got: {}", + batch.num_rows(), output_count); + } + Ok(output) + } ScalarFunctionDefinition::Name(_) => { internal_err!( "Name function must be resolved to one of the other variants prior to physical planning" From 5d67c73bb6bdd2f632fa9f42b0e523ee9a5882fe Mon Sep 17 00:00:00 2001 From: Duong Cong Toai Date: Sat, 20 Apr 2024 10:11:28 +0200 Subject: [PATCH 02/19] clip --- .../tests/user_defined/user_defined_scalar_functions.rs | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index 08f93cac44d0..1ae049034781 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -15,15 +15,14 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::Date32Builder; + use arrow::compute::kernels::numeric::add; use arrow_array::{ - Array, ArrayRef, Date32Array, Float32Array, Float64Array, Int32Array, RecordBatch, - StringArray, UInt8Array, + Array, ArrayRef, Float32Array, Float64Array, Int32Array, RecordBatch, UInt8Array, }; use arrow_schema::DataType::Float64; use arrow_schema::{DataType, Field, Schema}; -use chrono::DateTime; + use datafusion::execution::context::{FunctionFactory, RegisterFunction, SessionState}; use datafusion::prelude::*; use datafusion::{execution::registry::FunctionRegistry, test_util}; @@ -32,7 +31,7 @@ use datafusion_common::{ assert_batches_eq, assert_batches_sorted_eq, cast::as_float64_array, cast::as_int32_array, not_impl_err, plan_err, ExprSchema, Result, ScalarValue, }; -use datafusion_common::{assert_contains, downcast_value, exec_err, internal_err, DataFusionError}; +use datafusion_common::{assert_contains, exec_err, internal_err, DataFusionError}; use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; use datafusion_expr::{ From 5587b03f5c5b27b836fd9c158665268d05442fba Mon Sep 17 00:00:00 2001 From: Duong Cong Toai Date: Sat, 20 Apr 2024 10:14:19 +0200 Subject: [PATCH 03/19] fmt --- .../user_defined/user_defined_scalar_functions.rs | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index 1ae049034781..ff8628607cdc 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. - use arrow::compute::kernels::numeric::add; use arrow_array::{ Array, ArrayRef, Float32Array, Float64Array, Int32Array, RecordBatch, UInt8Array, @@ -176,7 +175,7 @@ async fn test_row_mismatch_error_in_scalar_udf() -> Result<()> { let batch = RecordBatch::try_new( Arc::new(schema.clone()), - vec![Arc::new(Int32Array::from(vec![1,2]))], + vec![Arc::new(Int32Array::from(vec![1, 2]))], )?; let ctx = SessionContext::new(); @@ -197,7 +196,13 @@ async fn test_row_mismatch_error_in_scalar_udf() -> Result<()> { buggy_udf, )); assert_contains!( - ctx.sql("select buggy_func(a) from t").await?.show().await.err().unwrap().to_string(), + ctx.sql("select buggy_func(a) from t") + .await? + .show() + .await + .err() + .unwrap() + .to_string(), "UDF returned a different number of rows than expected" ); Ok(()) From bba9bfe777edd9f391b97b6a6f917d2241ce2d4a Mon Sep 17 00:00:00 2001 From: Duong Cong Toai Date: Sat, 20 Apr 2024 10:16:00 +0200 Subject: [PATCH 04/19] clean garbage --- .../core/tests/user_defined/user_defined_scalar_functions.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index ff8628607cdc..86de4c92a000 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -184,7 +184,6 @@ async fn test_row_mismatch_error_in_scalar_udf() -> Result<()> { // udf that always return 1 row let buggy_udf = Arc::new(|_: &[ColumnarValue]| { - println!("here"); Ok(ColumnarValue::Array(Arc::new(Int32Array::from(vec![0])))) }); From 9be1a5a4613340211ec702eb177664e94da4b27f Mon Sep 17 00:00:00 2001 From: Duong Cong Toai Date: Sun, 21 Apr 2024 08:23:03 +0200 Subject: [PATCH 05/19] don't check if output is scalar --- .../user_defined/user_defined_scalar_functions.rs | 2 ++ datafusion/physical-expr/src/scalar_function.rs | 12 +++++------- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index 86de4c92a000..ddde9b9f03bd 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -22,6 +22,7 @@ use arrow_array::{ use arrow_schema::DataType::Float64; use arrow_schema::{DataType, Field, Schema}; +use datafusion::datasource::MemTable; use datafusion::execution::context::{FunctionFactory, RegisterFunction, SessionState}; use datafusion::prelude::*; use datafusion::{execution::registry::FunctionRegistry, test_util}; @@ -37,6 +38,7 @@ use datafusion_expr::{ create_udaf, create_udf, Accumulator, ColumnarValue, CreateFunction, ExprSchemable, LogicalPlanBuilder, ScalarUDF, ScalarUDFImpl, Signature, Volatility, }; +use datafusion_sql::TableReference; use rand::{thread_rng, Rng}; use std::any::Any; use std::iter; diff --git a/datafusion/physical-expr/src/scalar_function.rs b/datafusion/physical-expr/src/scalar_function.rs index 008fa57deba1..d88250474546 100644 --- a/datafusion/physical-expr/src/scalar_function.rs +++ b/datafusion/physical-expr/src/scalar_function.rs @@ -179,13 +179,11 @@ impl PhysicalExpr for ScalarFunctionExpr { } ScalarFunctionDefinition::UDF(ref fun) => { let output = fun.invoke(&inputs)?; - let output_count = match &output { - ColumnarValue::Array(array) => array.len(), - ColumnarValue::Scalar(_) => 1, - }; - if output_count != batch.num_rows() { - return internal_err!("UDF returned a different number of rows than expected. Expected: {}, Got: {}", - batch.num_rows(), output_count); + if let ColumnarValue::Array(array) = &output { + if array.len() != batch.num_rows() { + return internal_err!("UDF returned a different number of rows than expected. Expected: {}, Got: {}", + batch.num_rows(), array.len()); + } } Ok(output) } From 2a93b8c1855cc1dcc4ebfae69c2cdf1a03d6c396 Mon Sep 17 00:00:00 2001 From: Duong Cong Toai Date: Sun, 21 Apr 2024 08:29:56 +0200 Subject: [PATCH 06/19] lint --- .../core/tests/user_defined/user_defined_scalar_functions.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index ddde9b9f03bd..5c6cff197914 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -22,7 +22,7 @@ use arrow_array::{ use arrow_schema::DataType::Float64; use arrow_schema::{DataType, Field, Schema}; -use datafusion::datasource::MemTable; + use datafusion::execution::context::{FunctionFactory, RegisterFunction, SessionState}; use datafusion::prelude::*; use datafusion::{execution::registry::FunctionRegistry, test_util}; @@ -38,7 +38,7 @@ use datafusion_expr::{ create_udaf, create_udf, Accumulator, ColumnarValue, CreateFunction, ExprSchemable, LogicalPlanBuilder, ScalarUDF, ScalarUDFImpl, Signature, Volatility, }; -use datafusion_sql::TableReference; + use rand::{thread_rng, Rng}; use std::any::Any; use std::iter; From 974338240f5d604b70900cd183ab0ead67606c36 Mon Sep 17 00:00:00 2001 From: Duong Cong Toai Date: Sun, 21 Apr 2024 17:44:15 +0200 Subject: [PATCH 07/19] fix array_has --- datafusion/functions-array/src/array_has.rs | 59 ++++++++++---------- datafusion/functions/src/core/getfield.rs | 2 + datafusion/sqllogictest/test_files/array.slt | 15 +++-- 3 files changed, 42 insertions(+), 34 deletions(-) diff --git a/datafusion/functions-array/src/array_has.rs b/datafusion/functions-array/src/array_has.rs index ee064335c1cc..42bd8c140479 100644 --- a/datafusion/functions-array/src/array_has.rs +++ b/datafusion/functions-array/src/array_has.rs @@ -288,36 +288,39 @@ fn general_array_has_dispatch( } else { array }; - for (row_idx, (arr, sub_arr)) in array.iter().zip(sub_array.iter()).enumerate() { - if let (Some(arr), Some(sub_arr)) = (arr, sub_arr) { - let arr_values = converter.convert_columns(&[arr])?; - let sub_arr_values = if comparison_type != ComparisonType::Single { - converter.convert_columns(&[sub_arr])? - } else { - converter.convert_columns(&[element.clone()])? - }; - - let mut res = match comparison_type { - ComparisonType::All => sub_arr_values - .iter() - .dedup() - .all(|elem| arr_values.iter().dedup().any(|x| x == elem)), - ComparisonType::Any => sub_arr_values - .iter() - .dedup() - .any(|elem| arr_values.iter().dedup().any(|x| x == elem)), - ComparisonType::Single => arr_values - .iter() - .dedup() - .any(|x| x == sub_arr_values.row(row_idx)), - }; - - if comparison_type == ComparisonType::Any { - res |= res; + match (arr, sub_arr) { + (Some(arr), Some(sub_arr)) => { + let arr_values = converter.convert_columns(&[arr])?; + let sub_arr_values = if comparison_type != ComparisonType::Single { + converter.convert_columns(&[sub_arr])? + } else { + converter.convert_columns(&[element.clone()])? + }; + + let mut res = match comparison_type { + ComparisonType::All => sub_arr_values + .iter() + .dedup() + .all(|elem| arr_values.iter().dedup().any(|x| x == elem)), + ComparisonType::Any => sub_arr_values + .iter() + .dedup() + .any(|elem| arr_values.iter().dedup().any(|x| x == elem)), + ComparisonType::Single => arr_values + .iter() + .dedup() + .any(|x| x == sub_arr_values.row(row_idx)), + }; + + if comparison_type == ComparisonType::Any { + res |= res; + } + boolean_builder.append_value(res); } - - boolean_builder.append_value(res); + (_,_) => {boolean_builder.append_null();} + // (None, _) => boolean_builder.append_null(), + // (_, None) => {} } } Ok(Arc::new(boolean_builder.finish())) diff --git a/datafusion/functions/src/core/getfield.rs b/datafusion/functions/src/core/getfield.rs index b00b8ea553f2..e4195238233b 100644 --- a/datafusion/functions/src/core/getfield.rs +++ b/datafusion/functions/src/core/getfield.rs @@ -114,6 +114,8 @@ impl ScalarUDFImpl for GetFieldFunc { let keys = arrow::compute::kernels::cmp::eq(&key_scalar, map_array.keys())?; let entries = arrow::compute::filter(map_array.entries(), &keys)?; let entries_struct_array = as_struct_array(entries.as_ref())?; + let st = entries_struct_array.column(1).clone(); + println!("{:?}", st.len()); Ok(ColumnarValue::Array(entries_struct_array.column(1).clone())) } (DataType::Struct(_), ScalarValue::Utf8(Some(k))) => { diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 3456963aacfc..08887e0c742c 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -5169,8 +5169,9 @@ false false false true true false true false true false false true false true false false -false false false false -false false false false +NULL NULL false false +false false NULL false +false false false NULL query BBBB select array_has(arrow_cast(column1, 'LargeList(List(Int64))'), make_array(5, 6)), @@ -5183,8 +5184,9 @@ false false false true true false true false true false false true false true false false -false false false false -false false false false +NULL NULL false false +false false NULL false +false false false NULL query BBBB select array_has(column1, make_array(5, 6)), @@ -5197,8 +5199,9 @@ false false false true true false true false true false false true false true false false -false false false false -false false false false +NULL NULL false false +false false NULL false +false false false NULL query BBBBBBBBBBBBB select array_has_all(make_array(1,2,3), make_array(1,3)), From 7eebcf22a06ca8f3f6ed4cb029c95bf5cc81ecdc Mon Sep 17 00:00:00 2001 From: Duong Cong Toai Date: Sun, 21 Apr 2024 17:46:18 +0200 Subject: [PATCH 08/19] rm debug --- datafusion/functions/src/core/getfield.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/datafusion/functions/src/core/getfield.rs b/datafusion/functions/src/core/getfield.rs index e4195238233b..b00b8ea553f2 100644 --- a/datafusion/functions/src/core/getfield.rs +++ b/datafusion/functions/src/core/getfield.rs @@ -114,8 +114,6 @@ impl ScalarUDFImpl for GetFieldFunc { let keys = arrow::compute::kernels::cmp::eq(&key_scalar, map_array.keys())?; let entries = arrow::compute::filter(map_array.entries(), &keys)?; let entries_struct_array = as_struct_array(entries.as_ref())?; - let st = entries_struct_array.column(1).clone(); - println!("{:?}", st.len()); Ok(ColumnarValue::Array(entries_struct_array.column(1).clone())) } (DataType::Struct(_), ScalarValue::Utf8(Some(k))) => { From c972d7d27ed6e8544af3b49712ad92d685462398 Mon Sep 17 00:00:00 2001 From: Duong Cong Toai Date: Tue, 23 Apr 2024 23:02:52 +0200 Subject: [PATCH 09/19] chore: temp code for demonstration --- datafusion/functions/src/core/getfield.rs | 74 +++++++++++++++++------ 1 file changed, 54 insertions(+), 20 deletions(-) diff --git a/datafusion/functions/src/core/getfield.rs b/datafusion/functions/src/core/getfield.rs index b00b8ea553f2..fee3e5fe2139 100644 --- a/datafusion/functions/src/core/getfield.rs +++ b/datafusion/functions/src/core/getfield.rs @@ -15,14 +15,21 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::{Scalar, StringArray}; +use arrow::array::{ + make_builder, Array, ArrayBuilder, AsArray, BooleanArray, BooleanBuilder, Datum, + NullArray, NullBuilder, Scalar, StringArray, StringBuilder, StructArray, +}; +use arrow::compute::{is_null, FilterBuilder}; use arrow::datatypes::DataType; +use arrow::ipc::BoolBuilder; use datafusion_common::cast::{as_map_array, as_struct_array}; use datafusion_common::{exec_err, ExprSchema, Result, ScalarValue}; use datafusion_expr::field_util::GetFieldAccessSchema; use datafusion_expr::{ColumnarValue, Expr, ExprSchemable}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use std::any::Any; +use std::borrow::Borrow; +use std::sync::Arc; #[derive(Debug)] pub struct GetFieldFunc { @@ -107,29 +114,56 @@ impl ScalarUDFImpl for GetFieldFunc { ); } }; + match (array.data_type(), name) { - (DataType::Map(_, _), ScalarValue::Utf8(Some(k))) => { - let map_array = as_map_array(array.as_ref())?; - let key_scalar = Scalar::new(StringArray::from(vec![k.clone()])); - let keys = arrow::compute::kernels::cmp::eq(&key_scalar, map_array.keys())?; - let entries = arrow::compute::filter(map_array.entries(), &keys)?; + (DataType::Map(_, _), ScalarValue::Utf8(Some(k))) => { + let map_array = as_map_array(array.as_ref())?; + let key_scalar: Scalar>> = Scalar::new(StringArray::from(vec![k.clone()])); + let keys = arrow::compute::kernels::cmp::eq(&key_scalar, map_array.keys())?; + + // TODO: how to determine the type of this builder + let mut temp_builder= StringBuilder::new(); + + for entry in 0..map_array.len(){ + let end = map_array.value_offsets()[entry + 1] as usize; + let start = map_array.value_offsets()[entry] as usize; + + // child_struct is a subset of the original struct + let child_struct = map_array.value(entry); + + // one key in this child struct matching the input key + let entries = arrow::compute::filter( + dict.borrow(), + &keys.slice(start,end-start))?; + // at least one key matched + if entries.len() != 1 { + temp_builder.append_null(); + continue + } + // basically one row after filting let entries_struct_array = as_struct_array(entries.as_ref())?; - Ok(ColumnarValue::Array(entries_struct_array.column(1).clone())) + let str = entries_struct_array + .column(1).as_any().downcast_ref::().unwrap(); + temp_builder.append_value(str.value(0)); } - (DataType::Struct(_), ScalarValue::Utf8(Some(k))) => { - let as_struct_array = as_struct_array(&array)?; - match as_struct_array.column_by_name(k) { - None => exec_err!( - "get indexed field {k} not found in struct"), - Some(col) => Ok(ColumnarValue::Array(col.clone())) - } + + Ok(ColumnarValue::Array(Arc::new(temp_builder.finish()))) + } + (DataType::Struct(_), ScalarValue::Utf8(Some(k))) => { + let as_struct_array = as_struct_array(&array)?; + match as_struct_array.column_by_name(k) { + None => exec_err!("get indexed field {k} not found in struct"), + Some(col) => Ok(ColumnarValue::Array(col.clone())), } - (DataType::Struct(_), name) => exec_err!( - "get indexed field is only possible on struct with utf8 indexes. \ - Tried with {name:?} index"), - (dt, name) => exec_err!( - "get indexed field is only possible on lists with int64 indexes or struct \ - with utf8 indexes. Tried {dt:?} with {name:?} index"), } + (DataType::Struct(_), name) => exec_err!( + "get indexed field is only possible on struct with utf8 indexes. \ + Tried with {name:?} index" + ), + (dt, name) => exec_err!( + "get indexed field is only possible on lists with int64 indexes or struct \ + with utf8 indexes. Tried {dt:?} with {name:?} index" + ), + } } } From e5bbfaf31def98f6a9695bbb5ffcc0246243a9db Mon Sep 17 00:00:00 2001 From: Duong Cong Toai Date: Thu, 25 Apr 2024 21:28:17 +0200 Subject: [PATCH 10/19] getfield retains number of rows --- datafusion/functions/src/core/getfield.rs | 43 ++++++++++------------ datafusion/sqllogictest/test_files/map.slt | 1 + 2 files changed, 20 insertions(+), 24 deletions(-) diff --git a/datafusion/functions/src/core/getfield.rs b/datafusion/functions/src/core/getfield.rs index fee3e5fe2139..1fae56fd2fcf 100644 --- a/datafusion/functions/src/core/getfield.rs +++ b/datafusion/functions/src/core/getfield.rs @@ -16,20 +16,15 @@ // under the License. use arrow::array::{ - make_builder, Array, ArrayBuilder, AsArray, BooleanArray, BooleanBuilder, Datum, - NullArray, NullBuilder, Scalar, StringArray, StringBuilder, StructArray, + make_array, Array, Capacities, MutableArrayData, Scalar, StringArray, }; -use arrow::compute::{is_null, FilterBuilder}; use arrow::datatypes::DataType; -use arrow::ipc::BoolBuilder; use datafusion_common::cast::{as_map_array, as_struct_array}; use datafusion_common::{exec_err, ExprSchema, Result, ScalarValue}; use datafusion_expr::field_util::GetFieldAccessSchema; use datafusion_expr::{ColumnarValue, Expr, ExprSchemable}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use std::any::Any; -use std::borrow::Borrow; -use std::sync::Arc; #[derive(Debug)] pub struct GetFieldFunc { @@ -121,33 +116,33 @@ impl ScalarUDFImpl for GetFieldFunc { let key_scalar: Scalar>> = Scalar::new(StringArray::from(vec![k.clone()])); let keys = arrow::compute::kernels::cmp::eq(&key_scalar, map_array.keys())?; - // TODO: how to determine the type of this builder - let mut temp_builder= StringBuilder::new(); + let original_data = map_array.entries().column(1).to_data(); + let capacity = Capacities::Array(original_data.len()); + let mut mutable = + MutableArrayData::with_capacities(vec![&original_data], true, + capacity); for entry in 0..map_array.len(){ let end = map_array.value_offsets()[entry + 1] as usize; let start = map_array.value_offsets()[entry] as usize; - // child_struct is a subset of the original struct - let child_struct = map_array.value(entry); - - // one key in this child struct matching the input key - let entries = arrow::compute::filter( - dict.borrow(), - &keys.slice(start,end-start))?; // at least one key matched - if entries.len() != 1 { - temp_builder.append_null(); + let find_result = + keys.slice(start, end-start). + iter().enumerate(). + find(|(_,t)| t.unwrap()); + if find_result.is_none(){ + mutable.extend_nulls(1); continue } - // basically one row after filting - let entries_struct_array = as_struct_array(entries.as_ref())?; - let str = entries_struct_array - .column(1).as_any().downcast_ref::().unwrap(); - temp_builder.append_value(str.value(0)); - } + let (idx,_) = find_result.unwrap(); - Ok(ColumnarValue::Array(Arc::new(temp_builder.finish()))) + // TODO: can this value have more than 1 column + mutable.extend(0, start+idx, start+idx+1); + } + let data = mutable.freeze(); + let data = make_array(data); + Ok(ColumnarValue::Array(data)) } (DataType::Struct(_), ScalarValue::Utf8(Some(k))) => { let as_struct_array = as_struct_array(&array)?; diff --git a/datafusion/sqllogictest/test_files/map.slt b/datafusion/sqllogictest/test_files/map.slt index 415fabf224d7..8ff7d119c454 100644 --- a/datafusion/sqllogictest/test_files/map.slt +++ b/datafusion/sqllogictest/test_files/map.slt @@ -44,6 +44,7 @@ DELETE 24 query T SELECT strings['not_found'] FROM data LIMIT 1; ---- +NULL statement ok drop table data; From ed41d3a370071c14b333344bb5a063018fd9d56f Mon Sep 17 00:00:00 2001 From: Duong Cong Toai Date: Thu, 25 Apr 2024 21:29:02 +0200 Subject: [PATCH 11/19] rust fmt --- .../core/tests/user_defined/user_defined_scalar_functions.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index 5c6cff197914..f565b9518c4b 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -22,7 +22,6 @@ use arrow_array::{ use arrow_schema::DataType::Float64; use arrow_schema::{DataType, Field, Schema}; - use datafusion::execution::context::{FunctionFactory, RegisterFunction, SessionState}; use datafusion::prelude::*; use datafusion::{execution::registry::FunctionRegistry, test_util}; From 6603135085a98fe1a7ebfb080813023e2198ba4f Mon Sep 17 00:00:00 2001 From: Duong Cong Toai Date: Thu, 25 Apr 2024 21:40:26 +0200 Subject: [PATCH 12/19] minor comments --- datafusion/functions/src/core/getfield.rs | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/datafusion/functions/src/core/getfield.rs b/datafusion/functions/src/core/getfield.rs index 1fae56fd2fcf..a092aac159bb 100644 --- a/datafusion/functions/src/core/getfield.rs +++ b/datafusion/functions/src/core/getfield.rs @@ -116,6 +116,8 @@ impl ScalarUDFImpl for GetFieldFunc { let key_scalar: Scalar>> = Scalar::new(StringArray::from(vec![k.clone()])); let keys = arrow::compute::kernels::cmp::eq(&key_scalar, map_array.keys())?; + // note that this array has more entries than the expected output/input size + // because maparray is flatten let original_data = map_array.entries().column(1).to_data(); let capacity = Capacities::Array(original_data.len()); let mut mutable = @@ -123,22 +125,19 @@ impl ScalarUDFImpl for GetFieldFunc { capacity); for entry in 0..map_array.len(){ - let end = map_array.value_offsets()[entry + 1] as usize; let start = map_array.value_offsets()[entry] as usize; + let end = map_array.value_offsets()[entry + 1] as usize; - // at least one key matched - let find_result = - keys.slice(start, end-start). - iter().enumerate(). - find(|(_,t)| t.unwrap()); - if find_result.is_none(){ + let maybe_matched = + keys.slice(start, end-start). + iter().enumerate(). + find(|(_, t)| t.unwrap()); + if maybe_matched.is_none(){ mutable.extend_nulls(1); continue } - let (idx,_) = find_result.unwrap(); - - // TODO: can this value have more than 1 column - mutable.extend(0, start+idx, start+idx+1); + let (match_offset,_) = maybe_matched.unwrap(); + mutable.extend(0, start + match_offset, start + match_offset + 1); } let data = mutable.freeze(); let data = make_array(data); From cf7fac38c4c39aeeec37d3ac7eaf7ab2a3c2ee95 Mon Sep 17 00:00:00 2001 From: Duong Cong Toai Date: Thu, 25 Apr 2024 21:43:18 +0200 Subject: [PATCH 13/19] fmt --- datafusion/functions-array/src/array_has.rs | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/datafusion/functions-array/src/array_has.rs b/datafusion/functions-array/src/array_has.rs index 42bd8c140479..fbc04e16172a 100644 --- a/datafusion/functions-array/src/array_has.rs +++ b/datafusion/functions-array/src/array_has.rs @@ -318,9 +318,10 @@ fn general_array_has_dispatch( } boolean_builder.append_value(res); } - (_,_) => {boolean_builder.append_null();} - // (None, _) => boolean_builder.append_null(), - // (_, None) => {} + (_, _) => { + boolean_builder.append_null(); + } // (None, _) => boolean_builder.append_null(), + // (_, None) => {} } } Ok(Arc::new(boolean_builder.finish())) From fc304ae1733ed7b99305ffd43be7c02d6dd60b95 Mon Sep 17 00:00:00 2001 From: Duong Cong Toai Date: Thu, 25 Apr 2024 21:53:53 +0200 Subject: [PATCH 14/19] refactor --- .../user_defined/user_defined_scalar_functions.rs | 2 -- datafusion/functions-array/src/array_has.rs | 4 ++-- datafusion/physical-expr/src/scalar_function.rs | 12 +++++++++--- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index c9aa025cd358..a57ea59d359b 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -21,7 +21,6 @@ use arrow_array::{ }; use arrow_schema::DataType::Float64; use arrow_schema::{DataType, Field, Schema}; - use datafusion::execution::context::{FunctionFactory, RegisterFunction, SessionState}; use datafusion::prelude::*; use datafusion::{execution::registry::FunctionRegistry, test_util}; @@ -37,7 +36,6 @@ use datafusion_expr::{ Accumulator, ColumnarValue, CreateFunction, ExprSchemable, LogicalPlanBuilder, ScalarUDF, ScalarUDFImpl, Signature, Volatility, }; - use rand::{thread_rng, Rng}; use std::any::Any; use std::iter; diff --git a/datafusion/functions-array/src/array_has.rs b/datafusion/functions-array/src/array_has.rs index fbc04e16172a..d53d33c899a1 100644 --- a/datafusion/functions-array/src/array_has.rs +++ b/datafusion/functions-array/src/array_has.rs @@ -318,10 +318,10 @@ fn general_array_has_dispatch( } boolean_builder.append_value(res); } + // respect null input (_, _) => { boolean_builder.append_null(); - } // (None, _) => boolean_builder.append_null(), - // (_, None) => {} + } } } Ok(Arc::new(boolean_builder.finish())) diff --git a/datafusion/physical-expr/src/scalar_function.rs b/datafusion/physical-expr/src/scalar_function.rs index ff911b80be90..33a55a6b7b73 100644 --- a/datafusion/physical-expr/src/scalar_function.rs +++ b/datafusion/physical-expr/src/scalar_function.rs @@ -162,12 +162,18 @@ impl PhysicalExpr for ScalarFunctionExpr { match self.fun { ScalarFunctionDefinition::UDF(ref fun) => { let output = fun.invoke(&inputs)?; - if let ColumnarValue::Array(array) = &output { - if array.len() != batch.num_rows() { + // Only arrow_typeof can bypass this rule + if fun.name() != "arrow_typeof" { + let output_size = match array { + ColumnarValue::Array(array) => array.len(), + ColumnarValue::Scalar(_) => 1, + }; + if output_size != batch.num_rows() { return internal_err!("UDF returned a different number of rows than expected. Expected: {}, Got: {}", - batch.num_rows(), array.len()); + batch.num_rows(), output_size); } } + Ok(output) } ScalarFunctionDefinition::Name(_) => { From e70245eeaa91395e4fc4061924f8804f20037bfa Mon Sep 17 00:00:00 2001 From: Duong Cong Toai Date: Thu, 25 Apr 2024 21:58:09 +0200 Subject: [PATCH 15/19] compile err --- datafusion/physical-expr/src/scalar_function.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/physical-expr/src/scalar_function.rs b/datafusion/physical-expr/src/scalar_function.rs index 33a55a6b7b73..8616e12a427c 100644 --- a/datafusion/physical-expr/src/scalar_function.rs +++ b/datafusion/physical-expr/src/scalar_function.rs @@ -164,7 +164,7 @@ impl PhysicalExpr for ScalarFunctionExpr { let output = fun.invoke(&inputs)?; // Only arrow_typeof can bypass this rule if fun.name() != "arrow_typeof" { - let output_size = match array { + let output_size = match &output { ColumnarValue::Array(array) => array.len(), ColumnarValue::Scalar(_) => 1, }; From cda3e3b741becd222469a40fd377de293d7c958b Mon Sep 17 00:00:00 2001 From: Duong Cong Toai Date: Thu, 25 Apr 2024 22:01:28 +0200 Subject: [PATCH 16/19] fmt again --- datafusion/functions-array/src/array_has.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/functions-array/src/array_has.rs b/datafusion/functions-array/src/array_has.rs index d53d33c899a1..e5e8add95fbe 100644 --- a/datafusion/functions-array/src/array_has.rs +++ b/datafusion/functions-array/src/array_has.rs @@ -318,7 +318,7 @@ fn general_array_has_dispatch( } boolean_builder.append_value(res); } - // respect null input + // respect null input (_, _) => { boolean_builder.append_null(); } From efb1c5ff9c1543c462f504c42040b0fc0e6b8b78 Mon Sep 17 00:00:00 2001 From: Duong Cong Toai Date: Fri, 26 Apr 2024 21:48:44 +0200 Subject: [PATCH 17/19] fmt --- datafusion/physical-expr/src/scalar_function.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/physical-expr/src/scalar_function.rs b/datafusion/physical-expr/src/scalar_function.rs index ebc551f02e35..249a07b33b46 100644 --- a/datafusion/physical-expr/src/scalar_function.rs +++ b/datafusion/physical-expr/src/scalar_function.rs @@ -148,7 +148,7 @@ impl PhysicalExpr for ScalarFunctionExpr { ScalarFunctionDefinition::UDF(ref fun) => { let output = match self.args.is_empty() { true => fun.invoke_no_args(batch.num_rows()), - false => fun.invoke(&inputs) + false => fun.invoke(&inputs), }?; // Only arrow_typeof can bypass this rule if fun.name() != "arrow_typeof" { From 6c397e85ea490034c7d5383e5010ff88fb3d10c1 Mon Sep 17 00:00:00 2001 From: Duong Cong Toai Date: Fri, 26 Apr 2024 22:13:44 +0200 Subject: [PATCH 18/19] add validate_number_of_rows for UDF --- datafusion/expr/src/udf.rs | 10 ++++++++++ datafusion/functions/src/core/arrowtypeof.rs | 4 ++++ datafusion/physical-expr/src/scalar_function.rs | 4 ++-- 3 files changed, 16 insertions(+), 2 deletions(-) diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index c9c11a6bbfea..027abdf5c573 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -205,6 +205,11 @@ impl ScalarUDF { pub fn short_circuits(&self) -> bool { self.inner.short_circuits() } + + /// Validate the number of rows for the function + pub fn validate_number_of_rows(&self) -> bool { + self.inner.validate_number_of_rows() + } } impl From for ScalarUDF @@ -405,6 +410,11 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { fn short_circuits(&self) -> bool { false } + + /// Most of the UDFs should have the same number of input and output rows. + fn validate_number_of_rows(&self) -> bool { + true + } } /// ScalarUDF that adds an alias to the underlying function. It is better to diff --git a/datafusion/functions/src/core/arrowtypeof.rs b/datafusion/functions/src/core/arrowtypeof.rs index cc5e7e619bd8..4a93dd7fa105 100644 --- a/datafusion/functions/src/core/arrowtypeof.rs +++ b/datafusion/functions/src/core/arrowtypeof.rs @@ -69,4 +69,8 @@ impl ScalarUDFImpl for ArrowTypeOfFunc { "{input_data_type}" )))) } + + fn validate_number_of_rows(&self) -> bool { + false + } } diff --git a/datafusion/physical-expr/src/scalar_function.rs b/datafusion/physical-expr/src/scalar_function.rs index 249a07b33b46..f2c1a589f3a4 100644 --- a/datafusion/physical-expr/src/scalar_function.rs +++ b/datafusion/physical-expr/src/scalar_function.rs @@ -150,8 +150,8 @@ impl PhysicalExpr for ScalarFunctionExpr { true => fun.invoke_no_args(batch.num_rows()), false => fun.invoke(&inputs), }?; - // Only arrow_typeof can bypass this rule - if fun.name() != "arrow_typeof" { + + if fun.validate_number_of_rows() { let output_size = match &output { ColumnarValue::Array(array) => array.len(), ColumnarValue::Scalar(_) => 1, From c1458c224215c3a1d4536f5ff611a971cad4169b Mon Sep 17 00:00:00 2001 From: Duong Cong Toai Date: Sat, 27 Apr 2024 13:57:08 +0200 Subject: [PATCH 19/19] only check for columnarvalue::array --- datafusion/expr/src/udf.rs | 10 ---------- datafusion/functions/src/core/arrowtypeof.rs | 4 ---- datafusion/physical-expr/src/scalar_function.rs | 10 +++------- 3 files changed, 3 insertions(+), 21 deletions(-) diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 027abdf5c573..c9c11a6bbfea 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -205,11 +205,6 @@ impl ScalarUDF { pub fn short_circuits(&self) -> bool { self.inner.short_circuits() } - - /// Validate the number of rows for the function - pub fn validate_number_of_rows(&self) -> bool { - self.inner.validate_number_of_rows() - } } impl From for ScalarUDF @@ -410,11 +405,6 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { fn short_circuits(&self) -> bool { false } - - /// Most of the UDFs should have the same number of input and output rows. - fn validate_number_of_rows(&self) -> bool { - true - } } /// ScalarUDF that adds an alias to the underlying function. It is better to diff --git a/datafusion/functions/src/core/arrowtypeof.rs b/datafusion/functions/src/core/arrowtypeof.rs index 4a93dd7fa105..cc5e7e619bd8 100644 --- a/datafusion/functions/src/core/arrowtypeof.rs +++ b/datafusion/functions/src/core/arrowtypeof.rs @@ -69,8 +69,4 @@ impl ScalarUDFImpl for ArrowTypeOfFunc { "{input_data_type}" )))) } - - fn validate_number_of_rows(&self) -> bool { - false - } } diff --git a/datafusion/physical-expr/src/scalar_function.rs b/datafusion/physical-expr/src/scalar_function.rs index f2c1a589f3a4..b9c6ff3cfefc 100644 --- a/datafusion/physical-expr/src/scalar_function.rs +++ b/datafusion/physical-expr/src/scalar_function.rs @@ -151,14 +151,10 @@ impl PhysicalExpr for ScalarFunctionExpr { false => fun.invoke(&inputs), }?; - if fun.validate_number_of_rows() { - let output_size = match &output { - ColumnarValue::Array(array) => array.len(), - ColumnarValue::Scalar(_) => 1, - }; - if output_size != batch.num_rows() { + if let ColumnarValue::Array(array) = &output { + if array.len() != batch.num_rows() { return internal_err!("UDF returned a different number of rows than expected. Expected: {}, Got: {}", - batch.num_rows(), output_size); + batch.num_rows(), array.len()); } } Ok(output)