diff --git a/crates/polars-python/src/file.rs b/crates/polars-python/src/file.rs index 741c5c695152..e48111077619 100644 --- a/crates/polars-python/src/file.rs +++ b/crates/polars-python/src/file.rs @@ -322,14 +322,20 @@ pub fn get_python_scan_source_input( write: bool, ) -> PyResult { Python::with_gil(|py| { - let py_f_0 = py_f; - let py_f = py_f_0.clone_ref(py).into_bound(py); + let py_f = py_f.into_bound(py); + + // CPython has some internal tricks that means much of the time + // BytesIO.getvalue() involves no memory copying, unlike + // BytesIO.read(). So we want to handle BytesIO specially in order + // to save memory. + let py_f = read_if_bytesio(py_f); // If the pyobject is a `bytes` class if let Ok(b) = py_f.downcast::() { return Ok(PythonScanSourceInput::Buffer(MemSlice::from_arc( b.as_bytes(), - Arc::new(py_f_0), + // We want to specifically keep alive the PyBytes object. + Arc::new(b.clone().unbind()), ))); } @@ -373,15 +379,17 @@ pub fn get_file_like(f: PyObject, truncate: bool) -> PyResult> Ok(get_either_file(f, truncate)?.into_dyn()) } -/// If the give file-like is a BytesIO, read its contents. +/// If the give file-like is a BytesIO, read its contents in a memory-efficient +/// way. fn read_if_bytesio(py_f: Bound) -> Bound { - if py_f.getattr("read").is_ok() { + let bytes_io = py_f.py().import("io").unwrap().getattr("BytesIO").unwrap(); + if py_f.is_instance(&bytes_io).unwrap() { + // Note that BytesIO has some memory optimizations ensuring that much of + // the time getvalue() doesn't need to copy the underlying data: let Ok(bytes) = py_f.call_method0("getvalue") else { return py_f; }; - if bytes.downcast::().is_ok() || bytes.downcast::().is_ok() { - return bytes.clone(); - } + return bytes; } py_f } diff --git a/py-polars/src/memory.rs b/py-polars/src/memory.rs index e7abf2b7d51c..70b0ceb8ac8c 100644 --- a/py-polars/src/memory.rs +++ b/py-polars/src/memory.rs @@ -69,6 +69,7 @@ unsafe impl GlobalAlloc for TracemallocAllocator { } unsafe fn realloc(&self, ptr: *mut u8, layout: std::alloc::Layout, new_size: usize) -> *mut u8 { + PyTraceMalloc_Untrack(TRACEMALLOC_DOMAIN, ptr as uintptr_t); let result = self.wrapped_alloc.realloc(ptr, layout, new_size); PyTraceMalloc_Track(TRACEMALLOC_DOMAIN, result as uintptr_t, new_size); result diff --git a/py-polars/tests/unit/conftest.py b/py-polars/tests/unit/conftest.py index 8197872831b5..11af90824728 100644 --- a/py-polars/tests/unit/conftest.py +++ b/py-polars/tests/unit/conftest.py @@ -5,6 +5,7 @@ import random import string import sys +import time import tracemalloc from typing import TYPE_CHECKING, Any, cast @@ -205,7 +206,11 @@ def get_peak(self) -> int: return tracemalloc.get_traced_memory()[1] -@pytest.fixture +# The bizarre syntax is from +# https://github.com/pytest-dev/pytest/issues/1368#issuecomment-2344450259 - we +# need to mark any test using this fixture as slow because we have a sleep +# added to work around a CPython bug, see the end of the function. +@pytest.fixture(params=[pytest.param(0, marks=pytest.mark.slow)]) def memory_usage_without_pyarrow() -> Generator[MemoryUsage, Any, Any]: """ Provide an API for measuring peak memory usage. @@ -231,6 +236,10 @@ def memory_usage_without_pyarrow() -> Generator[MemoryUsage, Any, Any]: try: yield MemoryUsage() finally: + # Workaround for https://github.com/python/cpython/issues/128679 + time.sleep(1) + gc.collect() + tracemalloc.stop() diff --git a/py-polars/tests/unit/io/test_scan.py b/py-polars/tests/unit/io/test_scan.py index 569c27260513..4f19698488f2 100644 --- a/py-polars/tests/unit/io/test_scan.py +++ b/py-polars/tests/unit/io/test_scan.py @@ -16,6 +16,7 @@ if TYPE_CHECKING: from polars._typing import SchemaDict + from tests.unit.conftest import MemoryUsage @dataclass @@ -929,3 +930,30 @@ def test_predicate_stats_eval_nested_binary() -> None: ), pl.DataFrame({"x": [2]}), ) + + +@pytest.mark.slow +@pytest.mark.parametrize("streaming", [True, False]) +def test_scan_csv_bytesio_memory_usage( + streaming: bool, + memory_usage_without_pyarrow: MemoryUsage, +) -> None: + memory_usage = memory_usage_without_pyarrow + + # Create CSV that is ~70-85 MB in size: + f = io.BytesIO() + df = pl.DataFrame({"mydata": pl.int_range(0, 10_000_000, eager=True)}) + df.write_csv(f) + assert 70_000_000 < f.tell() < 85_000_000 + f.seek(0, 0) + + # A lazy scan shouldn't make a full copy of the data: + starting_memory = memory_usage.get_current() + assert ( + pl.scan_csv(f) + .filter(pl.col("mydata") == 9_999_999) + .collect(new_streaming=streaming) # type: ignore[call-overload] + .item() + == 9_999_999 + ) + assert memory_usage.get_peak() - starting_memory < 10_000_000