diff --git a/py-polars/polars/dataframe/frame.py b/py-polars/polars/dataframe/frame.py index 04bc70ecc8f8..7d768e3d400f 100644 --- a/py-polars/polars/dataframe/frame.py +++ b/py-polars/polars/dataframe/frame.py @@ -3802,12 +3802,15 @@ def unpack_table_name(name: str) -> tuple[str | None, str | None, str]: ) # note: the catalog (database) should be a part of the connection string from sqlalchemy.engine import create_engine + from sqlalchemy.orm import Session + + if isinstance(connection, str): + engine_sa = create_engine(connection) + elif isinstance(connection, Session): + engine_sa = connection.connection().engine + else: + engine_sa = connection.engine # type: ignore[union-attr] - engine_sa = ( - create_engine(connection) - if isinstance(connection, str) - else connection.engine # type: ignore[union-attr] - ) catalog, db_schema, unpacked_table_name = unpack_table_name(table_name) if catalog: msg = f"Unexpected three-part table name; provide the database/catalog ({catalog!r}) on the connection URI" diff --git a/py-polars/tests/unit/io/database/test_write.py b/py-polars/tests/unit/io/database/test_write.py index 3b16a420db12..f79060fe729e 100644 --- a/py-polars/tests/unit/io/database/test_write.py +++ b/py-polars/tests/unit/io/database/test_write.py @@ -5,6 +5,8 @@ import pytest from sqlalchemy import create_engine +from sqlalchemy.orm import Session +from sqlalchemy.pool import NullPool import polars as pl from polars.io.database._utils import _open_adbc_connection @@ -233,3 +235,27 @@ def test_write_database_errors( match="unrecognised connection type", ): df.write_database(connection=True, table_name="misc") # type: ignore[arg-type] + + +@pytest.mark.write_disk() +def test_write_database_using_sa_session(tmp_path: str) -> None: + df = pl.DataFrame( + { + "key": ["xx", "yy", "zz"], + "value": [123, None, 789], + "other": [5.5, 7.0, None], + } + ) + table_name = "test_sa_session" + test_db_uri = f"sqlite:///{tmp_path}/test_sa_session.db" + engine = create_engine(test_db_uri, poolclass=NullPool) + with Session(engine) as session: + df.write_database(table_name, session) + session.commit() + + with Session(engine) as session: + result = pl.read_database( + query=f"select * from {table_name}", connection=session + ) + + assert_frame_equal(result, df)