Skip to content

Commit

Permalink
fix(python): Fix handling of TextIOWrapper in write_csv (#17328)
Browse files Browse the repository at this point in the history
  • Loading branch information
ruihe774 authored Jul 5, 2024
1 parent 447cbcf commit 221bea8
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 44 deletions.
1 change: 1 addition & 0 deletions .github/workflows/test-python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ concurrency:
env:
RUSTFLAGS: -C debuginfo=0 # Do not produce debug symbols to keep memory usage down
RUST_BACKTRACE: 1
PYTHONUTF8: 1

defaults:
run:
Expand Down
5 changes: 1 addition & 4 deletions py-polars/polars/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import random
from collections import defaultdict
from collections.abc import Sized
from io import BytesIO, StringIO, TextIOWrapper
from io import BytesIO, StringIO
from operator import itemgetter
from pathlib import Path
from typing import (
Expand All @@ -24,7 +24,6 @@
NoReturn,
Sequence,
TypeVar,
cast,
get_args,
overload,
)
Expand Down Expand Up @@ -2695,8 +2694,6 @@ def write_csv(
should_return_buffer = True
elif isinstance(file, (str, os.PathLike)):
file = normalize_filepath(file)
elif isinstance(file, TextIOWrapper):
file = cast(TextIOWrapper, file.buffer)

self._df.write_csv(
file,
Expand Down
93 changes: 53 additions & 40 deletions py-polars/src/file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,36 +51,34 @@ impl PyFileLikeObject {
Cursor::new(buf)
}

/// Same as `PyFileLikeObject::new`, but validates that the underlying
/// Validates that the underlying
/// python object has a `read`, `write`, and `seek` methods in respect to parameters.
/// Will return a `TypeError` if object does not have `read`, `seek`, and `write` methods.
pub fn with_requirements(
object: PyObject,
pub fn ensure_requirements(
object: &Bound<PyAny>,
read: bool,
write: bool,
seek: bool,
) -> PyResult<Self> {
Python::with_gil(|py| {
if read && object.getattr(py, "read").is_err() {
return Err(PyErr::new::<PyTypeError, _>(
"Object does not have a .read() method.",
));
}
) -> PyResult<()> {
if read && object.getattr("read").is_err() {
return Err(PyErr::new::<PyTypeError, _>(
"Object does not have a .read() method.",
));
}

if seek && object.getattr(py, "seek").is_err() {
return Err(PyErr::new::<PyTypeError, _>(
"Object does not have a .seek() method.",
));
}
if seek && object.getattr("seek").is_err() {
return Err(PyErr::new::<PyTypeError, _>(
"Object does not have a .seek() method.",
));
}

if write && object.getattr(py, "write").is_err() {
return Err(PyErr::new::<PyTypeError, _>(
"Object does not have a .write() method.",
));
}
if write && object.getattr("write").is_err() {
return Err(PyErr::new::<PyTypeError, _>(
"Object does not have a .write() method.",
));
}

Ok(PyFileLikeObject::new(object))
})
Ok(())
}
}

Expand Down Expand Up @@ -196,7 +194,7 @@ fn get_either_file_and_path(
write: bool,
) -> PyResult<(EitherRustPythonFile, Option<PathBuf>)> {
Python::with_gil(|py| {
let py_f = py_f.bind(py);
let py_f = py_f.into_bound(py);
if let Ok(s) = py_f.extract::<Cow<str>>() {
let file_path = std::path::Path::new(&*s);
let file_path = resolve_homedir(file_path);
Expand All @@ -208,29 +206,24 @@ fn get_either_file_and_path(
Ok((EitherRustPythonFile::Rust(f), Some(file_path)))
} else {
let io = py.import_bound("io").unwrap();
let is_utf8_encoding = |py_f: &Bound<PyAny>| -> PyResult<bool> {
let encoding = py_f.getattr("encoding")?;
let encoding = encoding.extract::<Cow<str>>()?;
Ok(encoding.eq_ignore_ascii_case("utf-8") || encoding.eq_ignore_ascii_case("utf8"))
};
let flush_file = |py_f: &Bound<PyAny>| -> PyResult<()> {
py_f.getattr("flush")?.call0()?;
Ok(())
};
#[cfg(target_family = "unix")]
if let Some(fd) = ((py_f.is_exact_instance(&io.getattr("FileIO").unwrap())
|| py_f.is_exact_instance(&io.getattr("BufferedReader").unwrap())
|| py_f.is_exact_instance(&io.getattr("BufferedWriter").unwrap())
|| py_f.is_exact_instance(&io.getattr("BufferedRandom").unwrap())
|| py_f.is_exact_instance(&io.getattr("BufferedRWPair").unwrap())
|| (py_f.is_exact_instance(&io.getattr("TextIOWrapper").unwrap())
&& py_f
.getattr("encoding")
.ok()
.filter(|encoding| match encoding.extract::<Cow<str>>() {
Ok(encoding) => {
encoding.eq_ignore_ascii_case("utf-8")
|| encoding.eq_ignore_ascii_case("utf8")
},
Err(_) => false,
})
.is_some()))
&& (!write
|| py_f
.getattr("flush")
.and_then(|flush| flush.call0())
.is_ok()))
&& is_utf8_encoding(&py_f)?))
&& (!write || flush_file(&py_f).is_ok()))
.then(|| {
py_f.getattr("fileno")
.and_then(|fileno| fileno.call0())
Expand All @@ -256,7 +249,27 @@ fn get_either_file_and_path(
Ensure you pass a path to the file instead of a python file object when possible for best \
performance.");
}
let f = PyFileLikeObject::with_requirements(py_f.to_object(py), !write, write, !write)?;
// Unwrap TextIOWrapper
// Allow subclasses to allow things like pytest.capture.CaptureIO
let py_f = if py_f
.is_instance(&io.getattr("TextIOWrapper").unwrap())
.unwrap_or_default()
{
if !is_utf8_encoding(&py_f)? {
return Err(PyPolarsErr::from(
polars_err!(InvalidOperation: "file encoding is not UTF-8"),
)
.into());
}
if write {
flush_file(&py_f)?;
}
py_f.getattr("buffer")?
} else {
py_f
};
PyFileLikeObject::ensure_requirements(&py_f, !write, write, !write)?;
let f = PyFileLikeObject::new(py_f.to_object(py));
Ok((EitherRustPythonFile::Py(f), None))
}
})
Expand Down
26 changes: 26 additions & 0 deletions py-polars/tests/unit/io/test_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -2221,3 +2221,29 @@ def test_projection_applied_on_file_with_no_rows_16606(tmp_path: Path) -> None:

out = pl.scan_csv(path).select(columns).collect().columns
assert out == columns


@pytest.mark.write_disk()
def test_write_csv_to_dangling_file_17328(
df_no_lists: pl.DataFrame, tmp_path: Path
) -> None:
tmp_path.mkdir(exist_ok=True)
df_no_lists.write_csv((tmp_path / "dangling.csv").open("w"))


def test_write_csv_raise_on_non_utf8_17328(
df_no_lists: pl.DataFrame, tmp_path: Path
) -> None:
tmp_path.mkdir(exist_ok=True)
with pytest.raises(InvalidOperationError, match="file encoding is not UTF-8"):
df_no_lists.write_csv((tmp_path / "dangling.csv").open("w", encoding="gbk"))


@pytest.mark.write_disk()
def test_write_csv_appending_17328(tmp_path: Path) -> None:
tmp_path.mkdir(exist_ok=True)
with (tmp_path / "append.csv").open("w") as f:
f.write("# test\n")
pl.DataFrame({"col": ["value"]}).write_csv(f)
with (tmp_path / "append.csv").open("r") as f:
assert f.read() == "# test\ncol\nvalue\n"

0 comments on commit 221bea8

Please sign in to comment.