Skip to content

Commit

Permalink
Minor: Consolidate UDF tests (#7704)
Browse files Browse the repository at this point in the history
* Minor: Consolidate user defined functions

* cleanup

* move more tests

* more

* cleanup use
  • Loading branch information
alamb authored Oct 3, 2023
1 parent b1587c1 commit 32fe176
Show file tree
Hide file tree
Showing 6 changed files with 222 additions and 206 deletions.
86 changes: 1 addition & 85 deletions datafusion/core/src/execution/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2222,16 +2222,13 @@ mod tests {
use crate::execution::context::QueryPlanner;
use crate::execution::memory_pool::MemoryConsumer;
use crate::execution::runtime_env::RuntimeConfig;
use crate::physical_plan::expressions::AvgAccumulator;
use crate::test;
use crate::test_util::parquet_test_data;
use crate::variable::VarType;
use arrow::array::ArrayRef;
use arrow::record_batch::RecordBatch;
use arrow_schema::{Field, Schema};
use async_trait::async_trait;
use datafusion_expr::{create_udaf, create_udf, Expr, Volatility};
use datafusion_physical_expr::functions::make_scalar_function;
use datafusion_expr::Expr;
use std::fs::File;
use std::path::PathBuf;
use std::sync::Weak;
Expand Down Expand Up @@ -2330,87 +2327,6 @@ mod tests {
Ok(())
}

#[tokio::test]
async fn case_sensitive_identifiers_user_defined_functions() -> Result<()> {
let ctx = SessionContext::new();
ctx.register_table("t", test::table_with_sequence(1, 1).unwrap())
.unwrap();

let myfunc = |args: &[ArrayRef]| Ok(Arc::clone(&args[0]));
let myfunc = make_scalar_function(myfunc);

ctx.register_udf(create_udf(
"MY_FUNC",
vec![DataType::Int32],
Arc::new(DataType::Int32),
Volatility::Immutable,
myfunc,
));

// doesn't work as it was registered with non lowercase
let err = plan_and_collect(&ctx, "SELECT MY_FUNC(i) FROM t")
.await
.unwrap_err();
assert!(err
.to_string()
.contains("Error during planning: Invalid function \'my_func\'"));

// Can call it if you put quotes
let result = plan_and_collect(&ctx, "SELECT \"MY_FUNC\"(i) FROM t").await?;

let expected = [
"+--------------+",
"| MY_FUNC(t.i) |",
"+--------------+",
"| 1 |",
"+--------------+",
];
assert_batches_eq!(expected, &result);

Ok(())
}

#[tokio::test]
async fn case_sensitive_identifiers_user_defined_aggregates() -> Result<()> {
let ctx = SessionContext::new();
ctx.register_table("t", test::table_with_sequence(1, 1).unwrap())
.unwrap();

// Note capitalization
let my_avg = create_udaf(
"MY_AVG",
vec![DataType::Float64],
Arc::new(DataType::Float64),
Volatility::Immutable,
Arc::new(|_| Ok(Box::<AvgAccumulator>::default())),
Arc::new(vec![DataType::UInt64, DataType::Float64]),
);

ctx.register_udaf(my_avg);

// doesn't work as it was registered as non lowercase
let err = plan_and_collect(&ctx, "SELECT MY_AVG(i) FROM t")
.await
.unwrap_err();
assert!(err
.to_string()
.contains("Error during planning: Invalid function \'my_avg\'"));

// Can call it if you put quotes
let result = plan_and_collect(&ctx, "SELECT \"MY_AVG\"(i) FROM t").await?;

let expected = [
"+-------------+",
"| MY_AVG(t.i) |",
"+-------------+",
"| 1.0 |",
"+-------------+",
];
assert_batches_eq!(expected, &result);

Ok(())
}

#[tokio::test]
async fn query_csv_with_custom_partition_extension() -> Result<()> {
let tmp_dir = TempDir::new()?;
Expand Down
16 changes: 2 additions & 14 deletions datafusion/core/tests/sql/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -616,7 +616,7 @@ async fn test_array_cast_expressions() -> Result<()> {

#[tokio::test]
async fn test_random_expression() -> Result<()> {
let ctx = create_ctx();
let ctx = SessionContext::new();
let sql = "SELECT random() r1";
let actual = execute(&ctx, sql).await;
let r1 = actual[0][0].parse::<f64>().unwrap();
Expand All @@ -627,7 +627,7 @@ async fn test_random_expression() -> Result<()> {

#[tokio::test]
async fn test_uuid_expression() -> Result<()> {
let ctx = create_ctx();
let ctx = SessionContext::new();
let sql = "SELECT uuid()";
let actual = execute(&ctx, sql).await;
let uuid = actual[0][0].parse::<uuid::Uuid>().unwrap();
Expand Down Expand Up @@ -886,18 +886,6 @@ async fn csv_query_nullif_divide_by_0() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn csv_query_avg_sqrt() -> Result<()> {
let ctx = create_ctx();
register_aggregate_csv(&ctx).await?;
let sql = "SELECT avg(custom_sqrt(c12)) FROM aggregate_test_100";
let mut actual = execute(&ctx, sql).await;
actual.sort();
let expected = vec![vec!["0.6706002946036462"]];
assert_float_eq(&expected, &actual);
Ok(())
}

#[tokio::test]
async fn nested_subquery() -> Result<()> {
let ctx = SessionContext::new();
Expand Down
55 changes: 1 addition & 54 deletions datafusion/core/tests/sql/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ use chrono::prelude::*;
use chrono::Duration;

use datafusion::datasource::TableProvider;
use datafusion::error::{DataFusionError, Result};
use datafusion::logical_expr::{Aggregate, LogicalPlan, TableScan};
use datafusion::physical_plan::metrics::MetricValue;
use datafusion::physical_plan::ExecutionPlan;
Expand All @@ -34,15 +35,9 @@ use datafusion::prelude::*;
use datafusion::test_util;
use datafusion::{assert_batches_eq, assert_batches_sorted_eq};
use datafusion::{datasource::MemTable, physical_plan::collect};
use datafusion::{
error::{DataFusionError, Result},
physical_plan::ColumnarValue,
};
use datafusion::{execution::context::SessionContext, physical_plan::displayable};
use datafusion_common::cast::as_float64_array;
use datafusion_common::plan_err;
use datafusion_common::{assert_contains, assert_not_contains};
use datafusion_expr::Volatility;
use object_store::path::Path;
use std::fs::File;
use std::io::Write;
Expand Down Expand Up @@ -101,54 +96,6 @@ pub mod select;
mod sql_api;
pub mod subqueries;
pub mod timestamp;
pub mod udf;

fn assert_float_eq<T>(expected: &[Vec<T>], received: &[Vec<String>])
where
T: AsRef<str>,
{
expected
.iter()
.flatten()
.zip(received.iter().flatten())
.for_each(|(l, r)| {
let (l, r) = (
l.as_ref().parse::<f64>().unwrap(),
r.as_str().parse::<f64>().unwrap(),
);
if l.is_nan() || r.is_nan() {
assert!(l.is_nan() && r.is_nan());
} else if (l - r).abs() > 2.0 * f64::EPSILON {
panic!("{l} != {r}")
}
});
}

fn create_ctx() -> SessionContext {
let ctx = SessionContext::new();

// register a custom UDF
ctx.register_udf(create_udf(
"custom_sqrt",
vec![DataType::Float64],
Arc::new(DataType::Float64),
Volatility::Immutable,
Arc::new(custom_sqrt),
));

ctx
}

fn custom_sqrt(args: &[ColumnarValue]) -> Result<ColumnarValue> {
let arg = &args[0];
if let ColumnarValue::Array(v) = arg {
let input = as_float64_array(v).expect("cast failed");
let array: Float64Array = input.iter().map(|v| v.map(|x| x.sqrt())).collect();
Ok(ColumnarValue::Array(Arc::new(array)))
} else {
unimplemented!()
}
}

fn create_join_context(
column_left: &str,
Expand Down
3 changes: 3 additions & 0 deletions datafusion/core/tests/user_defined/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
// specific language governing permissions and limitations
// under the License.

/// Tests for user defined Scalar functions
mod user_defined_scalar_functions;

/// Tests for User Defined Aggregate Functions
mod user_defined_aggregates;

Expand Down
94 changes: 94 additions & 0 deletions datafusion/core/tests/user_defined/user_defined_aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,14 @@
//! user defined aggregate functions
use arrow::{array::AsArray, datatypes::Fields};
use arrow_array::Int32Array;
use arrow_schema::Schema;
use std::sync::{
atomic::{AtomicBool, Ordering},
Arc,
};

use datafusion::datasource::MemTable;
use datafusion::{
arrow::{
array::{ArrayRef, Float64Array, TimestampNanosecondArray},
Expand All @@ -43,6 +46,8 @@ use datafusion::{
use datafusion_common::{
assert_contains, cast::as_primitive_array, exec_err, DataFusionError,
};
use datafusion_expr::create_udaf;
use datafusion_physical_expr::expressions::AvgAccumulator;

/// Test to show the contents of the setup
#[tokio::test]
Expand Down Expand Up @@ -204,6 +209,95 @@ async fn execute(ctx: &SessionContext, sql: &str) -> Result<Vec<RecordBatch>> {
ctx.sql(sql).await?.collect().await
}

/// tests the creation, registration and usage of a UDAF
#[tokio::test]
async fn simple_udaf() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);

let batch1 = RecordBatch::try_new(
Arc::new(schema.clone()),
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
)?;
let batch2 = RecordBatch::try_new(
Arc::new(schema.clone()),
vec![Arc::new(Int32Array::from(vec![4, 5]))],
)?;

let ctx = SessionContext::new();

let provider = MemTable::try_new(Arc::new(schema), vec![vec![batch1], vec![batch2]])?;
ctx.register_table("t", Arc::new(provider))?;

// define a udaf, using a DataFusion's accumulator
let my_avg = create_udaf(
"my_avg",
vec![DataType::Float64],
Arc::new(DataType::Float64),
Volatility::Immutable,
Arc::new(|_| Ok(Box::<AvgAccumulator>::default())),
Arc::new(vec![DataType::UInt64, DataType::Float64]),
);

ctx.register_udaf(my_avg);

let result = ctx.sql("SELECT MY_AVG(a) FROM t").await?.collect().await?;

let expected = [
"+-------------+",
"| my_avg(t.a) |",
"+-------------+",
"| 3.0 |",
"+-------------+",
];
assert_batches_eq!(expected, &result);

Ok(())
}

#[tokio::test]
async fn case_sensitive_identifiers_user_defined_aggregates() -> Result<()> {
let ctx = SessionContext::new();
let arr = Int32Array::from(vec![1]);
let batch = RecordBatch::try_from_iter(vec![("i", Arc::new(arr) as _)])?;
ctx.register_batch("t", batch).unwrap();

// Note capitalization
let my_avg = create_udaf(
"MY_AVG",
vec![DataType::Float64],
Arc::new(DataType::Float64),
Volatility::Immutable,
Arc::new(|_| Ok(Box::<AvgAccumulator>::default())),
Arc::new(vec![DataType::UInt64, DataType::Float64]),
);

ctx.register_udaf(my_avg);

// doesn't work as it was registered as non lowercase
let err = ctx.sql("SELECT MY_AVG(i) FROM t").await.unwrap_err();
assert!(err
.to_string()
.contains("Error during planning: Invalid function \'my_avg\'"));

// Can call it if you put quotes
let result = ctx
.sql("SELECT \"MY_AVG\"(i) FROM t")
.await?
.collect()
.await?;

let expected = [
"+-------------+",
"| MY_AVG(t.i) |",
"+-------------+",
"| 1.0 |",
"+-------------+",
];
assert_batches_eq!(expected, &result);

Ok(())
}

/// Returns an context with a table "t" and the "first" and "time_sum"
/// aggregate functions registered.
///
Expand Down
Loading

0 comments on commit 32fe176

Please sign in to comment.