From 503a36e5fe997c939cf34a312fcbd3181077054e Mon Sep 17 00:00:00 2001 From: Trangar Date: Tue, 14 Feb 2023 20:56:40 +0100 Subject: [PATCH] Added regexp support in sqlite (#2189) * CHANGELOG: mention that users should upgrade CLI * Added regexp support in sqlite * Added a with_regexp function to sqliteconnectoptions * Fixed tests * Undo CHANGELOG.md change --------- Co-authored-by: Austin Bonander Co-authored-by: Victor Koenders --- Cargo.lock | 2 + Cargo.toml | 1 + sqlx-core/Cargo.toml | 1 - sqlx-sqlite/Cargo.toml | 5 + sqlx-sqlite/src/connection/establish.rs | 23 ++- sqlx-sqlite/src/lib.rs | 3 + sqlx-sqlite/src/options/mod.rs | 36 +++- sqlx-sqlite/src/regexp.rs | 237 ++++++++++++++++++++++++ 8 files changed, 298 insertions(+), 10 deletions(-) create mode 100644 sqlx-sqlite/src/regexp.rs diff --git a/Cargo.lock b/Cargo.lock index 4738340651..8c036d61ac 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3231,7 +3231,9 @@ dependencies = [ "libsqlite3-sys", "log", "percent-encoding", + "regex", "serde", + "sqlx", "sqlx-core", "time 0.3.17", "tracing", diff --git a/Cargo.toml b/Cargo.toml index da437616d6..ce4598323e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -109,6 +109,7 @@ mac_address = ["sqlx-core/mac_address", "sqlx-macros?/mac_address", "sqlx-postgr rust_decimal = ["sqlx-core/rust_decimal", "sqlx-macros?/rust_decimal", "sqlx-mysql?/rust_decimal", "sqlx-postgres?/rust_decimal"] time = ["sqlx-core/time", "sqlx-macros?/time", "sqlx-mysql?/time", "sqlx-postgres?/time", "sqlx-sqlite?/time"] uuid = ["sqlx-core/uuid", "sqlx-macros?/uuid", "sqlx-mysql?/uuid", "sqlx-postgres?/uuid", "sqlx-sqlite?/uuid"] +regexp = ["sqlx-sqlite?/regexp"] [workspace.dependencies] # Driver crates diff --git a/sqlx-core/Cargo.toml b/sqlx-core/Cargo.toml index f220c55dd8..4a8f55330c 100644 --- a/sqlx-core/Cargo.toml +++ b/sqlx-core/Cargo.toml @@ -70,7 +70,6 @@ futures-intrusive = "0.4.0" futures-util = { version = "0.3.19", default-features = false, features = ["alloc", "sink", "io"] } generic-array = { version = "0.14.4", default-features = false, optional = true } hex = "0.4.3" - log = { version = "0.4.14", default-features = false } memchr = { version = "2.4.1", default-features = false } num-bigint = { version = "0.4.0", default-features = false, optional = true, features = ["std"] } diff --git a/sqlx-sqlite/Cargo.toml b/sqlx-sqlite/Cargo.toml index e5749192f8..c32a29650a 100644 --- a/sqlx-sqlite/Cargo.toml +++ b/sqlx-sqlite/Cargo.toml @@ -18,6 +18,7 @@ offline = ["sqlx-core/offline", "serde"] migrate = ["sqlx-core/migrate"] chrono = ["dep:chrono", "bitflags"] +regexp = ["dep:regex"] [dependencies] futures-core = { version = "0.3.19", default-features = false } @@ -44,6 +45,7 @@ log = "0.4.17" tracing = { version = "0.1.37", features = ["log"] } serde = { version = "1.0.145", features = ["derive"], optional = true } +regex = { version = "1.5.5", optional = true } [dependencies.libsqlite3-sys] version = "0.25.1" @@ -58,3 +60,6 @@ features = [ [dependencies.sqlx-core] version = "=0.6.2" path = "../sqlx-core" + +[dev-dependencies] +sqlx = { version = "0.6.2", path = "..", default-features = false, features = ["macros", "runtime-tokio", "tls-none"] } \ No newline at end of file diff --git a/sqlx-sqlite/src/connection/establish.rs b/sqlx-sqlite/src/connection/establish.rs index b380fce5fb..c5425dd19b 100644 --- a/sqlx-sqlite/src/connection/establish.rs +++ b/sqlx-sqlite/src/connection/establish.rs @@ -44,6 +44,8 @@ pub struct EstablishParams { extensions: IndexMap>, pub(crate) thread_name: String, pub(crate) command_channel_size: usize, + #[cfg(feature = "regexp")] + register_regexp_function: bool, } impl EstablishParams { @@ -145,6 +147,8 @@ impl EstablishParams { extensions, thread_name: (options.thread_name)(THREAD_ID.fetch_add(1, Ordering::AcqRel)), command_channel_size: options.command_channel_size, + #[cfg(feature = "regexp")] + register_regexp_function: options.register_regexp_function, }) } @@ -238,12 +242,10 @@ impl EstablishParams { &err_msg, )))); } - } - - // Preempt any hypothetical security issues arising from leaving ENABLE_LOAD_EXTENSION - // on by disabling the flag again once we've loaded all the requested modules. - // Fail-fast (via `?`) if disabling the extension loader didn't work for some reason, - // avoids an unexpected state going undetected. + } // Preempt any hypothetical security issues arising from leaving ENABLE_LOAD_EXTENSION + // on by disabling the flag again once we've loaded all the requested modules. + // Fail-fast (via `?`) if disabling the extension loader didn't work for some reason, + // avoids an unexpected state going undetected. unsafe { Self::sqlite3_set_load_extension( handle.as_ptr(), @@ -252,6 +254,15 @@ impl EstablishParams { } } + #[cfg(feature = "regexp")] + if self.register_regexp_function { + // configure a `regexp` function for sqlite, it does not come with one by default + let status = crate::regexp::register(handle.as_ptr()); + if status != SQLITE_OK { + return Err(Error::Database(Box::new(SqliteError::new(handle.as_ptr())))); + } + } + // Configure a busy timeout // This causes SQLite to automatically sleep in increasing intervals until the time // when there is something locked during [sqlite3_step]. diff --git a/sqlx-sqlite/src/lib.rs b/sqlx-sqlite/src/lib.rs index e3a425554c..2ddc576301 100644 --- a/sqlx-sqlite/src/lib.rs +++ b/sqlx-sqlite/src/lib.rs @@ -72,6 +72,9 @@ mod value; #[cfg(feature = "any")] pub mod any; +#[cfg(feature = "regexp")] +mod regexp; + #[cfg(feature = "migrate")] mod migrate; diff --git a/sqlx-sqlite/src/options/mod.rs b/sqlx-sqlite/src/options/mod.rs index 9b171e7770..43d2939d6a 100644 --- a/sqlx-sqlite/src/options/mod.rs +++ b/sqlx-sqlite/src/options/mod.rs @@ -38,7 +38,7 @@ use sqlx_core::IndexMap; /// ```rust,no_run /// # use sqlx_core::connection::ConnectOptions; /// # use sqlx_core::error::Error; -/// use sqlx::sqlite::{SqliteConnectOptions, SqliteJournalMode}; +/// # use sqlx_sqlite::{SqliteConnectOptions, SqliteJournalMode}; /// use std::str::FromStr; /// /// # fn main() { @@ -79,6 +79,9 @@ pub struct SqliteConnectOptions { pub(crate) serialized: bool, pub(crate) thread_name: Arc String + Send + Sync + 'static>>, + + #[cfg(feature = "regexp")] + pub(crate) register_regexp_function: bool, } impl Default for SqliteConnectOptions { @@ -185,6 +188,8 @@ impl SqliteConnectOptions { thread_name: Arc::new(DebugFn(|id| format!("sqlx-sqlite-worker-{}", id))), command_channel_size: 50, row_channel_size: 50, + #[cfg(feature = "regexp")] + register_regexp_function: false, } } @@ -431,8 +436,8 @@ impl SqliteConnectOptions { /// will be loaded in the order they are added. /// ```rust,no_run /// # use sqlx_core::error::Error; - /// use std::str::FromStr; - /// use sqlx::sqlite::SqliteConnectOptions; + /// # use std::str::FromStr; + /// # use sqlx_sqlite::SqliteConnectOptions; /// # fn options() -> Result { /// let options = SqliteConnectOptions::from_str("sqlite://data.db")? /// .extension("vsv") @@ -458,4 +463,29 @@ impl SqliteConnectOptions { .insert(extension_name.into(), Some(entry_point.into())); self } + + /// Register a regexp function that allows using regular expressions in queries. + /// + /// ``` + /// # use std::str::FromStr; + /// # use sqlx::{ConnectOptions, Connection, Row}; + /// # use sqlx_sqlite::SqliteConnectOptions; + /// # async fn run() -> sqlx::Result<()> { + /// let mut sqlite = SqliteConnectOptions::from_str("sqlite://:memory:")? + /// .with_regexp() + /// .connect() + /// .await?; + /// let tables = sqlx::query("SELECT name FROM sqlite_schema WHERE name REGEXP 'foo(\\d+)bar'") + /// .fetch_all(&mut sqlite) + /// .await?; + /// # Ok(()) + /// # } + /// ``` + /// + /// This uses the [`regex`] crate, and is only enabled when you enable the `regex` feature is enabled on sqlx + #[cfg(feature = "regexp")] + pub fn with_regexp(mut self) -> Self { + self.register_regexp_function = true; + self + } } diff --git a/sqlx-sqlite/src/regexp.rs b/sqlx-sqlite/src/regexp.rs new file mode 100644 index 0000000000..93f4038899 --- /dev/null +++ b/sqlx-sqlite/src/regexp.rs @@ -0,0 +1,237 @@ +#![deny(missing_docs, clippy::pedantic)] +#![allow(clippy::cast_sign_loss)] // some lengths returned from sqlite3 are `i32`, but rust needs `usize` + +//! Here be dragons +//! +//! We need to register a custom REGEX implementation for sqlite +//! some useful resources: +//! - rusqlite has an example implementation: +//! - sqlite supports registering custom C functions: +//! - sqlite also supports a `A REGEXP B` syntax, but ONLY if the user implements `regex(B, A)` +//! - Note that A and B are indeed swapped: the regex comes first, the field comes second +//! - +//! - sqlx has a way to safely get a sqlite3 pointer: +//! - +//! - + +use libsqlite3_sys as ffi; +use log::error; +use regex::Regex; +use std::sync::Arc; + +/// The function name for sqlite3. This must be "regexp\0" +static FN_NAME: &[u8] = b"regexp\0"; + +/// Register the regex function with sqlite. +/// +/// Returns the result code of `sqlite3_create_function_v2` +pub fn register(sqlite3: *mut ffi::sqlite3) -> i32 { + unsafe { + ffi::sqlite3_create_function_v2( + // the database connection + sqlite3, + // the function name. Must be up to 255 bytes, and 0-terminated + FN_NAME.as_ptr().cast(), + // the number of arguments this function accepts. We want 2 arguments: The regex and the field + 2, + // we want all our strings to be UTF8, and this function will return the same output with the same inputs + ffi::SQLITE_UTF8 | ffi::SQLITE_DETERMINISTIC, + // pointer to user data. We're not using user data + std::ptr::null_mut(), + // xFunc to be executed when we are invoked + Some(sqlite3_regexp_func), + // xStep, should be NULL for scalar functions + None, + // xFinal, should be NULL for scalar functions + None, + // xDestroy, called when this function is deregistered. Should be used to clean up our pointer to user-data + None, + ) + } +} + +/// A function to be called on each invocation of `regex(REGEX, FIELD)` from sqlite3 +/// +/// - `ctx`: a pointer to the current sqlite3 context +/// - `n_arg`: The length of `args` +/// - `args`: the arguments of this function call +unsafe extern "C" fn sqlite3_regexp_func( + ctx: *mut ffi::sqlite3_context, + n_arg: i32, + args: *mut *mut ffi::sqlite3_value, +) { + // check the arg size. sqlite3 should already ensure this is only 2 args but we want to double check + if n_arg != 2 { + eprintln!("n_arg expected to be 2, is {n_arg}"); + ffi::sqlite3_result_error_code(ctx, ffi::SQLITE_CONSTRAINT_FUNCTION); + return; + } + + // arg0: Regex + let regex = if let Some(regex) = get_regex_from_arg(ctx, *args.offset(0), 0) { + regex + } else { + return; + }; + + // arg1: value + let value = if let Some(text) = get_text_from_arg(ctx, *args.offset(1)) { + text + } else { + return; + }; + + // if the regex matches the value, set the result int as 1, else as 0 + if regex.is_match(value) { + ffi::sqlite3_result_int(ctx, 1); + } else { + ffi::sqlite3_result_int(ctx, 0); + } +} + +/// Get the regex from the given `arg` at the given `index`. +/// +/// First this will check to see if the value exists in sqlite's `auxdata`. If it does, that regex will be returned. +/// sqlite is able to clean up this data at any point, but rust's [`Arc`] guarantees make sure things don't break. +/// +/// If this value does not exist in `auxdata`, [`try_load_value`] is called and a regex is created from this. If any of +/// those fail, a message is printed and `None` is returned. +/// +/// After this regex is created it is stored in `auxdata` and loaded again. If it fails to load, this means that +/// something inside of sqlite3 went wrong, and we return `None`. +/// +/// If this value is stored correctly, or if it already existed, the arc reference counter is increased and this value is returned. +unsafe fn get_regex_from_arg( + ctx: *mut ffi::sqlite3_context, + arg: *mut ffi::sqlite3_value, + index: i32, +) -> Option> { + // try to get the auxdata for this field + let ptr = ffi::sqlite3_get_auxdata(ctx, index); + if !ptr.is_null() { + // if we have it, turn it into an Arc. + // we need to make sure to call `increment_strong_count` because the returned `Arc` decrement this when it goes out of scope + let ptr = ptr as *const Regex; + Arc::increment_strong_count(ptr); + return Some(Arc::from_raw(ptr)); + } + // get the text for this field + let value = get_text_from_arg(ctx, arg)?; + // try to compile it into a regex + let regex = match Regex::new(value) { + Ok(regex) => Arc::new(regex), + Err(e) => { + error!("Invalid regex {value:?}: {e:?}"); + ffi::sqlite3_result_error_code(ctx, ffi::SQLITE_CONSTRAINT_FUNCTION); + return None; + } + }; + // set the regex as auxdata for the next time around + ffi::sqlite3_set_auxdata( + ctx, + index, + // make sure to call `Arc::clone` here, setting the strong count to 2. + // this will be cleaned up at 2 points: + // - when the returned arc goes out of scope + // - when sqlite decides to clean it up an calls `cleanup_arc_regex_pointer` + Arc::into_raw(Arc::clone(®ex)) as *mut _, + Some(cleanup_arc_regex_pointer), + ); + Some(regex) +} + +/// Get a text reference of the value of `arg`. If this value is not a string value, an error is printed and `None` is +/// returned. +/// +/// The returned `&str` is valid for lifetime `'a` which can be determined by the caller. This lifetime should **not** +/// outlive `ctx`. +unsafe fn get_text_from_arg<'a>( + ctx: *mut ffi::sqlite3_context, + arg: *mut ffi::sqlite3_value, +) -> Option<&'a str> { + let ty = ffi::sqlite3_value_type(arg); + if ty == ffi::SQLITE_TEXT { + let ptr = ffi::sqlite3_value_text(arg); + let len = ffi::sqlite3_value_bytes(arg); + let slice = std::slice::from_raw_parts(ptr.cast(), len as usize); + match std::str::from_utf8(slice) { + Ok(result) => Some(result), + Err(e) => { + log::error!("Incoming text is not valid UTF8: {e:?}",); + ffi::sqlite3_result_error_code(ctx, ffi::SQLITE_CONSTRAINT_FUNCTION); + None + } + } + } else { + None + } +} + +/// Clean up the `Arc` that is stored in the given `ptr`. +unsafe extern "C" fn cleanup_arc_regex_pointer(ptr: *mut std::ffi::c_void) { + Arc::decrement_strong_count(ptr.cast::()); +} + +#[cfg(test)] +mod tests { + use sqlx::{ConnectOptions, Connection, Row}; + use std::str::FromStr; + + async fn test_db() -> crate::SqliteConnection { + let mut conn = crate::SqliteConnectOptions::from_str("sqlite://:memory:") + .unwrap() + .with_regexp() + .connect() + .await + .unwrap(); + sqlx::query("CREATE TABLE test (col TEXT NOT NULL)") + .execute(&mut conn) + .await + .unwrap(); + for i in 0..10 { + sqlx::query("INSERT INTO test VALUES (?)") + .bind(format!("value {}", i)) + .execute(&mut conn) + .await + .unwrap(); + } + conn + } + + #[sqlx::test] + async fn test_regexp_does_not_fail() { + let mut conn = test_db().await; + let result = sqlx::query("SELECT col FROM test WHERE col REGEXP 'foo.*bar'") + .fetch_all(&mut conn) + .await + .expect("Could not execute query"); + assert!(result.is_empty()); + } + + #[sqlx::test] + async fn test_regexp_filters_correctly() { + let mut conn = test_db().await; + + let result = sqlx::query("SELECT col FROM test WHERE col REGEXP '.*2'") + .fetch_all(&mut conn) + .await + .expect("Could not execute query"); + assert_eq!(result.len(), 1); + assert_eq!(result[0].get::(0), String::from("value 2")); + + let result = sqlx::query("SELECT col FROM test WHERE col REGEXP '^3'") + .fetch_all(&mut conn) + .await + .expect("Could not execute query"); + assert!(result.is_empty()); + } + + #[sqlx::test] + async fn test_invalid_regexp_should_fail() { + let mut conn = test_db().await; + let result = sqlx::query("SELECT col from test WHERE col REGEXP '(?:?)'") + .execute(&mut conn) + .await; + assert!(matches!(result, Err(sqlx::Error::Database(_)))); + } +}