Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update build.yaml #267

Merged
merged 5 commits into from
Oct 23, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
code changes
wseaton committed Oct 23, 2024
commit 3544b3d104668ad08f3adf858ed89724deb3ae98
11 changes: 5 additions & 6 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -3,7 +3,6 @@ use pythonize::pythonize;
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;

use pyo3::wrap_pyfunction;
use pythonize::PythonizeError;

use sqlparser::ast::Statement;
@@ -20,12 +19,12 @@ use visitor::{extract_expressions, extract_relations, mutate_expressions, mutate
/// Available `dialects`: https://github.com/sqlparser-rs/sqlparser-rs/blob/main/src/dialect/mod.rs#L189-L206
#[pyfunction]
#[pyo3(text_signature = "(sql, dialect)")]
fn parse_sql(py: Python, sql: &str, dialect: &str) -> PyResult<PyObject> {
fn parse_sql(py: Python, sql: String, dialect: String) -> PyResult<PyObject> {
let chosen_dialect = dialect_from_str(dialect).unwrap_or_else(|| {
println!("The dialect you chose was not recognized, falling back to 'generic'");
Box::new(GenericDialect {})
});
let parse_result = Parser::parse_sql(&*chosen_dialect, sql);
let parse_result = Parser::parse_sql(&*chosen_dialect, &sql);

let output = match parse_result {
Ok(statements) => pythonize(py, &statements).map_err(|e| {
@@ -40,13 +39,13 @@ fn parse_sql(py: Python, sql: &str, dialect: &str) -> PyResult<PyObject> {
}
};

Ok(output)
Ok(output.into())
}

/// This utility function allows reconstituing a modified AST back into list of SQL queries.
#[pyfunction]
#[pyo3(text_signature = "(ast)")]
fn restore_ast(_py: Python, ast: &PyAny) -> PyResult<Vec<String>> {
fn restore_ast(_py: Python, ast: &Bound<'_, PyAny>) -> PyResult<Vec<String>> {
let parse_result: Result<Vec<Statement>, PythonizeError> = pythonize::depythonize(ast);

let output = match parse_result {
@@ -66,7 +65,7 @@ fn restore_ast(_py: Python, ast: &PyAny) -> PyResult<Vec<String>> {
}

#[pymodule]
fn sqloxide(_py: Python, m: &PyModule) -> PyResult<()> {
fn sqloxide(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(parse_sql, m)?)?;
m.add_function(wrap_pyfunction!(restore_ast, m)?)?;
// TODO: maybe refactor into seperate module
18 changes: 9 additions & 9 deletions src/visitor.rs
Original file line number Diff line number Diff line change
@@ -10,7 +10,7 @@ use sqlparser::ast::{
};

// Refactored function for handling depythonization
fn depythonize_query(parsed_query: &PyAny) -> Result<Vec<Statement>, PyErr> {
fn depythonize_query(parsed_query: &Bound<'_, PyAny>) -> Result<Vec<Statement>, PyErr> {
match pythonize::depythonize(parsed_query) {
Ok(statements) => Ok(statements),
Err(e) => {
@@ -27,7 +27,7 @@ where
T: Sized + Serialize,
{
match pythonize::pythonize(py, &output) {
Ok(p) => Ok(p),
Ok(p) => Ok(p.into()),
Err(e) => {
let msg = e.to_string();
Err(PyValueError::new_err(format!(
@@ -39,7 +39,7 @@ where

#[pyfunction]
#[pyo3(text_signature = "(parsed_query)")]
pub fn extract_relations(py: Python, parsed_query: &PyAny) -> PyResult<PyObject> {
pub fn extract_relations(py: Python, parsed_query: &Bound<'_, PyAny>) -> PyResult<PyObject> {
let statements = depythonize_query(parsed_query)?;

let mut relations = Vec::new();
@@ -55,7 +55,7 @@ pub fn extract_relations(py: Python, parsed_query: &PyAny) -> PyResult<PyObject>

#[pyfunction]
#[pyo3(text_signature = "(parsed_query, func)")]
pub fn mutate_relations(_py: Python, parsed_query: &PyAny, func: &PyAny) -> PyResult<Vec<String>> {
pub fn mutate_relations(_py: Python, parsed_query: &Bound<'_, PyAny>, func: &Bound<'_, PyAny>) -> PyResult<Vec<String>> {
let mut statements = depythonize_query(parsed_query)?;

for statement in &mut statements {
@@ -85,8 +85,8 @@ pub fn mutate_relations(_py: Python, parsed_query: &PyAny, func: &PyAny) -> PyRe

#[pyfunction]
#[pyo3(text_signature = "(parsed_query, func)")]
pub fn mutate_expressions(py: Python, parsed_query: &PyAny, func: &PyAny) -> PyResult<Vec<String>> {
let mut statements = depythonize_query(parsed_query)?;
pub fn mutate_expressions(py: Python, parsed_query: &Bound<'_, PyAny>, func: &Bound<'_, PyAny>) -> PyResult<Vec<String>> {
let mut statements: Vec<Statement> = depythonize_query(parsed_query)?;

for statement in &mut statements {
visit_expressions_mut(statement, |expr| {
@@ -110,7 +110,7 @@ pub fn mutate_expressions(py: Python, parsed_query: &PyAny, func: &PyAny) -> PyR
}
};

*expr = match pythonize::depythonize(func_result) {
*expr = match pythonize::depythonize(&func_result) {
Ok(val) => val,
Err(e) => {
let msg = e.to_string();
@@ -132,8 +132,8 @@ pub fn mutate_expressions(py: Python, parsed_query: &PyAny, func: &PyAny) -> PyR

#[pyfunction]
#[pyo3(text_signature = "(parsed_query)")]
pub fn extract_expressions(py: Python, parsed_query: &PyAny) -> PyResult<PyObject> {
let statements = depythonize_query(parsed_query)?;
pub fn extract_expressions(py: Python, parsed_query: &Bound<'_, PyAny>) -> PyResult<PyObject> {
let statements: Vec<Statement> = depythonize_query(parsed_query)?;

let mut expressions = Vec::new();
for statement in statements {