Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gh-104223: Fix issues with inheriting from buffer classes #104227

Merged
merged 14 commits into from
May 8, 2023
Merged
1 change: 1 addition & 0 deletions Include/cpython/memoryobject.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ typedef struct {
#define _Py_MEMORYVIEW_FORTRAN 0x004 /* Fortran contiguous layout */
#define _Py_MEMORYVIEW_SCALAR 0x008 /* scalar: ndim = 0 */
#define _Py_MEMORYVIEW_PIL 0x010 /* PIL-style layout */
#define _Py_MEMORYVIEW_RESTRICTED 0x020 /* Disallow additional references */
kumaraditya303 marked this conversation as resolved.
Show resolved Hide resolved

typedef struct {
PyObject_VAR_HEAD
Expand Down
3 changes: 2 additions & 1 deletion Include/internal/pycore_memoryobject.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ extern "C" {
#endif

PyObject *
PyMemoryView_FromObjectAndFlags(PyObject *v, int flags);
_PyMemoryView_FromBufferProc(PyObject *v, int flags,
getbufferproc bufferproc);

#ifdef __cplusplus
}
Expand Down
111 changes: 111 additions & 0 deletions Lib/test/test_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4579,6 +4579,117 @@ def test_c_buffer(self):
buf.__release_buffer__(mv)
self.assertEqual(buf.references, 0)

def test_inheritance(self):
class A(bytearray):
def __buffer__(self, flags):
return super().__buffer__(flags)

a = A(b"hello")
mv = memoryview(a)
self.assertEqual(mv.tobytes(), b"hello")

def test_inheritance_releasebuffer(self):
rb_call_count = 0
class B(bytearray):
def __buffer__(self, flags):
return super().__buffer__(flags)
def __release_buffer__(self, view):
nonlocal rb_call_count
rb_call_count += 1
super().__release_buffer__(view)

b = B(b"hello")
with memoryview(b) as mv:
self.assertEqual(mv.tobytes(), b"hello")
self.assertEqual(rb_call_count, 0)
self.assertEqual(rb_call_count, 1)

def test_inherit_but_return_something_else(self):
class A(bytearray):
def __buffer__(self, flags):
return memoryview(b"hello")

a = A(b"hello")
with memoryview(a) as mv:
self.assertEqual(mv.tobytes(), b"hello")

rb_call_count = 0
rb_raised = False
class B(bytearray):
def __buffer__(self, flags):
return memoryview(b"hello")
def __release_buffer__(self, view):
nonlocal rb_call_count
rb_call_count += 1
try:
super().__release_buffer__(view)
except ValueError:
nonlocal rb_raised
rb_raised = True

b = B(b"hello")
with memoryview(b) as mv:
self.assertEqual(mv.tobytes(), b"hello")
self.assertEqual(rb_call_count, 0)
self.assertEqual(rb_call_count, 1)
self.assertIs(rb_raised, True)

def test_override_only_release(self):
class C(bytearray):
def __release_buffer__(self, buffer):
super().__release_buffer__(buffer)

c = C(b"hello")
with memoryview(c) as mv:
self.assertEqual(mv.tobytes(), b"hello")

def test_release_saves_reference(self):
smuggled_buffer = None

class C(bytearray):
def __release_buffer__(s, buffer: memoryview):
with self.assertRaises(ValueError):
memoryview(buffer)
with self.assertRaises(ValueError):
buffer.cast("b")
with self.assertRaises(ValueError):
buffer.toreadonly()
with self.assertRaises(ValueError):
buffer[:1]
with self.assertRaises(ValueError):
buffer.__buffer__(0)
nonlocal smuggled_buffer
smuggled_buffer = buffer
self.assertEqual(buffer.tobytes(), b"hello")
super().__release_buffer__(buffer)

c = C(b"hello")
with memoryview(c) as mv:
self.assertEqual(mv.tobytes(), b"hello")
c.clear()
with self.assertRaises(ValueError):
smuggled_buffer.tobytes()

def test_release_saves_reference_no_subclassing(self):
ba = bytearray(b"hello")

class C:
def __buffer__(self, flags):
return memoryview(ba)

def __release_buffer__(self, buffer):
self.buffer = buffer

c = C()
with memoryview(c) as mv:
self.assertEqual(mv.tobytes(), b"hello")
self.assertEqual(c.buffer.tobytes(), b"hello")

with self.assertRaises(BufferError):
ba.clear()
c.buffer.release()
ba.clear()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a test in which the class has multiple inheritance? Also tests for when there are two or classes in mro which support buffer protocol?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding some tests. It's hard to come up with a case where multiple bases support the C buffer protocol, because that will almost inevitably lead to:

Traceback (most recent call last):
  File "/Users/jelle/py/cpython/Lib/test/test_buffer.py", line 4707, in test_two_buffer_bases
    class A(bytearray, bytes):
TypeError: multiple bases have instance lay-out conflict

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, maybe try one var length and one pure python type which implements buffer protocol.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, just pushed. Thanks for the review!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding some tests. It's hard to come up with a case where multiple bases support the C buffer protocol, because that will almost inevitably lead to:

I see, you would have to write a custom type in C whose layout doesn't conflicts for that. Probably an object which has just object header and supports buffer protocol.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TBH, I agree this is an edge case but I try to be extra careful when touching typeobject, hope you don't mind the extra work.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to catch this now then right before 3.12 final!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Confirmed the bug I outlined above, I'll have to step out for a few hours and fix it after that.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thinking about it more, that crash isn't actually due to PEP 688: it reproduces without any Python __buffer__ method. So I don't think there's any more interesting cases to cover for PEP 688.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Filed #104297 for that case.


if __name__ == "__main__":
unittest.main()
1 change: 1 addition & 0 deletions Objects/bytearrayobject.c
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ static void
bytearray_releasebuffer(PyByteArrayObject *obj, Py_buffer *view)
{
obj->ob_exports--;
assert(obj->ob_exports >= 0);
}

static int
Expand Down
45 changes: 44 additions & 1 deletion Objects/memoryobject.c
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,20 @@ PyTypeObject _PyManagedBuffer_Type = {
return -1; \
}

#define CHECK_RESTRICTED(mv) \
if (((PyMemoryViewObject *)(mv))->flags & _Py_MEMORYVIEW_RESTRICTED) { \
PyErr_SetString(PyExc_ValueError, \
"cannot create new view on restricted memoryview"); \
return NULL; \
}

#define CHECK_RESTRICTED_INT(mv) \
if (((PyMemoryViewObject *)(mv))->flags & _Py_MEMORYVIEW_RESTRICTED) { \
PyErr_SetString(PyExc_ValueError, \
"cannot create new view on restricted memoryview"); \
return -1; \
}

/* See gh-92888. These macros signal that we need to check the memoryview
again due to possible read after frees. */
#define CHECK_RELEASED_AGAIN(mv) CHECK_RELEASED(mv)
Expand Down Expand Up @@ -781,14 +795,15 @@ PyMemoryView_FromBuffer(const Py_buffer *info)
using the given flags.
If the object is a memoryview, the new memoryview must be registered
with the same managed buffer. Otherwise, a new managed buffer is created. */
PyObject *
static PyObject *
PyMemoryView_FromObjectAndFlags(PyObject *v, int flags)
{
_PyManagedBufferObject *mbuf;

if (PyMemoryView_Check(v)) {
PyMemoryViewObject *mv = (PyMemoryViewObject *)v;
CHECK_RELEASED(mv);
CHECK_RESTRICTED(mv);
return mbuf_add_view(mv->mbuf, &mv->view);
}
else if (PyObject_CheckBuffer(v)) {
Expand All @@ -806,6 +821,30 @@ PyMemoryView_FromObjectAndFlags(PyObject *v, int flags)
Py_TYPE(v)->tp_name);
return NULL;
}

/* Create a memoryview from an object that implements the buffer protocol,
using the given flags.
If the object is a memoryview, the new memoryview must be registered
with the same managed buffer. Otherwise, a new managed buffer is created. */
PyObject *
_PyMemoryView_FromBufferProc(PyObject *v, int flags, getbufferproc bufferproc)
{
_PyManagedBufferObject *mbuf = mbuf_alloc();
if (mbuf == NULL)
return NULL;

int res = bufferproc(v, &mbuf->master, flags);
if (res < 0) {
mbuf->master.obj = NULL;
Py_DECREF(mbuf);
return NULL;
}

PyObject *ret = mbuf_add_view(mbuf, NULL);
Py_DECREF(mbuf);
return ret;
}

/* Create a memoryview from an object that implements the buffer protocol.
If the object is a memoryview, the new memoryview must be registered
with the same managed buffer. Otherwise, a new managed buffer is created. */
Expand Down Expand Up @@ -1397,6 +1436,7 @@ memoryview_cast_impl(PyMemoryViewObject *self, PyObject *format,
Py_ssize_t ndim = 1;

CHECK_RELEASED(self);
CHECK_RESTRICTED(self);

if (!MV_C_CONTIGUOUS(self->flags)) {
PyErr_SetString(PyExc_TypeError,
Expand Down Expand Up @@ -1452,6 +1492,7 @@ memoryview_toreadonly_impl(PyMemoryViewObject *self)
/*[clinic end generated code: output=2c7e056f04c99e62 input=dc06d20f19ba236f]*/
{
CHECK_RELEASED(self);
CHECK_RESTRICTED(self);
/* Even if self is already readonly, we still need to create a new
* object for .release() to work correctly.
*/
Expand All @@ -1474,6 +1515,7 @@ memory_getbuf(PyMemoryViewObject *self, Py_buffer *view, int flags)
int baseflags = self->flags;

CHECK_RELEASED_INT(self);
CHECK_RESTRICTED_INT(self);

/* start with complete information */
*view = *base;
Expand Down Expand Up @@ -2535,6 +2577,7 @@ memory_subscript(PyMemoryViewObject *self, PyObject *key)
return memory_item(self, index);
}
else if (PySlice_Check(key)) {
CHECK_RESTRICTED(self);
PyMemoryViewObject *sliced;

sliced = (PyMemoryViewObject *)mbuf_add_view(self->mbuf, view);
Expand Down
Loading