Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(python): Reduce scan_csv() (and friends') memory usage when using BytesIO #20649

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
24 changes: 16 additions & 8 deletions crates/polars-python/src/file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -322,14 +322,20 @@ pub fn get_python_scan_source_input(
write: bool,
) -> PyResult<PythonScanSourceInput> {
Python::with_gil(|py| {
let py_f_0 = py_f;
let py_f = py_f_0.clone_ref(py).into_bound(py);
let py_f = py_f.into_bound(py);

// CPython has some internal tricks that means much of the time
// BytesIO.getvalue() involves no memory copying, unlike
// BytesIO.read(). So we want to handle BytesIO specially in order
// to save memory.
let py_f = read_if_bytesio(py_f);

// If the pyobject is a `bytes` class
if let Ok(b) = py_f.downcast::<PyBytes>() {
return Ok(PythonScanSourceInput::Buffer(MemSlice::from_arc(
b.as_bytes(),
Arc::new(py_f_0),
// We want to specifically keep alive the PyBytes object.
Arc::new(b.clone().unbind()),
)));
}

Expand Down Expand Up @@ -373,15 +379,17 @@ pub fn get_file_like(f: PyObject, truncate: bool) -> PyResult<Box<dyn FileLike>>
Ok(get_either_file(f, truncate)?.into_dyn())
}

/// If the give file-like is a BytesIO, read its contents.
/// If the give file-like is a BytesIO, read its contents in a memory-efficient
/// way.
fn read_if_bytesio(py_f: Bound<PyAny>) -> Bound<PyAny> {
if py_f.getattr("read").is_ok() {
let bytes_io = py_f.py().import("io").unwrap().getattr("BytesIO").unwrap();
if py_f.is_instance(&bytes_io).unwrap() {
// Note that BytesIO has some memory optimizations ensuring that much of
// the time getvalue() doesn't need to copy the underlying data:
let Ok(bytes) = py_f.call_method0("getvalue") else {
return py_f;
};
if bytes.downcast::<PyBytes>().is_ok() || bytes.downcast::<PyString>().is_ok() {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am fairly certain the StringIO path wasn't hit in existing usage of this function, and it's wrong because the caller expects strings to be paths... and doesn't match the function name, either.

return bytes.clone();
}
return bytes;
}
py_f
}
Expand Down
1 change: 1 addition & 0 deletions py-polars/src/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ unsafe impl<A: GlobalAlloc> GlobalAlloc for TracemallocAllocator<A> {
}

unsafe fn realloc(&self, ptr: *mut u8, layout: std::alloc::Layout, new_size: usize) -> *mut u8 {
PyTraceMalloc_Untrack(TRACEMALLOC_DOMAIN, ptr as uintptr_t);
let result = self.wrapped_alloc.realloc(ptr, layout, new_size);
PyTraceMalloc_Track(TRACEMALLOC_DOMAIN, result as uintptr_t, new_size);
result
Expand Down
11 changes: 10 additions & 1 deletion py-polars/tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import random
import string
import sys
import time
import tracemalloc
from typing import TYPE_CHECKING, Any, cast

Expand Down Expand Up @@ -205,7 +206,11 @@ def get_peak(self) -> int:
return tracemalloc.get_traced_memory()[1]


@pytest.fixture
# The bizarre syntax is from
# https://github.com/pytest-dev/pytest/issues/1368#issuecomment-2344450259 - we
# need to mark any test using this fixture as slow because we have a sleep
# added to work around a CPython bug, see the end of the function.
@pytest.fixture(params=[pytest.param(0, marks=pytest.mark.slow)])
def memory_usage_without_pyarrow() -> Generator[MemoryUsage, Any, Any]:
"""
Provide an API for measuring peak memory usage.
Expand All @@ -231,6 +236,10 @@ def memory_usage_without_pyarrow() -> Generator[MemoryUsage, Any, Any]:
try:
yield MemoryUsage()
finally:
# Workaround for https://github.com/python/cpython/issues/128679
time.sleep(1)
gc.collect()

tracemalloc.stop()


Expand Down
28 changes: 28 additions & 0 deletions py-polars/tests/unit/io/test_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

if TYPE_CHECKING:
from polars._typing import SchemaDict
from tests.unit.conftest import MemoryUsage


@dataclass
Expand Down Expand Up @@ -929,3 +930,30 @@ def test_predicate_stats_eval_nested_binary() -> None:
),
pl.DataFrame({"x": [2]}),
)


@pytest.mark.slow
@pytest.mark.parametrize("streaming", [True, False])
def test_scan_csv_bytesio_memory_usage(
streaming: bool,
memory_usage_without_pyarrow: MemoryUsage,
) -> None:
memory_usage = memory_usage_without_pyarrow

# Create CSV that is ~70-85 MB in size:
f = io.BytesIO()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we test this with a little bit less data. I assume it is noticeable with less.

df = pl.DataFrame({"mydata": pl.int_range(0, 10_000_000, eager=True)})
df.write_csv(f)
assert 70_000_000 < f.tell() < 85_000_000
f.seek(0, 0)

# A lazy scan shouldn't make a full copy of the data:
starting_memory = memory_usage.get_current()
assert (
pl.scan_csv(f)
.filter(pl.col("mydata") == 9_999_999)
.collect(new_streaming=streaming) # type: ignore[call-overload]
.item()
== 9_999_999
)
assert memory_usage.get_peak() - starting_memory < 10_000_000
Loading