Skip to content

Commit

Permalink
fix(python): Make boolean reads consistent across all read_excel en…
Browse files Browse the repository at this point in the history
…gines (#17448)
  • Loading branch information
alexander-beedie authored Jul 6, 2024
1 parent afce7e5 commit 447146d
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 17 deletions.
14 changes: 13 additions & 1 deletion py-polars/polars/io/spreadsheet/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
Int64,
Null,
String,
UInt8,
)
from polars.datatypes.group import FLOAT_DTYPES, INTEGER_DTYPES, NUMERIC_DTYPES
from polars.dependencies import import_optional
Expand Down Expand Up @@ -506,6 +507,7 @@ def _read_spreadsheet(

read_options = (read_options or {}).copy()
engine_options = (engine_options or {}).copy()
schema_overrides = dict(schema_overrides or {})

# normalise some top-level parameters to 'read_options' entries
if engine == "calamine":
Expand Down Expand Up @@ -872,7 +874,7 @@ def _read_spreadsheet_calamine(
elif base_dtype == Duration:
parser_dtypes[name] = "duration"
elif base_dtype == Boolean:
parser_dtypes[name] = "bool"
parser_dtypes[name] = "boolean"

read_options["dtypes"] = parser_dtypes

Expand Down Expand Up @@ -936,11 +938,21 @@ def _read_spreadsheet_xlsx2csv(
if columns:
read_options["columns"] = columns

cast_to_boolean = []
if schema_overrides:
for col, dtype in schema_overrides.items():
if dtype == Boolean:
schema_overrides[col] = UInt8 # type: ignore[index]
cast_to_boolean.append(F.col(col).cast(Boolean))

df = _csv_buffer_to_frame(
csv_buffer,
separator=",",
read_options=read_options,
schema_overrides=schema_overrides,
raise_if_empty=raise_if_empty,
)
if cast_to_boolean:
df = df.with_columns(*cast_to_boolean)

return _reorder_columns(df, columns)
28 changes: 12 additions & 16 deletions py-polars/tests/unit/io/test_spreadsheet.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from tests.unit.conftest import FLOAT_DTYPES, NUMERIC_DTYPES

if TYPE_CHECKING:
from polars._typing import ExcelSpreadsheetEngine, SchemaDict, SelectorType
from polars._typing import ExcelSpreadsheetEngine, SelectorType

pytestmark = pytest.mark.slow()

Expand Down Expand Up @@ -209,39 +209,35 @@ def test_read_excel_all_sheets(


@pytest.mark.parametrize(
("engine", "schema_overrides"),
[
("xlsx2csv", {"datetime": pl.Datetime}),
("calamine", None),
("openpyxl", None),
],
"engine",
["xlsx2csv", "calamine", "openpyxl"],
)
def test_read_excel_basic_datatypes(
engine: ExcelSpreadsheetEngine,
schema_overrides: SchemaDict | None,
) -> None:
def test_read_excel_basic_datatypes(engine: ExcelSpreadsheetEngine) -> None:
df = pl.DataFrame(
{
"A": [1, 2, 3, 4, 5],
"fruits": ["banana", "banana", "apple", "apple", "banana"],
"floats": [1.1, 1.2, 1.3, 1.4, 1.5],
"datetime": [datetime(2023, 1, x) for x in range(1, 6)],
"nulls": [1, None, None, None, 1],
}
"nulls": [1, None, None, None, 0],
},
)
xls = BytesIO()
df.write_excel(xls, position="C5")

# check if can be read as it was written
schema_overrides = {"datetime": pl.Datetime, "nulls": pl.Boolean}
df_compare = df.with_columns(
pl.col(nm).cast(tp) for nm, tp in schema_overrides.items()
)
for sheet_id, sheet_name in ((None, None), (1, None), (None, "Sheet1")):
df = pl.read_excel(
df_from_excel = pl.read_excel(
xls,
sheet_id=sheet_id,
sheet_name=sheet_name,
engine=engine,
schema_overrides=schema_overrides,
)
assert_frame_equal(df, df)
assert_frame_equal(df_compare, df_from_excel)

# check some additional overrides
# (note: xlsx2csv can't currently convert datetime with trailing '00:00:00' to date)
Expand Down

0 comments on commit 447146d

Please sign in to comment.