Skip to content

Commit

Permalink
Merge pull request diesel-rs#4350 from weiznich/fix/4349
Browse files Browse the repository at this point in the history
  • Loading branch information
weiznich committed Dec 5, 2024
2 parents 04a6ddf + 5bf0218 commit 2303c0b
Show file tree
Hide file tree
Showing 8 changed files with 150 additions and 34 deletions.
89 changes: 66 additions & 23 deletions diesel/src/expression/array_comparison.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@ use std::marker::PhantomData;
/// `IN` expression.
///
/// The postgres backend provided a specialized implementation
/// by using `left = ANY(values)` as optimized variant instead.
/// by using `left = ANY(values)` as optimized variant instead
/// if this is possible. For cases where this is not possible
/// like for example if values is a vector of arrays we
/// generate an ordinary `IN` expression instead.
#[derive(Debug, Copy, Clone, QueryId, ValidGrouping)]
#[non_exhaustive]
pub struct In<T, U> {
Expand All @@ -47,7 +50,10 @@ pub struct In<T, U> {
/// `NOT IN` expression.0
///
/// The postgres backend provided a specialized implementation
/// by using `left = ALL(values)` as optimized variant instead.
/// by using `left != ALL(values)` as optimized variant instead
/// if this is possible. For cases where this is not possible
/// like for example if values is a vector of arrays we
/// generate a ordinary `NOT IN` expression instead
#[derive(Debug, Copy, Clone, QueryId, ValidGrouping)]
#[non_exhaustive]
pub struct NotIn<T, U> {
Expand All @@ -61,12 +67,46 @@ impl<T, U> In<T, U> {
pub(crate) fn new(left: T, values: U) -> Self {
In { left, values }
}

pub(crate) fn walk_ansi_ast<'b, DB>(&'b self, mut out: AstPass<'_, 'b, DB>) -> QueryResult<()>
where
DB: Backend,
T: QueryFragment<DB>,
U: QueryFragment<DB> + InExpression,
{
if self.values.is_empty() {
out.push_sql("1=0");
} else {
self.left.walk_ast(out.reborrow())?;
out.push_sql(" IN (");
self.values.walk_ast(out.reborrow())?;
out.push_sql(")");
}
Ok(())
}
}

impl<T, U> NotIn<T, U> {
pub(crate) fn new(left: T, values: U) -> Self {
NotIn { left, values }
}

pub(crate) fn walk_ansi_ast<'b, DB>(&'b self, mut out: AstPass<'_, 'b, DB>) -> QueryResult<()>
where
DB: Backend,
T: QueryFragment<DB>,
U: QueryFragment<DB> + InExpression,
{
if self.values.is_empty() {
out.push_sql("1=1");
} else {
self.left.walk_ast(out.reborrow())?;
out.push_sql(" NOT IN (");
self.values.walk_ast(out.reborrow())?;
out.push_sql(")");
}
Ok(())
}
}

impl<T, U> Expression for In<T, U>
Expand Down Expand Up @@ -114,16 +154,8 @@ where
T: QueryFragment<DB>,
U: QueryFragment<DB> + InExpression,
{
fn walk_ast<'b>(&'b self, mut out: AstPass<'_, 'b, DB>) -> QueryResult<()> {
if self.values.is_empty() {
out.push_sql("1=0");
} else {
self.left.walk_ast(out.reborrow())?;
out.push_sql(" IN (");
self.values.walk_ast(out.reborrow())?;
out.push_sql(")");
}
Ok(())
fn walk_ast<'b>(&'b self, out: AstPass<'_, 'b, DB>) -> QueryResult<()> {
self.walk_ansi_ast(out)
}
}

Expand All @@ -145,16 +177,8 @@ where
T: QueryFragment<DB>,
U: QueryFragment<DB> + InExpression,
{
fn walk_ast<'b>(&'b self, mut out: AstPass<'_, 'b, DB>) -> QueryResult<()> {
if self.values.is_empty() {
out.push_sql("1=1");
} else {
self.left.walk_ast(out.reborrow())?;
out.push_sql(" NOT IN (");
self.values.walk_ast(out.reborrow())?;
out.push_sql(")");
}
Ok(())
fn walk_ast<'b>(&'b self, out: AstPass<'_, 'b, DB>) -> QueryResult<()> {
self.walk_ansi_ast(out)
}
}

Expand Down Expand Up @@ -217,6 +241,10 @@ pub trait InExpression {
/// Returns `true` if self represents an empty collection
/// Otherwise `false` is returned.
fn is_empty(&self) -> bool;

/// Returns `true` if the values clause represents
/// bind values and each bind value is a postgres array type
fn is_array(&self) -> bool;
}

impl<ST, F, S, D, W, O, LOf, G, H, LC> AsInExpression<ST>
Expand Down Expand Up @@ -306,6 +334,10 @@ where
fn is_empty(&self) -> bool {
self.values.is_empty()
}

fn is_array(&self) -> bool {
ST::IS_ARRAY
}
}

impl<ST, I, QS> SelectableExpression<QS> for Many<ST, I>
Expand Down Expand Up @@ -345,7 +377,18 @@ where
ST: SingleValue,
I: ToSql<ST, DB>,
{
fn walk_ast<'b>(&'b self, mut out: AstPass<'_, 'b, DB>) -> QueryResult<()> {
fn walk_ast<'b>(&'b self, out: AstPass<'_, 'b, DB>) -> QueryResult<()> {
self.walk_ansi_ast(out)
}
}

impl<ST, I> Many<ST, I> {
pub(crate) fn walk_ansi_ast<'b, DB>(&'b self, mut out: AstPass<'_, 'b, DB>) -> QueryResult<()>
where
DB: Backend + HasSqlType<ST>,
ST: SingleValue,
I: ToSql<ST, DB>,
{
out.unsafe_to_cache_prepared();
let mut first = true;
for value in &self.values {
Expand Down
3 changes: 3 additions & 0 deletions diesel/src/expression/subselect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ impl<T, ST: SqlType> InExpression for Subselect<T, ST> {
fn is_empty(&self) -> bool {
false
}
fn is_array(&self) -> bool {
false
}
}

impl<T, ST, QS> SelectableExpression<QS> for Subselect<T, ST>
Expand Down
9 changes: 7 additions & 2 deletions diesel/src/expression_methods/global_expression_methods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,9 @@ pub trait ExpressionMethods: Expression + Sized {
/// query will use the cache (assuming the subquery
/// itself is safe to cache).
/// On PostgreSQL, this method automatically performs a `= ANY()`
/// query.
/// query if this is possible. For cases where this is not possible
/// like for example if values is a vector of arrays we
/// generate an ordinary `IN` expression instead.
///
/// # Example
///
Expand Down Expand Up @@ -149,7 +151,10 @@ pub trait ExpressionMethods: Expression + Sized {
///
/// Queries using this method will not be
/// placed in the prepared statement cache. On PostgreSQL, this
/// method automatically performs a `!= ALL()` query.
/// method automatically performs a `!= ALL()` query if this is possible.
/// For cases where this is not possible
/// like for example if values is a vector of arrays we
/// generate an ordinary `NOT IN` expression instead.
///
/// # Example
///
Expand Down
14 changes: 14 additions & 0 deletions diesel/src/pg/expression/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,15 @@ where
ST: SqlType,
{
type SqlType = ST;

fn is_empty(&self) -> bool {
false
}

fn is_array(&self) -> bool {
// we want to use the `= ANY(_)` syntax
false
}
}

impl<T, ST> AsInExpression<ST> for ArrayLiteral<T, ST>
Expand All @@ -189,6 +195,7 @@ where
ST: SqlType,
{
type InExpression = Self;

fn as_in_expression(self) -> Self::InExpression {
self
}
Expand Down Expand Up @@ -296,9 +303,15 @@ where
ST: SqlType,
{
type SqlType = ST;

fn is_empty(&self) -> bool {
false
}

fn is_array(&self) -> bool {
// we want to use the `= ANY(_)` syntax
false
}
}

impl<T, ST> AsInExpression<ST> for ArraySubselect<T, ST>
Expand All @@ -307,6 +320,7 @@ where
ST: SqlType,
{
type InExpression = Self;

fn as_in_expression(self) -> Self::InExpression {
self
}
Expand Down
31 changes: 22 additions & 9 deletions diesel/src/pg/query_builder/query_fragment_impls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,14 @@ where
U: QueryFragment<Pg> + InExpression,
{
fn walk_ast<'b>(&'b self, mut out: AstPass<'_, 'b, Pg>) -> QueryResult<()> {
self.left.walk_ast(out.reborrow())?;
out.push_sql(" = ANY(");
self.values.walk_ast(out.reborrow())?;
out.push_sql(")");
if self.values.is_array() {
self.walk_ansi_ast(out)?;
} else {
self.left.walk_ast(out.reborrow())?;
out.push_sql(" = ANY(");
self.values.walk_ast(out.reborrow())?;
out.push_sql(")");
}
Ok(())
}
}
Expand All @@ -80,10 +84,14 @@ where
U: QueryFragment<Pg> + InExpression,
{
fn walk_ast<'b>(&'b self, mut out: AstPass<'_, 'b, Pg>) -> QueryResult<()> {
self.left.walk_ast(out.reborrow())?;
out.push_sql(" != ALL(");
self.values.walk_ast(out.reborrow())?;
out.push_sql(")");
if self.values.is_array() {
self.walk_ansi_ast(out)?;
} else {
self.left.walk_ast(out.reborrow())?;
out.push_sql(" != ALL(");
self.values.walk_ast(out.reborrow())?;
out.push_sql(")");
}
Ok(())
}
}
Expand All @@ -92,10 +100,15 @@ impl<ST, I> QueryFragment<Pg, PgStyleArrayComparison> for Many<ST, I>
where
ST: SingleValue,
Vec<I>: ToSql<Array<ST>, Pg>,
I: ToSql<ST, Pg>,
Pg: HasSqlType<ST>,
{
fn walk_ast<'b>(&'b self, mut out: AstPass<'_, 'b, Pg>) -> QueryResult<()> {
out.push_bind_param::<Array<ST>, Vec<I>>(&self.values)
if ST::IS_ARRAY {
self.walk_ansi_ast(out)
} else {
out.push_bind_param::<Array<ST>, Vec<I>>(&self.values)
}
}
}

Expand Down
3 changes: 3 additions & 0 deletions diesel/src/sql_types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -676,6 +676,9 @@ pub trait SqlType: 'static {
///
/// ['is_nullable`]: is_nullable
type IsNull: OneIsNullable<is_nullable::IsNullable> + OneIsNullable<is_nullable::NotNull>;

#[doc(hidden)]
const IS_ARRAY: bool = false;
}

/// Is one value of `IsNull` nullable?
Expand Down
5 changes: 5 additions & 0 deletions diesel_derives/src/sql_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,23 @@ pub fn derive(item: DeriveInput) -> Result<TokenStream> {
let model = Model::from_item(&item, true, false)?;

let struct_name = &item.ident;
let generic_count = item.generics.params.len();
let (impl_generics, ty_generics, where_clause) = item.generics.split_for_impl();

let sqlite_tokens = sqlite_tokens(&item, &model);
let mysql_tokens = mysql_tokens(&item, &model);
let pg_tokens = pg_tokens(&item, &model);

let is_array = struct_name == "Array" && generic_count == 1;

Ok(wrap_in_dummy_mod(quote! {
impl #impl_generics diesel::sql_types::SqlType
for #struct_name #ty_generics
#where_clause
{
type IsNull = diesel::sql_types::is_nullable::NotNull;

const IS_ARRAY: bool = #is_array;
}

impl #impl_generics diesel::sql_types::SingleValue
Expand Down
30 changes: 30 additions & 0 deletions diesel_tests/tests/filter_operators.rs
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,36 @@ fn filter_by_in_explicit_array() {
);
}

#[test]
#[cfg(feature = "postgres")]
fn filter_array_by_in() {
use crate::schema::posts::dsl::*;

let connection: &mut PgConnection = &mut connection();
let tag_combinations_to_look_for: &[&[&str]] = &[&["foo"], &["foo", "bar"], &["baz"]];
let result: Vec<i32> = posts
.filter(tags.eq_any(tag_combinations_to_look_for))
.select(id)
.load(connection)
.unwrap();
assert_eq!(result, &[] as &[i32]);
}

#[test]
#[cfg(feature = "postgres")]
fn filter_array_by_not_in() {
use crate::schema::posts::dsl::*;

let connection: &mut PgConnection = &mut connection();
let tag_combinations_to_look_for: &[&[&str]] = &[&["foo"], &["foo", "bar"], &["baz"]];
let result: Vec<i32> = posts
.filter(tags.ne_all(tag_combinations_to_look_for))
.select(id)
.load(connection)
.unwrap();
assert_eq!(result, &[] as &[i32]);
}

fn connection_with_3_users() -> TestConnection {
let mut connection = connection_with_sean_and_tess_in_users_table();
diesel::sql_query("INSERT INTO users (id, name) VALUES (3, 'Jim')")
Expand Down

0 comments on commit 2303c0b

Please sign in to comment.