Skip to content

Commit

Permalink
Improve: Unlock GIL in Str.write_to
Browse files Browse the repository at this point in the history
Closes #105
  • Loading branch information
ashvardanian committed Mar 18, 2024
1 parent ca3a410 commit 0d4af91
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 0 deletions.
6 changes: 6 additions & 0 deletions python/lib.c
Original file line number Diff line number Diff line change
Expand Up @@ -1272,15 +1272,21 @@ static PyObject *Str_write_to(PyObject *self, PyObject *args, PyObject *kwargs)
}
memcpy(path_buffer, path.start, path.length);

// Unlock the Global Interpreter Lock (GIL) to allow other threads to run
// while the current thread is waiting for the file to be written.
PyThreadState *gil_state = PyEval_SaveThread();
FILE *file_pointer = fopen(path_buffer, "wb");
if (file_pointer == NULL) {
PyEval_RestoreThread(gil_state);
PyErr_SetFromErrnoWithFilename(PyExc_OSError, path_buffer);
free(path_buffer);
PyEval_RestoreThread(gil_state);
return NULL;
}

setbuf(file_pointer, NULL); // Set the stream to unbuffered
int status = fwrite(text.start, 1, text.length, file_pointer);
PyEval_RestoreThread(gil_state);
if (status != text.length) {
PyErr_SetFromErrnoWithFilename(PyExc_OSError, path_buffer);
free(path_buffer);
Expand Down
21 changes: 21 additions & 0 deletions scripts/test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from random import choice, randint
from string import ascii_lowercase
from typing import Optional
import tempfile
import os

import pytest

Expand Down Expand Up @@ -104,6 +106,25 @@ def test_unit_buffer_protocol():
assert "".join([c.decode("utf-8") for c in arr.tolist()]) == "hello"


def test_str_write_to():
native = "line1\nline2\nline3"
big = Str(native) # Assuming Str is your custom class

# Create a temporary file
with tempfile.NamedTemporaryFile(delete=False) as tmpfile:
temp_filename = tmpfile.name # Store the name for later use

try:
big.write_to(temp_filename)
with open(temp_filename, "r") as file:
content = file.read()
assert (
content == native
), "The content of the file does not match the expected output"
finally:
os.remove(temp_filename)


def test_unit_split():
native = "line1\nline2\nline3"
big = Str(native)
Expand Down

0 comments on commit 0d4af91

Please sign in to comment.