Skip to content

Commit

Permalink
make lead/lag throw if arg is invalid, check that the arg is int befo…
Browse files Browse the repository at this point in the history
…re casting, add tests
  • Loading branch information
Blizzara committed Jul 2, 2024
1 parent 9f37c0e commit 356d706
Showing 1 changed file with 141 additions and 20 deletions.
161 changes: 141 additions & 20 deletions datafusion/physical-plan/src/windows/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,15 @@ fn get_scalar_value_from_args(
})
}

fn get_integer(value: ScalarValue) -> Result<i64> {
if !value.data_type().is_integer() {
return Err(DataFusionError::Execution(
"Expected an integer value".to_string(),
));
}
value.cast_to(&DataType::Int64)?.try_into()
}

fn get_casted_value(
default_value: Option<ScalarValue>,
dtype: &DataType,
Expand Down Expand Up @@ -257,22 +266,17 @@ fn create_built_in_window_expr(
return exec_err!("NTILE requires a positive integer, but finds NULL");
}

if n.is_unsigned() {
let n: u64 = n.cast_to(&DataType::UInt64)?.try_into()?;
Arc::new(Ntile::new(name, n, out_data_type))
} else {
let n: i64 = n.cast_to(&DataType::Int64)?.try_into()?;
if n <= 0 {
return exec_err!("NTILE requires a positive integer");
}
Arc::new(Ntile::new(name, n as u64, out_data_type))
let n: i64 = get_integer(n)?;
if n <= 0 {
return exec_err!("NTILE requires a positive integer");
}
Arc::new(Ntile::new(name, n as u64, out_data_type))
}
BuiltInWindowFunction::Lag => {
let arg = args[0].clone();
let shift_offset = get_scalar_value_from_args(args, 1)?
.map(|v| v.cast_to(&DataType::Int64)?.try_into())
.and_then(|v| v.ok());
.map(get_integer)
.map_or(Ok(None), |v| v.map(Some))?;
let default_value =
get_casted_value(get_scalar_value_from_args(args, 2)?, out_data_type)?;
Arc::new(lag(
Expand All @@ -287,8 +291,8 @@ fn create_built_in_window_expr(
BuiltInWindowFunction::Lead => {
let arg = args[0].clone();
let shift_offset = get_scalar_value_from_args(args, 1)?
.map(|v| v.cast_to(&DataType::Int64)?.try_into())
.and_then(|v| v.ok());
.map(get_integer)
.map_or(Ok(None), |v| v.map(Some))?;
let default_value =
get_casted_value(get_scalar_value_from_args(args, 2)?, out_data_type)?;
Arc::new(lead(
Expand All @@ -302,12 +306,14 @@ fn create_built_in_window_expr(
}
BuiltInWindowFunction::NthValue => {
let arg = args[0].clone();
let n = args[1].as_any().downcast_ref::<Literal>().unwrap().value();
let n: i64 = n
.clone()
.cast_to(&DataType::Int64)?
.try_into()
.map_err(|e| DataFusionError::Execution(format!("{e:?}")))?;
let n = get_integer(
args[1]
.as_any()
.downcast_ref::<Literal>()
.unwrap()
.value()
.clone(),
)?;
Arc::new(NthValue::nth(
name,
arg,
Expand Down Expand Up @@ -614,8 +620,8 @@ mod tests {
use datafusion_execution::TaskContext;

use datafusion_functions_aggregate::count::count_udaf;
use datafusion_physical_expr_common::expressions::lit;
use futures::FutureExt;

use InputOrderMode::{Linear, PartiallySorted, Sorted};

fn create_test_schema() -> Result<SchemaRef> {
Expand Down Expand Up @@ -1139,4 +1145,119 @@ mod tests {

Ok(())
}

#[tokio::test]
async fn test_ntile_accepts_integer_args_but_not_others() -> Result<()> {
let schema = Schema::new(vec![Field::new("col", DataType::Int32, true)]);
for arg in [
lit(1i8),
lit(1i16),
lit(1i32),
lit(1i64),
lit(1u8),
lit(1u16),
lit(1u32),
lit(1u64),
] {
create_built_in_window_expr(
&BuiltInWindowFunction::Ntile,
vec![arg].as_slice(),
&schema,
"col".to_string(),
false,
)?;
}

assert_eq!(
create_built_in_window_expr(
&BuiltInWindowFunction::Ntile,
vec![lit(1.1)].as_slice(),
&schema,
"col".to_string(),
false,
)
.unwrap_err()
.message(),
"Expected an integer value"
);

Ok(())
}

#[tokio::test]
async fn test_nth_value_accepts_integer_args_but_not_others() -> Result<()> {
let schema = Schema::new(vec![Field::new("col", DataType::Int32, true)]);
for arg in [
lit(1i8),
lit(1i16),
lit(1i32),
lit(1i64),
lit(1u8),
lit(1u16),
lit(1u32),
lit(1u64),
] {
create_built_in_window_expr(
&BuiltInWindowFunction::NthValue,
vec![col("col", &schema)?, arg].as_slice(),
&schema,
"col".to_string(),
false,
)?;
}

assert_eq!(
create_built_in_window_expr(
&BuiltInWindowFunction::NthValue,
vec![col("col", &schema)?, lit(1.1)].as_slice(),
&schema,
"col".to_string(),
false,
)
.unwrap_err()
.message(),
"Expected an integer value"
);
Ok(())
}

#[tokio::test]
async fn test_lag_lead_accepts_integer_args_but_not_others() -> Result<()> {
let schema = Schema::new(vec![Field::new("col", DataType::Int32, true)]);
for window_function in [&BuiltInWindowFunction::Lag, &BuiltInWindowFunction::Lead]
{
for arg in [
lit(-1i8),
lit(-1i16),
lit(-1i32),
lit(-1i64),
lit(2u8),
lit(2u16),
lit(2u32),
lit(2u64),
] {
create_built_in_window_expr(
window_function,
vec![col("col", &schema)?, arg].as_slice(),
&schema,
"col".to_string(),
false,
)?;
}

assert_eq!(
create_built_in_window_expr(
window_function,
vec![col("col", &schema)?, lit(1.1)].as_slice(),
&schema,
"col".to_string(),
false,
)
.unwrap_err()
.message(),
"Expected an integer value"
);
}
Ok(())
}
}

0 comments on commit 356d706

Please sign in to comment.