Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Initial support for SQL ARRAY literals and the UNNEST table function #16330

Merged
merged 6 commits into from
May 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 63 additions & 1 deletion crates/polars-sql/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use sqlparser::dialect::GenericDialect;
use sqlparser::parser::{Parser, ParserOptions};

use crate::function_registry::{DefaultFunctionRegistry, FunctionRegistry};
use crate::sql_expr::{parse_sql_expr, process_join_constraint};
use crate::sql_expr::{parse_sql_array, parse_sql_expr, process_join_constraint};
use crate::table_functions::PolarsTableFunctions;

/// The SQLContext is the main entry point for executing SQL queries.
Expand Down Expand Up @@ -748,6 +748,7 @@ impl SQLContext {
alias,
} => {
polars_ensure!(!(*lateral), ComputeError: "LATERAL not supported");

if let Some(alias) = alias {
let lf = self.execute_query_no_ctes(subquery)?;
self.table_map.insert(alias.name.value.clone(), lf.clone());
Expand All @@ -756,6 +757,67 @@ impl SQLContext {
polars_bail!(ComputeError: "derived tables must have aliases");
}
},
TableFactor::UNNEST {
alias,
array_exprs,
with_offset,
with_offset_alias: _,
} => {
if let Some(alias) = alias {
let table_name = alias.name.value.clone();
let column_names: Vec<Option<&str>> = alias
.columns
.iter()
.map(|c| {
if c.value.is_empty() {
None
} else {
Some(c.value.as_str())
}
})
.collect();

let column_values: Vec<Series> = array_exprs
.iter()
.map(|arr| parse_sql_array(arr, self))
.collect::<Result<_, _>>()?;

polars_ensure!(!column_names.is_empty(),
ComputeError:
"UNNEST table alias must also declare column names, eg: {} (a,b,c)", alias.name.to_string()
);
if column_names.len() != column_values.len() {
let plural = if column_values.len() > 1 { "s" } else { "" };
polars_bail!(
ComputeError:
"UNNEST table alias requires {} column name{}, found {}", column_values.len(), plural, column_names.len()
);
}
let column_series: Vec<Series> = column_values
.iter()
.zip(column_names.iter())
.map(|(s, name)| {
if let Some(name) = name {
s.clone().with_name(name)
} else {
s.clone()
}
})
.collect();

let lf = DataFrame::new(column_series)?.lazy();
if *with_offset {
// TODO: make a PR to `sqlparser-rs` to support 'ORDINALITY'
// (note that 'OFFSET' is BigQuery-specific syntax, not PostgreSQL)
polars_bail!(ComputeError: "UNNEST tables do not (yet) support WITH OFFSET/ORDINALITY");
}
self.table_map.insert(table_name.clone(), lf.clone());
Ok((table_name.clone(), lf))
} else {
polars_bail!(ComputeError: "UNNEST table must have an alias");
}
},

// Support bare table, optional with alias for now
_ => polars_bail!(ComputeError: "not yet implemented: {}", relation),
}
Expand Down
129 changes: 75 additions & 54 deletions crates/polars-sql/src/sql_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use polars_core::export::regex;
use polars_core::prelude::*;
use polars_error::to_compute_err;
use polars_lazy::prelude::*;
use polars_ops::series::SeriesReshape;
use polars_plan::prelude::typed_lit;
use polars_plan::prelude::LiteralValue::Null;
use rand::distributions::Alphanumeric;
Expand Down Expand Up @@ -185,6 +186,28 @@ pub(crate) struct SQLExprVisitor<'a> {
}

impl SQLExprVisitor<'_> {
fn array_expr_to_series(&mut self, elements: &[SQLExpr]) -> PolarsResult<Series> {
let array_elements = elements
.iter()
.map(|e| match e {
SQLExpr::Value(v) => self.visit_any_value(v, None),
SQLExpr::UnaryOp { op, expr } => match expr.as_ref() {
SQLExpr::Value(v) => self.visit_any_value(v, Some(op)),
_ => Err(polars_err!(ComputeError: "SQL expression {:?} is not yet supported", e)),
},
SQLExpr::Array(_) => {
// TODO: nested arrays (handle FnMut issues)
// let srs = self.array_expr_to_series(&[e.clone()])?;
// Ok(AnyValue::List(srs))
Err(polars_err!(ComputeError: "SQL interface does not yet support nested array literals:\n{:?}", e))
},
_ => Err(polars_err!(ComputeError: "SQL expression {:?} is not yet supported", e)),
})
.collect::<PolarsResult<Vec<_>>>()?;

Series::from_any_values("", &array_elements, true)
}

fn visit_expr(&mut self, expr: &SQLExpr) -> PolarsResult<Expr> {
match expr {
SQLExpr::AllOp {
Expand All @@ -197,6 +220,7 @@ impl SQLExprVisitor<'_> {
compare_op,
right,
} => self.visit_any(left, compare_op, right),
SQLExpr::Array(arr) => self.visit_array_expr(&arr.elem, true, None),
SQLExpr::ArrayAgg(expr) => self.visit_arr_agg(expr),
SQLExpr::Between {
expr,
Expand All @@ -220,7 +244,12 @@ impl SQLExprVisitor<'_> {
expr,
list,
negated,
} => self.visit_in_list(expr, list, *negated),
} => {
let expr = self.visit_expr(expr)?;
let elems = self.visit_array_expr(list, false, Some(&expr))?;
let is_in = expr.is_in(elems);
Ok(if *negated { is_in.not() } else { is_in })
},
SQLExpr::InSubquery {
expr,
subquery,
Expand Down Expand Up @@ -615,6 +644,38 @@ impl SQLExprVisitor<'_> {
}
}

/// Visit a SQL `ARRAY` list (including `IN` values).
fn visit_array_expr(
&mut self,
elements: &[SQLExpr],
result_as_element: bool,
dtype_expr_match: Option<&Expr>,
) -> PolarsResult<Expr> {
let mut elems = self.array_expr_to_series(elements)?;

// handle implicit temporal strings, eg: "dt IN ('2024-04-30','2024-05-01')".
// (not yet as versatile as the temporal string conversions in visit_binary_op)
if let (Some(Expr::Column(name)), Some(schema)) =
(dtype_expr_match, self.active_schema.as_ref())
{
if elems.dtype() == &DataType::String {
if let Some(DataType::Date | DataType::Time | DataType::Datetime(_, _)) =
schema.get(name)
{
elems = elems.strict_cast(&schema.get(name).unwrap().clone())?;
}
}
}
// if we are parsing the list as an element in a series, implode.
// otherwise, return the series as-is.
let res = if result_as_element {
elems.implode()?.into_series()
} else {
elems
};
Ok(lit(res))
}

/// Visit a SQL `CAST` or `TRY_CAST` expression.
///
/// e.g. `CAST(col AS INT)`, `col::int4`, or `TRY_CAST(col AS VARCHAR)`,
Expand Down Expand Up @@ -810,59 +871,6 @@ impl SQLExprVisitor<'_> {
Ok(base.implode())
}

/// Visit a SQL `IN` expression
fn visit_in_list(
&mut self,
expr: &SQLExpr,
list: &[SQLExpr],
negated: bool,
) -> PolarsResult<Expr> {
let expr = self.visit_expr(expr)?;
let list = list
.iter()
.map(|e| {
if let SQLExpr::Value(v) = e {
let av = self.visit_any_value(v, None)?;
Ok(av)
} else if let SQLExpr::UnaryOp {op, expr} = e {
match expr.as_ref() {
SQLExpr::Value(v) => {
let av = self.visit_any_value(v, Some(op))?;
Ok(av)
},
_ => Err(polars_err!(ComputeError: "SQL expression {:?} is not yet supported", e))
}
}else{
Err(polars_err!(ComputeError: "SQL expression {:?} is not yet supported", e))
}
})
.collect::<PolarsResult<Vec<_>>>()?;

let mut s = Series::from_any_values("", &list, true)?;

// handle implicit temporal strings, eg: "dt IN ('2024-04-30','2024-05-01')".
// (not yet as versatile as the temporal string conversions in visit_binary_op)
if s.dtype() == &DataType::String {
// handle implicit temporal string comparisons, eg: "dt >= '2024-04-30'"
if let Expr::Column(name) = &expr {
if self.active_schema.is_some() {
let schema = self.active_schema.as_ref().unwrap();
let left_dtype = schema.get(name);
if let Some(DataType::Date | DataType::Time | DataType::Datetime(_, _)) =
left_dtype
{
s = s.strict_cast(&left_dtype.unwrap().clone())?;
}
}
}
}
if negated {
Ok(expr.is_in(lit(s)).not())
} else {
Ok(expr.is_in(lit(s)))
}
}

/// Visit a SQL subquery inside and `IN` expression.
fn visit_in_subquery(
&mut self,
Expand Down Expand Up @@ -1115,6 +1123,19 @@ pub(crate) fn parse_sql_expr(
visitor.visit_expr(expr)
}

pub(crate) fn parse_sql_array(expr: &SQLExpr, ctx: &mut SQLContext) -> PolarsResult<Series> {
match expr {
SQLExpr::Array(arr) => {
let mut visitor = SQLExprVisitor {
ctx,
active_schema: None,
};
visitor.array_expr_to_series(arr.elem.as_slice())
},
_ => polars_bail!(ComputeError: "Expected array expression, found {:?}", expr),
}
}

fn parse_extract(expr: Expr, field: &DateTimeField) -> PolarsResult<Expr> {
Ok(match field {
DateTimeField::Millennium => expr.dt().millennium(),
Expand Down
20 changes: 20 additions & 0 deletions crates/polars-sql/tests/functions_string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,23 @@ fn test_array_to_string() {
.unwrap();
assert!(df_sql.equals(&df_expected));
}

#[test]
fn test_array_literal() {
let mut context = SQLContext::new();
context.register("df", DataFrame::empty().lazy());

let sql = "SELECT [100,200,300] AS arr FROM df";
let df_sql = context.execute(sql).unwrap().collect().unwrap();
let df_expected = df! {
"arr" => &[100i64, 200, 300],
}
.unwrap()
.lazy()
.select(&[col("arr").implode()])
.collect()
.unwrap();

assert!(df_sql.equals(&df_expected));
assert!(df_sql.height() == 1);
}
97 changes: 97 additions & 0 deletions py-polars/tests/unit/sql/test_array.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,39 @@
from __future__ import annotations

import pytest

import polars as pl
from polars.exceptions import ComputeError
from polars.testing import assert_frame_equal


def test_array_literals() -> None:
with pl.SQLContext(df=None, eager=True) as ctx:
res = ctx.execute(
"""
SELECT
a1, a2, ARRAY_AGG(a1) AS a3, ARRAY_AGG(a2) AS a4
FROM (
SELECT
[10,20,30] AS a1,
['a','b','c'] AS a2,
FROM df
) tbl
"""
)
assert_frame_equal(
res,
pl.DataFrame(
{
"a1": [[10, 20, 30]],
"a2": [["a", "b", "c"]],
"a3": [[[10, 20, 30]]],
"a4": [[["a", "b", "c"]]],
}
),
)


def test_array_to_string() -> None:
data = {"values": [["aa", "bb"], [None, "cc"], ["dd", None]]}
res = pl.DataFrame(data).sql(
Expand All @@ -25,3 +55,70 @@ def test_array_to_string() -> None:
}
),
)


@pytest.mark.parametrize(
"array_keyword",
["ARRAY", ""],
)
def test_unnest_table_function(array_keyword: str) -> None:
with pl.SQLContext(df=None, eager=True) as ctx:
res = ctx.execute(
f"""
SELECT * FROM
UNNEST(
{array_keyword}[1, 2, 3, 4],
{array_keyword}['ww','xx','yy','zz'],
{array_keyword}[23.0, 24.5, 28.0, 27.5]
) AS tbl (x,y,z);
"""
)
assert_frame_equal(
res,
pl.DataFrame(
{
"x": [1, 2, 3, 4],
"y": ["ww", "xx", "yy", "zz"],
"z": [23.0, 24.5, 28.0, 27.5],
}
),
)


def test_unnest_table_function_errors() -> None:
with pl.SQLContext(df=None, eager=True) as ctx:
with pytest.raises(
ComputeError,
match=r'UNNEST table alias must also declare column names, eg: "frame data" \(a,b,c\)',
):
ctx.execute('SELECT * FROM UNNEST([1, 2, 3]) AS "frame data"')

with pytest.raises(
ComputeError,
match="UNNEST table alias requires 1 column name, found 2",
):
ctx.execute("SELECT * FROM UNNEST([1, 2, 3]) AS tbl (a, b)")

with pytest.raises(
ComputeError,
match="UNNEST table alias requires 2 column names, found 1",
):
ctx.execute("SELECT * FROM UNNEST([1,2,3], [3,4,5]) AS tbl (a)")

with pytest.raises(
ComputeError,
match=r"UNNEST table must have an alias",
):
ctx.execute("SELECT * FROM UNNEST([1, 2, 3])")

with pytest.raises(
ComputeError,
match=r"UNNEST tables do not \(yet\) support WITH OFFSET/ORDINALITY",
):
ctx.execute("SELECT * FROM UNNEST([1, 2, 3]) tbl (colx) WITH OFFSET")

with pytest.raises(
ComputeError,
match="SQL interface does not yet support nested array literals",
):
pl.sql_expr("[[1,2,3]] AS nested")