Skip to content

Commit

Permalink
fix(python): Fix read_database(…,iter_batches=True) type annotations (
Browse files Browse the repository at this point in the history
  • Loading branch information
iliya-malecki authored Nov 17, 2024
1 parent da38e37 commit 8c0c7c5
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 42 deletions.
21 changes: 6 additions & 15 deletions py-polars/polars/io/database/_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,26 +29,17 @@

from polars.io.database._arrow_registry import ArrowDriverProperties

if sys.version_info >= (3, 10):
from typing import TypeAlias
else:
from typing_extensions import TypeAlias

if sys.version_info >= (3, 11):
from typing import Self
else:
from typing_extensions import Self

from sqlalchemy.sql.elements import TextClause
from sqlalchemy.sql.expression import Selectable

from polars import DataFrame
from polars._typing import ConnectionOrCursor, Cursor, SchemaDict

try:
from sqlalchemy.sql.expression import Selectable
except ImportError:
Selectable: TypeAlias = Any # type: ignore[no-redef]

from sqlalchemy.sql.elements import TextClause

_INVALID_QUERY_TYPES = {
"ALTER",
"ANALYZE",
Expand Down Expand Up @@ -207,7 +198,7 @@ def _from_arrow(
iter_batches: bool,
schema_overrides: SchemaDict | None,
infer_schema_length: int | None,
) -> DataFrame | Iterable[DataFrame] | None:
) -> DataFrame | Iterator[DataFrame] | None:
"""Return resultset data in Arrow format for frame init."""
from polars import DataFrame

Expand Down Expand Up @@ -253,7 +244,7 @@ def _from_rows(
iter_batches: bool,
schema_overrides: SchemaDict | None,
infer_schema_length: int | None,
) -> DataFrame | Iterable[DataFrame] | None:
) -> DataFrame | Iterator[DataFrame] | None:
"""Return resultset data row-wise for frame init."""
from polars import DataFrame

Expand Down Expand Up @@ -529,7 +520,7 @@ def to_polars(
batch_size: int | None = None,
schema_overrides: SchemaDict | None = None,
infer_schema_length: int | None = N_INFER_DEFAULT,
) -> DataFrame | Iterable[DataFrame]:
) -> DataFrame | Iterator[DataFrame]:
"""
Convert the result set to a DataFrame.
Expand Down
11 changes: 0 additions & 11 deletions py-polars/polars/io/database/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,11 @@
from polars.dependencies import import_optional

if TYPE_CHECKING:
import sys
from collections.abc import Coroutine

if sys.version_info >= (3, 10):
from typing import TypeAlias
else:
from typing_extensions import TypeAlias

from polars import DataFrame
from polars._typing import SchemaDict

try:
from sqlalchemy.sql.expression import Selectable
except ImportError:
Selectable: TypeAlias = Any # type: ignore[no-redef]


def _run_async(co: Coroutine[Any, Any, Any]) -> Any:
"""Run asynchronous code as if it was synchronous."""
Expand Down
22 changes: 6 additions & 16 deletions py-polars/polars/io/database/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,14 @@
from polars.io.database._executor import ConnectionExecutor

if TYPE_CHECKING:
import sys
from collections.abc import Iterable
from collections.abc import Iterator

if sys.version_info >= (3, 10):
from typing import TypeAlias
else:
from typing_extensions import TypeAlias
from sqlalchemy.sql.elements import TextClause
from sqlalchemy.sql.expression import Selectable

from polars import DataFrame
from polars._typing import ConnectionOrCursor, DbReadEngine, SchemaDict

try:
from sqlalchemy.sql.expression import Selectable
except ImportError:
Selectable: TypeAlias = Any # type: ignore[no-redef]

from sqlalchemy.sql.elements import TextClause


@overload
def read_database(
Expand All @@ -51,7 +41,7 @@ def read_database(
schema_overrides: SchemaDict | None = ...,
infer_schema_length: int | None = ...,
execute_options: dict[str, Any] | None = ...,
) -> Iterable[DataFrame]: ...
) -> Iterator[DataFrame]: ...


@overload
Expand All @@ -64,7 +54,7 @@ def read_database(
schema_overrides: SchemaDict | None = ...,
infer_schema_length: int | None = ...,
execute_options: dict[str, Any] | None = ...,
) -> DataFrame | Iterable[DataFrame]: ...
) -> DataFrame | Iterator[DataFrame]: ...


def read_database(
Expand All @@ -76,7 +66,7 @@ def read_database(
schema_overrides: SchemaDict | None = None,
infer_schema_length: int | None = N_INFER_DEFAULT,
execute_options: dict[str, Any] | None = None,
) -> DataFrame | Iterable[DataFrame]:
) -> DataFrame | Iterator[DataFrame]:
"""
Read the results of a SQL query into a DataFrame, given a connection object.
Expand Down

0 comments on commit 8c0c7c5

Please sign in to comment.