diff --git a/py-polars/polars/dataframe/frame.py b/py-polars/polars/dataframe/frame.py index 5cc2a1748137..b2972a164424 100644 --- a/py-polars/polars/dataframe/frame.py +++ b/py-polars/polars/dataframe/frame.py @@ -3878,15 +3878,19 @@ def unpack_table_name(name: str) -> tuple[str | None, str | None, str]: min_err_prefix="pandas >= 2.2 requires", ) # note: the catalog (database) should be a part of the connection string - from sqlalchemy.engine import create_engine + from sqlalchemy.engine import Connectable, create_engine from sqlalchemy.orm import Session + sa_object: Connectable if isinstance(connection, str): - engine_sa = create_engine(connection) + sa_object = create_engine(connection) elif isinstance(connection, Session): - engine_sa = connection.connection().engine + sa_object = connection.connection() + elif isinstance(connection, Connectable): + sa_object = connection else: - engine_sa = connection.engine # type: ignore[union-attr] + error_msg = f"unexpected connection type {type(connection)}" + raise TypeError(error_msg) catalog, db_schema, unpacked_table_name = unpack_table_name(table_name) if catalog: @@ -3900,7 +3904,7 @@ def unpack_table_name(name: str) -> tuple[str | None, str | None, str]: ).to_sql( name=unpacked_table_name, schema=db_schema, - con=engine_sa, + con=sa_object, if_exists=if_table_exists, index=False, **(engine_options or {}), diff --git a/py-polars/tests/unit/io/database/test_write.py b/py-polars/tests/unit/io/database/test_write.py index f79060fe729e..1a995e31df64 100644 --- a/py-polars/tests/unit/io/database/test_write.py +++ b/py-polars/tests/unit/io/database/test_write.py @@ -259,3 +259,62 @@ def test_write_database_using_sa_session(tmp_path: str) -> None: ) assert_frame_equal(result, df) + + +@pytest.mark.write_disk() +@pytest.mark.parametrize("pass_connection", [True, False]) +def test_write_database_sa_rollback(tmp_path: str, pass_connection: bool) -> None: + df = pl.DataFrame( + { + "key": ["xx", "yy", "zz"], + "value": [123, None, 789], + "other": [5.5, 7.0, None], + } + ) + table_name = "test_sa_rollback" + test_db_uri = f"sqlite:///{tmp_path}/test_sa_rollback.db" + engine = create_engine(test_db_uri, poolclass=NullPool) + with Session(engine) as session: + if pass_connection: + conn = session.connection() + df.write_database(table_name, conn) + else: + df.write_database(table_name, session) + session.rollback() + + with Session(engine) as session: + count = pl.read_database( + query=f"select count(*) from {table_name}", connection=session + ).item(0, 0) + + assert isinstance(count, int) + assert count == 0 + + +@pytest.mark.write_disk() +@pytest.mark.parametrize("pass_connection", [True, False]) +def test_write_database_sa_commit(tmp_path: str, pass_connection: bool) -> None: + df = pl.DataFrame( + { + "key": ["xx", "yy", "zz"], + "value": [123, None, 789], + "other": [5.5, 7.0, None], + } + ) + table_name = "test_sa_commit" + test_db_uri = f"sqlite:///{tmp_path}/test_sa_commit.db" + engine = create_engine(test_db_uri, poolclass=NullPool) + with Session(engine) as session: + if pass_connection: + conn = session.connection() + df.write_database(table_name, conn) + else: + df.write_database(table_name, session) + session.commit() + + with Session(engine) as session: + result = pl.read_database( + query=f"select * from {table_name}", connection=session + ) + + assert_frame_equal(result, df)