Skip to content

Commit

Permalink
fix(python): Consistent behaviour when "infer_schema_length=0" for `r…
Browse files Browse the repository at this point in the history
…ead_excel` (#16840)
  • Loading branch information
alexander-beedie authored Jun 10, 2024
1 parent 92af769 commit 4a45a9d
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 3 deletions.
11 changes: 9 additions & 2 deletions py-polars/polars/io/spreadsheet/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,6 +746,7 @@ def _read_spreadsheet_openpyxl(
) -> pl.DataFrame:
"""Use the 'openpyxl' library to read data from the given worksheet."""
infer_schema_length = read_options.pop("infer_schema_length", None)
no_inference = infer_schema_length == 0
ws = parser[sheet_name]

# prefer detection of actual table objects; otherwise read
Expand All @@ -766,11 +767,12 @@ def _read_spreadsheet_openpyxl(
header.extend(row_values)
break

dtype = String if no_inference else None
series_data = []
for name, column_data in zip(header, zip(*rows_iter)):
if name:
values = [cell.value for cell in column_data]
if (dtype := (schema_overrides or {}).get(name)) == String:
if no_inference or (dtype := (schema_overrides or {}).get(name)) == String: # type: ignore[assignment]
# note: if we init series with mixed-type data (eg: str/int)
# the non-strings will become null, so we handle the cast here
values = [str(v) if (v is not None) else v for v in values]
Expand Down Expand Up @@ -803,7 +805,11 @@ def _read_spreadsheet_calamine(
msg = f"a more recent version of `fastexcel` is required (>= 0.9; found {fastexcel.__version__})"
raise ModuleUpgradeRequired(msg)

if (schema_overrides := (schema_overrides or {})) and fastexcel_version >= (0, 10):
schema_overrides = schema_overrides or {}
if read_options.get("schema_sample_rows") == 0:
# ref: https://github.com/ToucanToco/fastexcel/issues/236
read_options["dtypes"] = {idx: "string" for idx in range(16384)}
elif schema_overrides and fastexcel_version >= (0, 10):
parser_dtypes = read_options.get("dtypes", {})
for name, dtype in schema_overrides.items():
if name not in parser_dtypes:
Expand All @@ -821,6 +827,7 @@ def _read_spreadsheet_calamine(
parser_dtypes[name] = "duration"
elif base_dtype == Boolean:
parser_dtypes[name] = "bool"

read_options["dtypes"] = parser_dtypes

ws = parser.load_sheet_by_name(name=sheet_name, **read_options)
Expand Down
7 changes: 6 additions & 1 deletion py-polars/tests/unit/io/test_spreadsheet.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,14 +358,19 @@ def test_read_mixed_dtype_columns(

@pytest.mark.parametrize("engine", ["xlsx2csv", "openpyxl", "calamine"])
def test_write_excel_bytes(engine: ExcelSpreadsheetEngine) -> None:
df = pl.DataFrame({"A": [1.5, -2, 0, 3.0, -4.5, 5.0]})
df = pl.DataFrame({"colx": [1.5, -2, 0], "coly": ["a", None, "c"]})

excel_bytes = BytesIO()
df.write_excel(excel_bytes)

df_read = pl.read_excel(excel_bytes, engine=engine)
assert_frame_equal(df, df_read)

# also confirm consistent behaviour when 'infer_schema_length=0'
df_read = pl.read_excel(excel_bytes, engine=engine, infer_schema_length=0)
expected = pl.DataFrame({"colx": ["1.5", "-2", "0"], "coly": ["a", None, "c"]})
assert_frame_equal(expected, df_read)


def test_schema_overrides(path_xlsx: Path, path_xlsb: Path, path_ods: Path) -> None:
df1 = pl.read_excel(
Expand Down

0 comments on commit 4a45a9d

Please sign in to comment.