Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/parameterized sql queries #964

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 91 additions & 46 deletions python/datafusion/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,14 @@
from datafusion.record_batch import RecordBatchStream
from datafusion.udf import ScalarUDF, AggregateUDF, WindowUDF

import pathlib
from typing import Any, TYPE_CHECKING, Protocol
from typing_extensions import deprecated

if TYPE_CHECKING:
import pyarrow
import pandas
import polars
import pathlib
from datafusion.plan import LogicalPlan, ExecutionPlan


Expand Down Expand Up @@ -523,9 +523,18 @@ def register_listing_table(
file_sort_order_raw,
)

def sql(self, query: str, options: SQLOptions | None = None) -> DataFrame:
def sql(
self, query: str, options: SQLOptions | None = None, **named_dfs: DataFrame
) -> DataFrame:
"""Create a :py:class:`~datafusion.DataFrame` from SQL query text.

The query string can optionally take a DataFrame as a parameter by assigning
a variable inside brackets. In the following example, if we have a DataFrame
called `my_df` then the DataFrame's logical plan will be converted into an
SQL query string and inserted as a subtitution::

ctx.sql("SELECT name from {df}", df=my_df)

Note: This API implements DDL statements such as ``CREATE TABLE`` and
``CREATE VIEW`` and DML statements such as ``INSERT INTO`` with in-memory
default implementation.See
Expand All @@ -534,12 +543,20 @@ def sql(self, query: str, options: SQLOptions | None = None) -> DataFrame:
Args:
query: SQL query text.
options: If provided, the query will be validated against these options.
named_dfs: When provided, used to replace parameterized query variables
in the query string.

Returns:
DataFrame representation of the SQL query.
"""
if named_dfs:
for alias, df in named_dfs.items():
df_sql = f"({df.logical_plan().to_sql()})"
query = query.replace(f"{{{alias}}}", df_sql)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some annoying unintended side effects to this approach. Imagine the following query

SELECT * FROM {alias} WHERE val="a string that happens to contain {alias} in it"

Since this code just replaces all occurences of {alias} with an sql query it'll do so in the WHERE part as well.
As far as I can tell, there would be no way to escape {alias} in such a way that the replacement does not occur.

This is obviously a contrived example, and it might be that this is acceptable.


if options is None:
return DataFrame(self.ctx.sql(query))

return DataFrame(self.ctx.sql_with_options(query, options.options_internal))

def sql_with_options(self, query: str, options: SQLOptions) -> DataFrame:
Expand Down Expand Up @@ -753,7 +770,7 @@ def register_parquet(
def register_csv(
self,
name: str,
path: str | pathlib.Path | list[str | pathlib.Path],
path: str | pathlib.Path | list[str] | list[pathlib.Path],
schema: pyarrow.Schema | None = None,
has_header: bool = True,
delimiter: str = ",",
Expand Down Expand Up @@ -917,6 +934,7 @@ def read_json(
file_extension: str = ".json",
table_partition_cols: list[tuple[str, str]] | None = None,
file_compression_type: str | None = None,
table_name: str | None = None,
) -> DataFrame:
"""Read a line-delimited JSON data source.

Expand All @@ -929,22 +947,23 @@ def read_json(
selected for data input.
table_partition_cols: Partition columns.
file_compression_type: File compression type.
table_name: Name to register the table as for SQL queries

Returns:
DataFrame representation of the read JSON files.
"""
if table_partition_cols is None:
table_partition_cols = []
return DataFrame(
self.ctx.read_json(
str(path),
schema,
schema_infer_max_records,
file_extension,
table_partition_cols,
file_compression_type,
)
if table_name is None:
table_name = self.generate_table_name(path)
self.register_json(
table_name,
path,
schema=schema,
schema_infer_max_records=schema_infer_max_records,
file_extension=file_extension,
table_partition_cols=table_partition_cols,
file_compression_type=file_compression_type,
)
return self.table(table_name)

def read_csv(
self,
Expand All @@ -956,6 +975,7 @@ def read_csv(
file_extension: str = ".csv",
table_partition_cols: list[tuple[str, str]] | None = None,
file_compression_type: str | None = None,
table_name: str | None = None,
) -> DataFrame:
"""Read a CSV data source.

Expand All @@ -973,27 +993,24 @@ def read_csv(
selected for data input.
table_partition_cols: Partition columns.
file_compression_type: File compression type.
table_name: Name to register the table as for SQL queries

Returns:
DataFrame representation of the read CSV files
"""
if table_partition_cols is None:
table_partition_cols = []

path = [str(p) for p in path] if isinstance(path, list) else str(path)

return DataFrame(
self.ctx.read_csv(
path,
schema,
has_header,
delimiter,
schema_infer_max_records,
file_extension,
table_partition_cols,
file_compression_type,
)
if table_name is None:
table_name = self.generate_table_name(path)
self.register_csv(
table_name,
path,
schema=schema,
has_header=has_header,
delimiter=delimiter,
schema_infer_max_records=schema_infer_max_records,
file_extension=file_extension,
file_compression_type=file_compression_type,
)
return self.table(table_name)

def read_parquet(
self,
Expand All @@ -1004,6 +1021,7 @@ def read_parquet(
skip_metadata: bool = True,
schema: pyarrow.Schema | None = None,
file_sort_order: list[list[Expr]] | None = None,
table_name: str | None = None,
) -> DataFrame:
"""Read a Parquet source into a :py:class:`~datafusion.dataframe.Dataframe`.

Expand All @@ -1021,30 +1039,32 @@ def read_parquet(
the parquet reader will try to infer it based on data in the
file.
file_sort_order: Sort order for the file.
table_name: Name to register the table as for SQL queries

Returns:
DataFrame representation of the read Parquet files
"""
if table_partition_cols is None:
table_partition_cols = []
return DataFrame(
self.ctx.read_parquet(
str(path),
table_partition_cols,
parquet_pruning,
file_extension,
skip_metadata,
schema,
file_sort_order,
)
if table_name is None:
table_name = self.generate_table_name(path)
self.register_parquet(
table_name,
path,
table_partition_cols=table_partition_cols,
parquet_pruning=parquet_pruning,
file_extension=file_extension,
skip_metadata=skip_metadata,
schema=schema,
file_sort_order=file_sort_order,
)
return self.table(table_name)

def read_avro(
self,
path: str | pathlib.Path,
schema: pyarrow.Schema | None = None,
file_partition_cols: list[tuple[str, str]] | None = None,
file_extension: str = ".avro",
table_name: str | None = None,
) -> DataFrame:
"""Create a :py:class:`DataFrame` for reading Avro data source.

Expand All @@ -1053,15 +1073,21 @@ def read_avro(
schema: The data source schema.
file_partition_cols: Partition columns.
file_extension: File extension to select.
table_name: Name to register the table as for SQL queries

Returns:
DataFrame representation of the read Avro file
"""
if file_partition_cols is None:
file_partition_cols = []
return DataFrame(
self.ctx.read_avro(str(path), schema, file_partition_cols, file_extension)
if table_name is None:
table_name = self.generate_table_name(path)
self.register_avro(
table_name,
path,
schema=schema,
file_extension=file_extension,
table_partition_cols=file_partition_cols,
)
return self.table(table_name)

def read_table(self, table: Table) -> DataFrame:
"""Creates a :py:class:`~datafusion.dataframe.DataFrame` from a table.
Expand All @@ -1075,3 +1101,22 @@ def read_table(self, table: Table) -> DataFrame:
def execute(self, plan: ExecutionPlan, partitions: int) -> RecordBatchStream:
"""Execute the ``plan`` and return the results."""
return RecordBatchStream(self.ctx.execute(plan._raw_plan, partitions))

def generate_table_name(
self, path: str | pathlib.Path | list[str] | list[pathlib.Path]
) -> str:
"""Generate a table name based on the file name or a uuid."""
import uuid

if isinstance(path, list):
path = path[0]

if isinstance(path, str):
path = pathlib.Path(path)

table_name = path.stem.replace(".", "_")

if self.table_exist(table_name):
table_name = uuid.uuid4().hex

return table_name
1 change: 1 addition & 0 deletions python/datafusion/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,7 @@ def window(
partition_by = expr_list_to_raw_expr_list(partition_by)
order_by_raw = sort_list_to_raw_sort_list(order_by)
window_frame = window_frame.window_frame if window_frame is not None else None
ctx = ctx.ctx if ctx is not None else None
return Expr(f.window(name, args, partition_by, order_by_raw, window_frame, ctx))


Expand Down
4 changes: 4 additions & 0 deletions python/datafusion/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,10 @@ def to_proto(self) -> bytes:
"""
return self._raw_plan.to_proto()

def to_sql(self) -> str:
"""Return the SQL equivalent statement for this logical plan."""
return self._raw_plan.to_sql()


class ExecutionPlan:
"""Represent nodes in the DataFusion Physical Plan."""
Expand Down
10 changes: 10 additions & 0 deletions python/tests/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,16 @@ def test_register_parquet(ctx, tmp_path):
assert result.to_pydict() == {"cnt": [100]}


def test_parameterized_sql(ctx, tmp_path) -> None:
path = helpers.write_parquet(tmp_path / "a.parquet", helpers.data())
df = ctx.read_parquet(path)
result = ctx.sql(
"SELECT COUNT(a) AS cnt FROM {replaced_df}", replaced_df=df
).collect()
result = pa.Table.from_batches(result)
assert result.to_pydict() == {"cnt": [100]}


@pytest.mark.parametrize("path_to_str", (True, False))
def test_register_parquet_partitioned(ctx, tmp_path, path_to_str):
dir_root = tmp_path / "dataset_parquet_partitioned"
Expand Down
11 changes: 11 additions & 0 deletions src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
// under the License.

use datafusion::functions_aggregate::all_default_aggregate_functions;
use datafusion::functions_window::all_default_window_functions;
use datafusion::logical_expr::ExprFunctionExt;
use datafusion::logical_expr::WindowFrame;
use pyo3::{prelude::*, wrap_pyfunction};
Expand Down Expand Up @@ -282,6 +283,16 @@ fn find_window_fn(name: &str, ctx: Option<PySessionContext>) -> PyResult<WindowF
return Ok(agg_fn);
}

// search default window functions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not clear to me how this relates to the rest of the change.

let window_fn = all_default_window_functions()
.iter()
.find(|v| v.name() == name || v.aliases().contains(&name.to_string()))
.map(|f| WindowFunctionDefinition::WindowUDF(f.clone()));

if let Some(window_fn) = window_fn {
return Ok(window_fn);
}

Err(DataFusionError::Common(format!("window function `{name}` not found")).into())
}

Expand Down
7 changes: 7 additions & 0 deletions src/sql/logical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ use crate::expr::table_scan::PyTableScan;
use crate::expr::unnest::PyUnnest;
use crate::expr::window::PyWindowExpr;
use crate::{context::PySessionContext, errors::py_unsupported_variant_err};
use datafusion::sql::unparser::plan_to_sql;
use datafusion::{error::DataFusionError, logical_expr::LogicalPlan};
use datafusion_proto::logical_plan::{AsLogicalPlan, DefaultLogicalExtensionCodec};
use prost::Message;
Expand Down Expand Up @@ -153,6 +154,12 @@ impl PyLogicalPlan {
.map_err(DataFusionError::from)?;
Ok(Self::new(plan))
}

pub fn to_sql(&self) -> PyResult<String> {
plan_to_sql(&self.plan)
.map(|v| v.to_string())
.map_err(|err| PyRuntimeError::new_err(err.to_string()))
}
}

impl From<PyLogicalPlan> for LogicalPlan {
Expand Down
Loading