Skip to content

Commit

Permalink
fix(python): Support sa session (#17435)
Browse files Browse the repository at this point in the history
Co-authored-by: Ritchie Vink <[email protected]>
  • Loading branch information
phi-friday and ritchie46 authored Jul 5, 2024
1 parent c390fd7 commit 909fa88
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 5 deletions.
13 changes: 8 additions & 5 deletions py-polars/polars/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -3802,12 +3802,15 @@ def unpack_table_name(name: str) -> tuple[str | None, str | None, str]:
)
# note: the catalog (database) should be a part of the connection string
from sqlalchemy.engine import create_engine
from sqlalchemy.orm import Session

if isinstance(connection, str):
engine_sa = create_engine(connection)
elif isinstance(connection, Session):
engine_sa = connection.connection().engine
else:
engine_sa = connection.engine # type: ignore[union-attr]

engine_sa = (
create_engine(connection)
if isinstance(connection, str)
else connection.engine # type: ignore[union-attr]
)
catalog, db_schema, unpacked_table_name = unpack_table_name(table_name)
if catalog:
msg = f"Unexpected three-part table name; provide the database/catalog ({catalog!r}) on the connection URI"
Expand Down
26 changes: 26 additions & 0 deletions py-polars/tests/unit/io/database/test_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

import pytest
from sqlalchemy import create_engine
from sqlalchemy.orm import Session
from sqlalchemy.pool import NullPool

import polars as pl
from polars.io.database._utils import _open_adbc_connection
Expand Down Expand Up @@ -233,3 +235,27 @@ def test_write_database_errors(
match="unrecognised connection type",
):
df.write_database(connection=True, table_name="misc") # type: ignore[arg-type]


@pytest.mark.write_disk()
def test_write_database_using_sa_session(tmp_path: str) -> None:
df = pl.DataFrame(
{
"key": ["xx", "yy", "zz"],
"value": [123, None, 789],
"other": [5.5, 7.0, None],
}
)
table_name = "test_sa_session"
test_db_uri = f"sqlite:///{tmp_path}/test_sa_session.db"
engine = create_engine(test_db_uri, poolclass=NullPool)
with Session(engine) as session:
df.write_database(table_name, session)
session.commit()

with Session(engine) as session:
result = pl.read_database(
query=f"select * from {table_name}", connection=session
)

assert_frame_equal(result, df)

0 comments on commit 909fa88

Please sign in to comment.