Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validate ScalarUDF output rows and fix nulls for array_has and get_field for Map #10148

Merged
merged 22 commits into from
Apr 29, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ use arrow_array::{
};
use arrow_schema::DataType::Float64;
use arrow_schema::{DataType, Field, Schema};


use datafusion::execution::context::{FunctionFactory, RegisterFunction, SessionState};
use datafusion::prelude::*;
use datafusion::{execution::registry::FunctionRegistry, test_util};
Expand All @@ -29,13 +31,14 @@ use datafusion_common::{
assert_batches_eq, assert_batches_sorted_eq, cast::as_float64_array,
cast::as_int32_array, not_impl_err, plan_err, ExprSchema, Result, ScalarValue,
};
use datafusion_common::{exec_err, internal_err, DataFusionError};
use datafusion_common::{assert_contains, exec_err, internal_err, DataFusionError};
use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv};
use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo};
use datafusion_expr::{
create_udaf, create_udf, Accumulator, ColumnarValue, CreateFunction, ExprSchemable,
LogicalPlanBuilder, ScalarUDF, ScalarUDFImpl, Signature, Volatility,
};

use rand::{thread_rng, Rng};
use std::any::Any;
use std::iter;
Expand Down Expand Up @@ -168,6 +171,44 @@ async fn scalar_udf() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn test_row_mismatch_error_in_scalar_udf() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);

let batch = RecordBatch::try_new(
Arc::new(schema.clone()),
vec![Arc::new(Int32Array::from(vec![1, 2]))],
)?;

let ctx = SessionContext::new();

ctx.register_batch("t", batch)?;

// udf that always return 1 row
let buggy_udf = Arc::new(|_: &[ColumnarValue]| {
Ok(ColumnarValue::Array(Arc::new(Int32Array::from(vec![0]))))
});

ctx.register_udf(create_udf(
"buggy_func",
vec![DataType::Int32],
Arc::new(DataType::Int32),
Volatility::Immutable,
buggy_udf,
));
assert_contains!(
ctx.sql("select buggy_func(a) from t")
.await?
.show()
.await
.err()
.unwrap()
.to_string(),
"UDF returned a different number of rows than expected"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👌 -- very nice

);
Ok(())
}

#[tokio::test]
async fn scalar_udf_zero_params() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
Expand Down
59 changes: 31 additions & 28 deletions datafusion/functions-array/src/array_has.rs
Original file line number Diff line number Diff line change
Expand Up @@ -288,36 +288,39 @@ fn general_array_has_dispatch<O: OffsetSizeTrait>(
} else {
array
};

for (row_idx, (arr, sub_arr)) in array.iter().zip(sub_array.iter()).enumerate() {
if let (Some(arr), Some(sub_arr)) = (arr, sub_arr) {
let arr_values = converter.convert_columns(&[arr])?;
let sub_arr_values = if comparison_type != ComparisonType::Single {
converter.convert_columns(&[sub_arr])?
} else {
converter.convert_columns(&[element.clone()])?
};

let mut res = match comparison_type {
ComparisonType::All => sub_arr_values
.iter()
.dedup()
.all(|elem| arr_values.iter().dedup().any(|x| x == elem)),
ComparisonType::Any => sub_arr_values
.iter()
.dedup()
.any(|elem| arr_values.iter().dedup().any(|x| x == elem)),
ComparisonType::Single => arr_values
.iter()
.dedup()
.any(|x| x == sub_arr_values.row(row_idx)),
};

if comparison_type == ComparisonType::Any {
res |= res;
match (arr, sub_arr) {
(Some(arr), Some(sub_arr)) => {
let arr_values = converter.convert_columns(&[arr])?;
let sub_arr_values = if comparison_type != ComparisonType::Single {
converter.convert_columns(&[sub_arr])?
} else {
converter.convert_columns(&[element.clone()])?
};

let mut res = match comparison_type {
ComparisonType::All => sub_arr_values
.iter()
.dedup()
.all(|elem| arr_values.iter().dedup().any(|x| x == elem)),
ComparisonType::Any => sub_arr_values
.iter()
.dedup()
.any(|elem| arr_values.iter().dedup().any(|x| x == elem)),
ComparisonType::Single => arr_values
.iter()
.dedup()
.any(|x| x == sub_arr_values.row(row_idx)),
};

if comparison_type == ComparisonType::Any {
res |= res;
}
boolean_builder.append_value(res);
}

boolean_builder.append_value(res);
(_,_) => {boolean_builder.append_null();}
// (None, _) => boolean_builder.append_null(),
// (_, None) => {}
}
}
Ok(Arc::new(boolean_builder.finish()))
Expand Down
2 changes: 2 additions & 0 deletions datafusion/functions/src/core/getfield.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ impl ScalarUDFImpl for GetFieldFunc {
let keys = arrow::compute::kernels::cmp::eq(&key_scalar, map_array.keys())?;
let entries = arrow::compute::filter(map_array.entries(), &keys)?;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

using filter will reduce the number of input rows to the number of rows that have keys matching the input key. But we want to respect the number of input rows, and give null for any rows not having the matching key

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand this

If the input is like this (two rows, each three elements)

{ a: 1, b: 2, c: 100}
{ a: 3, b: 4, c: 200}

An expression like col['c'] will still return 2 rows (but each row will have only a single element)

{ c: 100 }
{ c: 200 }

Copy link
Contributor Author

@duongcongtoai duongcongtoai Apr 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previous implememtation

map_array.entries() has type of

pub struct StructArray {
    len: usize,
    data_type: DataType,
    nulls: Option<NullBuffer>,
    fields: Vec<ArrayRef>,
}

With the example above, the layout of field "fields" will be a vector of 2 array, where first array is a list of key, and second array is a list of value

[0]: ["a","b","c","a","b",c"]
[1]: [1,2,100,3,4,200]
                    let keys = arrow::compute::kernels::cmp::eq(&key_scalar, map_array.keys())?;

with this computation, the result is a boolean aray where "key" = "c"

[false,false,true,false,false,true]

and thus this operation will reduce the number of rows into

                    let entries = arrow::compute::filter(map_array.entries(), &keys)?;
[0]: ["c,"c"]
[1]: [100,200]

Problem

However, let's add a row where the map does not have key "c" in between

{ a: 1, b: 2, c: 100}
{ a: 1, b: 2}
{ a: 3, b: 4, c: 200}

map_array.entries() underneath is represented as

[0]: ["a,"b","c","a","b","a","b","c"]
[1]: [1,2,100,1,2,3,4,200]

                    let entries = arrow::compute::filter(map_array.entries(), &keys)?;
Now rows after filtered will be
[0]: ["c","c"]
[1]: [100,200]

and the return result will be

{ c: 100 }
{ c: 200 }

instead of

{ c: 100 }
null
{ c: 200 }

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would expect the result of evaluating col[b] on

{ a: 1, b: 2, c: 100}
{ a: 1, b: 2}
{ a: 3, b: 4, c: 200}

to be:

{ c: 100 }
null
{ c: 200 }

For example, in duckdb:

D create table foo as values (MAP {'a':1, 'b':2, 'c':100}), (MAP{ 'a':1, 'b':2}), (MAP {'a':1, 'b':2, 'c':200});
D select * from foo;
┌───────────────────────┐
│         col0          │
│ map(varchar, integer) │
├───────────────────────┤
│ {a=1, b=2, c=100}     │
│ {a=1, b=2}            │
│ {a=1, b=2, c=200}     │
└───────────────────────┘
D select col0['c'] from foo;
┌───────────┐
│ col0['c'] │
│  int32[]  │
├───────────┤
│ [100]     │
│ []        │
│ [200]     │
└───────────┘

Basically a scalar function has the invarant that each input row produces exactly 1 output row

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i also explained in this discussion: #10148 (comment)

let entries_struct_array = as_struct_array(entries.as_ref())?;
let st = entries_struct_array.column(1).clone();
println!("{:?}", st.len());
Ok(ColumnarValue::Array(entries_struct_array.column(1).clone()))
}
(DataType::Struct(_), ScalarValue::Utf8(Some(k))) => {
Expand Down
11 changes: 10 additions & 1 deletion datafusion/physical-expr/src/scalar_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,16 @@ impl PhysicalExpr for ScalarFunctionExpr {
let fun = create_physical_fun(fun)?;
(fun)(&inputs)
}
ScalarFunctionDefinition::UDF(ref fun) => fun.invoke(&inputs),
ScalarFunctionDefinition::UDF(ref fun) => {
let output = fun.invoke(&inputs)?;
if let ColumnarValue::Array(array) = &output {
if array.len() != batch.num_rows() {
return internal_err!("UDF returned a different number of rows than expected. Expected: {}, Got: {}",
batch.num_rows(), array.len());
}
}
Ok(output)
}
ScalarFunctionDefinition::Name(_) => {
internal_err!(
"Name function must be resolved to one of the other variants prior to physical planning"
Expand Down
15 changes: 9 additions & 6 deletions datafusion/sqllogictest/test_files/array.slt
Original file line number Diff line number Diff line change
Expand Up @@ -5169,8 +5169,9 @@ false false false true
true false true false
true false false true
false true false false
false false false false
false false false false
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test result does not look correct, because it ignore some null rows in between

NULL NULL false false
false false NULL false
false false false NULL

query BBBB
select array_has(arrow_cast(column1, 'LargeList(List(Int64))'), make_array(5, 6)),
Expand All @@ -5183,8 +5184,9 @@ false false false true
true false true false
true false false true
false true false false
false false false false
false false false false
NULL NULL false false
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I double checked and the arrays table has 7 rows, so I agree the correct answer has 7 output rows as well

statement ok
CREATE TABLE arrays
AS VALUES
(make_array(make_array(NULL, 2),make_array(3, NULL)), make_array(1.1, 2.2, 3.3), make_array('L', 'o', 'r', 'e', 'm')),
(make_array(make_array(3, 4),make_array(5, 6)), make_array(NULL, 5.5, 6.6), make_array('i', 'p', NULL, 'u', 'm')),
(make_array(make_array(5, 6),make_array(7, 8)), make_array(7.7, 8.8, 9.9), make_array('d', NULL, 'l', 'o', 'r')),
(make_array(make_array(7, NULL),make_array(9, 10)), make_array(10.1, NULL, 12.2), make_array('s', 'i', 't')),
(NULL, make_array(13.3, 14.4, 15.5), make_array('a', 'm', 'e', 't')),
(make_array(make_array(11, 12),make_array(13, 14)), NULL, make_array(',')),
(make_array(make_array(15, 16),make_array(NULL, 18)), make_array(16.6, 17.7, 18.8), NULL)
;

false false NULL false
false false false NULL

query BBBB
select array_has(column1, make_array(5, 6)),
Expand All @@ -5197,8 +5199,9 @@ false false false true
true false true false
true false false true
false true false false
false false false false
false false false false
NULL NULL false false
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Iikewise I agree this should have 7 output rows

statement ok
CREATE TABLE fixed_size_arrays
AS VALUES
(arrow_cast(make_array(make_array(NULL, 2),make_array(3, NULL)), 'FixedSizeList(2, List(Int64))'), arrow_cast(make_array(1.1, 2.2, 3.3), 'FixedSizeList(3, Float64)'), arrow_cast(make_array('L', 'o', 'r', 'e', 'm'), 'FixedSizeList(5, Utf8)')),
(arrow_cast(make_array(make_array(3, 4),make_array(5, 6)), 'FixedSizeList(2, List(Int64))'), arrow_cast(make_array(NULL, 5.5, 6.6), 'FixedSizeList(3, Float64)'), arrow_cast(make_array('i', 'p', NULL, 'u', 'm'), 'FixedSizeList(5, Utf8)')),
(arrow_cast(make_array(make_array(5, 6),make_array(7, 8)), 'FixedSizeList(2, List(Int64))'), arrow_cast(make_array(7.7, 8.8, 9.9), 'FixedSizeList(3, Float64)'), arrow_cast(make_array('d', NULL, 'l', 'o', 'r'), 'FixedSizeList(5, Utf8)')),
(arrow_cast(make_array(make_array(7, NULL),make_array(9, 10)), 'FixedSizeList(2, List(Int64))'), arrow_cast(make_array(10.1, NULL, 12.2), 'FixedSizeList(3, Float64)'), arrow_cast(make_array('s', 'i', 't', 'a', 'b'), 'FixedSizeList(5, Utf8)')),
(NULL, arrow_cast(make_array(13.3, 14.4, 15.5), 'FixedSizeList(3, Float64)'), arrow_cast(make_array('a', 'm', 'e', 't', 'x'), 'FixedSizeList(5, Utf8)')),
(arrow_cast(make_array(make_array(11, 12),make_array(13, 14)), 'FixedSizeList(2, List(Int64))'), NULL, arrow_cast(make_array(',','a','b','c','d'), 'FixedSizeList(5, Utf8)')),
(arrow_cast(make_array(make_array(15, 16),make_array(NULL, 18)), 'FixedSizeList(2, List(Int64))'), arrow_cast(make_array(16.6, 17.7, 18.8), 'FixedSizeList(3, Float64)'), NULL)
;

false false NULL false
false false false NULL

query BBBBBBBBBBBBB
select array_has_all(make_array(1,2,3), make_array(1,3)),
Expand Down
Loading