Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add documentation and usability for prepared parameters #7785

Merged
merged 3 commits into from
Oct 16, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 36 additions & 1 deletion datafusion/core/src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1218,7 +1218,42 @@ impl DataFrame {
Ok(DataFrame::new(self.session_state, project_plan))
}

/// Convert a prepare logical plan into its inner logical plan with all params replaced with their corresponding values
/// Replace all parameters in logical plan with the specified
/// values, in preparation for execution.
///
/// # Example
///
/// ```
/// use datafusion::prelude::*;
/// # use datafusion::{error::Result, assert_batches_eq};
/// # #[tokio::main]
/// # async fn main() -> Result<()> {
/// # use datafusion_common::ScalarValue;
/// let mut ctx = SessionContext::new();
/// # ctx.register_csv("example", "tests/data/example.csv", CsvReadOptions::new()).await?;
/// let results = ctx
/// .sql("SELECT a FROM example WHERE b = $1")
/// .await?
/// // replace $1 with value 2
/// .with_param_values(vec![
/// // value at index 0 --> $1
/// ScalarValue::from(2i64)
/// ])?
/// .collect()
/// .await?;
/// assert_batches_eq!(
/// &[
/// "+---+",
/// "| a |",
/// "+---+",
/// "| 1 |",
/// "+---+",
/// ],
/// &results
/// );
/// # Ok(())
/// # }
/// ```
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❤️

pub fn with_param_values(self, param_values: Vec<ScalarValue>) -> Result<Self> {
let plan = self.plan.with_param_values(param_values)?;
Ok(Self::new(self.session_state, plan))
Expand Down
58 changes: 54 additions & 4 deletions datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

//! Expr module contains core type definition for `Expr`.

use crate::aggregate_function;
use crate::built_in_function;
use crate::expr_fn::binary_expr;
use crate::logical_plan::Subquery;
Expand All @@ -26,8 +25,10 @@ use crate::utils::{expr_to_columns, find_out_reference_exprs};
use crate::window_frame;
use crate::window_function;
use crate::Operator;
use crate::{aggregate_function, ExprSchemable};
use arrow::datatypes::DataType;
use datafusion_common::internal_err;
use datafusion_common::tree_node::{Transformed, TreeNode};
use datafusion_common::{internal_err, DFSchema};
use datafusion_common::{plan_err, Column, DataFusionError, Result, ScalarValue};
use std::collections::HashSet;
use std::fmt;
Expand Down Expand Up @@ -599,10 +600,13 @@ impl InSubquery {
}
}

/// Placeholder
/// Placeholder, representing bind parameter values such as `$1`.
///
/// The type of these parameters is inferred using [`Expr::infer_placeholder_types`]
/// or can be specified directly using `PREPARE` statements.
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub struct Placeholder {
/// The identifier of the parameter (e.g, $1 or $foo)
/// The identifier of the parameter, including the leading `$` (e.g, `"$1"` or `"$foo'`)
alamb marked this conversation as resolved.
Show resolved Hide resolved
pub id: String,
/// The type the parameter will be filled in with
pub data_type: Option<DataType>,
Expand Down Expand Up @@ -1030,6 +1034,52 @@ impl Expr {
pub fn contains_outer(&self) -> bool {
!find_out_reference_exprs(self).is_empty()
}

/// Recursively find all [`Expr::Placeholder`] expressions, and
/// to infer their [`DataType`] from the context of their use.
///
/// For example, gicen an expression like `<int32> = $0` will infer `$0` to
/// have type `int32`.
pub fn infer_placeholder_types(self, schema: &DFSchema) -> Result<Expr> {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was moved from the datafusion-sql module as it is not SQL specific, but generic to expressions

self.transform(&|mut expr| {
// Default to assuming the arguments are the same type
if let Expr::BinaryExpr(BinaryExpr { left, op: _, right }) = &mut expr {
rewrite_placeholder(left.as_mut(), right.as_ref(), schema)?;
rewrite_placeholder(right.as_mut(), left.as_ref(), schema)?;
};
if let Expr::Between(Between {
expr,
negated: _,
low,
high,
}) = &mut expr
{
rewrite_placeholder(low.as_mut(), expr.as_ref(), schema)?;
rewrite_placeholder(high.as_mut(), expr.as_ref(), schema)?;
}
Ok(Transformed::Yes(expr))
})
}
}

// modifies expr if it is a placeholder with datatype of right
fn rewrite_placeholder(expr: &mut Expr, other: &Expr, schema: &DFSchema) -> Result<()> {
if let Expr::Placeholder(Placeholder { id: _, data_type }) = expr {
if data_type.is_none() {
let other_dt = other.get_type(schema);
match other_dt {
Err(e) => {
Err(e.context(format!(
"Can not find type of {other} needed to infer type of {expr}"
)))?;
}
Ok(dt) => {
*data_type = Some(dt);
}
}
};
}
Ok(())
}

#[macro_export]
Expand Down
20 changes: 19 additions & 1 deletion datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

use crate::expr::{
AggregateFunction, BinaryExpr, Cast, Exists, GroupingSet, InList, InSubquery,
ScalarFunction, TryCast,
Placeholder, ScalarFunction, TryCast,
};
use crate::function::PartitionEvaluatorFactory;
use crate::WindowUDF;
Expand Down Expand Up @@ -80,6 +80,24 @@ pub fn ident(name: impl Into<String>) -> Expr {
Expr::Column(Column::from_name(name))
}

/// Create placeholder value that will be filled in (such as `$1`)
///
/// Note the parameter type can be inferred using [`Expr::infer_placeholder_types`]
///
/// # Example
///
/// ```rust
/// # use datafusion_expr::{placeholder};
jackwener marked this conversation as resolved.
Show resolved Hide resolved
/// let p = placeholder("$0"); // $0, refers to parameter 1
/// assert_eq!(p.to_string(), "$0")
/// ```
pub fn placeholder(id: impl Into<String>) -> Expr {
Expr::Placeholder(Placeholder {
id: id.into(),
data_type: None,
})
}

/// Return a new expression `left <op> right`
pub fn binary_expr(left: Expr, op: Operator, right: Expr) -> Expr {
Expr::BinaryExpr(BinaryExpr::new(Box::new(left), op, Box::new(right)))
Expand Down
69 changes: 51 additions & 18 deletions datafusion/expr/src/logical_plan/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -928,8 +928,40 @@ impl LogicalPlan {
}
}
}
/// Convert a prepared [`LogicalPlan`] into its inner logical plan
/// with all params replaced with their corresponding values
/// Replaces placeholder param values (like `$1`, `$2`) in [`LogicalPlan`]
/// with the specified `param_values`.
///
/// [`LogicalPlan::Prepare`] are
/// converted to their inner logical plan for execution.
///
/// # Example
/// ```
/// # use arrow::datatypes::{Field, Schema, DataType};
/// use datafusion_common::ScalarValue;
/// # use datafusion_expr::{lit, col, LogicalPlanBuilder, logical_plan::table_scan, placeholder};
/// # let schema = Schema::new(vec![
/// # Field::new("id", DataType::Int32, false),
/// # ]);
/// // Build SELECT * FROM t1 WHRERE id = $1
/// let plan = table_scan(Some("t1"), &schema, None).unwrap()
/// .filter(col("id").eq(placeholder("$1"))).unwrap()
/// .build().unwrap();
///
/// assert_eq!("Filter: t1.id = $1\
/// \n TableScan: t1",
/// plan.display_indent().to_string()
/// );
///
/// // Fill in the parameter $1 with a literal 3
/// let plan = plan.with_param_values(vec![
/// ScalarValue::from(3i32) // value at index 0 --> $1
/// ]).unwrap();
///
/// assert_eq!("Filter: t1.id = Int32(3)\
/// \n TableScan: t1",
/// plan.display_indent().to_string()
/// );
/// ```
pub fn with_param_values(
self,
param_values: Vec<ScalarValue>,
Expand Down Expand Up @@ -961,7 +993,7 @@ impl LogicalPlan {
let input_plan = prepare_lp.input;
input_plan.replace_params_with_values(&param_values)
}
_ => Ok(self),
_ => self.replace_params_with_values(&param_values),
}
}

Expand Down Expand Up @@ -1060,7 +1092,7 @@ impl LogicalPlan {
}

impl LogicalPlan {
/// applies collect to any subqueries in the plan
/// applies `op` to any subqueries in the plan
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

drive by cleanups

pub(crate) fn apply_subqueries<F>(&self, op: &mut F) -> datafusion_common::Result<()>
where
F: FnMut(&Self) -> datafusion_common::Result<VisitRecursion>,
Expand Down Expand Up @@ -1112,17 +1144,22 @@ impl LogicalPlan {
Ok(())
}

/// Return a logical plan with all placeholders/params (e.g $1 $2,
/// ...) replaced with corresponding values provided in the
/// params_values
/// Return a `LogicalPlan` with all placeholders (e.g $1 $2,
/// ...) replaced with corresponding values provided in
/// `params_values`
///
/// See [`Self::with_param_values`] for examples and usage
pub fn replace_params_with_values(
&self,
param_values: &[ScalarValue],
) -> Result<LogicalPlan> {
let new_exprs = self
.expressions()
.into_iter()
.map(|e| Self::replace_placeholders_with_values(e, param_values))
.map(|e| {
let e = e.infer_placeholder_types(self.schema())?;
Self::replace_placeholders_with_values(e, param_values)
})
.collect::<Result<Vec<_>>>()?;

let new_inputs_with_values = self
Expand Down Expand Up @@ -1219,7 +1256,9 @@ impl LogicalPlan {
// Various implementations for printing out LogicalPlans
impl LogicalPlan {
/// Return a `format`able structure that produces a single line
/// per node. For example:
/// per node.
///
/// # Example
///
/// ```text
/// Projection: employee.id
Expand Down Expand Up @@ -2321,7 +2360,7 @@ pub struct Unnest {
mod tests {
use super::*;
use crate::logical_plan::table_scan;
use crate::{col, exists, in_subquery, lit};
use crate::{col, exists, in_subquery, lit, placeholder};
use arrow::datatypes::{DataType, Field, Schema};
use datafusion_common::tree_node::TreeNodeVisitor;
use datafusion_common::{not_impl_err, DFSchema, TableReference};
Expand Down Expand Up @@ -2767,10 +2806,7 @@ digraph {

let plan = table_scan(TableReference::none(), &schema, None)
.unwrap()
.filter(col("id").eq(Expr::Placeholder(Placeholder::new(
"".into(),
Some(DataType::Int32),
))))
.filter(col("id").eq(placeholder("")))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here is an example of the new placeholder function being much nicer to use

.unwrap()
.build()
.unwrap();
Expand All @@ -2783,10 +2819,7 @@ digraph {

let plan = table_scan(TableReference::none(), &schema, None)
.unwrap()
.filter(col("id").eq(Expr::Placeholder(Placeholder::new(
"$0".into(),
Some(DataType::Int32),
))))
.filter(col("id").eq(placeholder("$0")))
.unwrap()
.build()
.unwrap();
Expand Down
48 changes: 2 additions & 46 deletions datafusion/sql/src/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,12 @@ mod value;

use crate::planner::{ContextProvider, PlannerContext, SqlToRel};
use arrow_schema::DataType;
use datafusion_common::tree_node::{Transformed, TreeNode};
use datafusion_common::{
internal_err, not_impl_err, plan_err, Column, DFSchema, DataFusionError, Result,
ScalarValue,
};
use datafusion_expr::expr::InList;
use datafusion_expr::expr::ScalarFunction;
use datafusion_expr::expr::{InList, Placeholder};
use datafusion_expr::{
col, expr, lit, AggregateFunction, Between, BinaryExpr, BuiltinScalarFunction, Cast,
Expr, ExprSchemable, GetFieldAccess, GetIndexedField, Like, Operator, TryCast,
Expand Down Expand Up @@ -122,7 +121,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
let mut expr = self.sql_expr_to_logical_expr(sql, schema, planner_context)?;
expr = self.rewrite_partial_qualifier(expr, schema);
self.validate_schema_satisfies_exprs(schema, &[expr.clone()])?;
let expr = infer_placeholder_types(expr, schema)?;
let expr = expr.infer_placeholder_types(schema)?;
Ok(expr)
}

Expand Down Expand Up @@ -712,49 +711,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
}
}

// modifies expr if it is a placeholder with datatype of right
fn rewrite_placeholder(expr: &mut Expr, other: &Expr, schema: &DFSchema) -> Result<()> {
if let Expr::Placeholder(Placeholder { id: _, data_type }) = expr {
if data_type.is_none() {
let other_dt = other.get_type(schema);
match other_dt {
Err(e) => {
Err(e.context(format!(
"Can not find type of {other} needed to infer type of {expr}"
)))?;
}
Ok(dt) => {
*data_type = Some(dt);
}
}
};
}
Ok(())
}

/// Find all [`Expr::Placeholder`] tokens in a logical plan, and try
/// to infer their [`DataType`] from the context of their use.
fn infer_placeholder_types(expr: Expr, schema: &DFSchema) -> Result<Expr> {
expr.transform(&|mut expr| {
// Default to assuming the arguments are the same type
if let Expr::BinaryExpr(BinaryExpr { left, op: _, right }) = &mut expr {
rewrite_placeholder(left.as_mut(), right.as_ref(), schema)?;
rewrite_placeholder(right.as_mut(), left.as_ref(), schema)?;
};
if let Expr::Between(Between {
expr,
negated: _,
low,
high,
}) = &mut expr
{
rewrite_placeholder(low.as_mut(), expr.as_ref(), schema)?;
rewrite_placeholder(high.as_mut(), expr.as_ref(), schema)?;
}
Ok(Transformed::Yes(expr))
})
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
13 changes: 13 additions & 0 deletions datafusion/sql/tests/sql_integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3684,6 +3684,19 @@ fn test_prepare_statement_should_infer_types() {
assert_eq!(actual_types, expected_types);
}

#[test]
fn test_non_prepare_statement_should_infer_types() {
// Non prepared statements (like SELECT) should also have their parameter types inferred
let sql = "SELECT 1 + $1";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice

let plan = logical_plan(sql).unwrap();
let actual_types = plan.get_parameter_types().unwrap();
let expected_types = HashMap::from([
// constant 1 is inferred to be int64
("$1".to_string(), Some(DataType::Int64)),
]);
assert_eq!(actual_types, expected_types);
}

#[test]
#[should_panic(
expected = "value: SQL(ParserError(\"Expected [NOT] NULL or TRUE|FALSE or [NOT] DISTINCT FROM after IS, found: $1\""
Expand Down