Skip to content

Commit

Permalink
Register Scalar/MessageMapContainerTypes as virtual subclasses of
Browse files Browse the repository at this point in the history
MutableMapping instead of inheriting directly in python cpp extension.

This prevents these from using abc.ABCMeta metaclass to avoid deprecation warning:
DeprecationWarning: Type google.protobuf.internal.cpp._message.MessageMapContainer uses PyType_Spec with a metaclass that has custom tp_new. This is deprecated and will no longer be allowed in Python 3.14.
PiperOrigin-RevId: 734302096
  • Loading branch information
anandolee authored and copybara-github committed Mar 6, 2025
1 parent cdbde21 commit 9a0b591
Showing 1 changed file with 51 additions and 27 deletions.
78 changes: 51 additions & 27 deletions python/google/protobuf/pyext/map_container.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@

#include "google/protobuf/pyext/map_container.h"

#include <Python.h>

#include <cstddef>
#include <cstdint>
#include <memory>
#include <string>
Expand All @@ -30,7 +33,7 @@ namespace python {
class MapReflectionFriend {
public:
// Methods that are in common between the map types.
static PyObject* Contains(PyObject* _self, PyObject* key);
static int Contains(PyObject* _self, PyObject* key);
static Py_ssize_t Length(PyObject* _self);
static PyObject* GetIterator(PyObject* _self);
static PyObject* IterNext(PyObject* _self);
Expand Down Expand Up @@ -328,7 +331,7 @@ PyObject* MapReflectionFriend::MergeFrom(PyObject* _self, PyObject* arg) {
Py_RETURN_NONE;
}

PyObject* MapReflectionFriend::Contains(PyObject* _self, PyObject* key) {
int MapReflectionFriend::Contains(PyObject* _self, PyObject* key) {
MapContainer* self = GetMap(_self);

const Message* message = self->parent->message;
Expand All @@ -337,14 +340,14 @@ PyObject* MapReflectionFriend::Contains(PyObject* _self, PyObject* key) {
MapKey map_key;

if (!PythonToMapKey(self, key, &map_key, &map_key_string)) {
return nullptr;
return -1;
}

if (reflection->ContainsMapKey(*message, self->parent_field_descriptor,
map_key)) {
Py_RETURN_TRUE;
return 1;
} else {
Py_RETURN_FALSE;
return 0;
}
}

Expand Down Expand Up @@ -450,11 +453,11 @@ static PyObject* ScalarMapSetdefault(PyObject* self, PyObject* args) {
return nullptr;
}

ScopedPyObjectPtr is_present(MapReflectionFriend::Contains(self, key));
if (is_present == nullptr) {
int is_present = MapReflectionFriend::Contains(self, key);
if (is_present < 0) {
return nullptr;
}
if (PyObject_IsTrue(is_present.get())) {
if (is_present) {
return MapReflectionFriend::ScalarMapGetItem(self, key);
}

Expand All @@ -476,12 +479,12 @@ static PyObject* ScalarMapGet(PyObject* self, PyObject* args,
return nullptr;
}

ScopedPyObjectPtr is_present(MapReflectionFriend::Contains(self, key));
if (is_present.get() == nullptr) {
int is_present = MapReflectionFriend::Contains(self, key);
if (is_present < 0) {
return nullptr;
}

if (PyObject_IsTrue(is_present.get())) {
if (is_present) {
return MapReflectionFriend::ScalarMapGetItem(self, key);
} else {
if (default_value != nullptr) {
Expand Down Expand Up @@ -534,8 +537,6 @@ static void ScalarMapDealloc(PyObject* _self) {
}

static PyMethodDef ScalarMapMethods[] = {
{"__contains__", MapReflectionFriend::Contains, METH_O,
"Tests whether a key is a member of the map."},
{"clear", (PyCFunction)Clear, METH_NOARGS,
"Removes all elements from the map."},
{"setdefault", (PyCFunction)ScalarMapSetdefault, METH_VARARGS,
Expand All @@ -561,6 +562,7 @@ static PyType_Slot ScalarMapContainer_Type_slots[] = {
{Py_mp_length, (void*)MapReflectionFriend::Length},
{Py_mp_subscript, (void*)MapReflectionFriend::ScalarMapGetItem},
{Py_mp_ass_subscript, (void*)MapReflectionFriend::ScalarMapSetItem},
{Py_sq_contains, (void*)MapReflectionFriend::Contains},
{Py_tp_methods, (void*)ScalarMapMethods},
{Py_tp_iter, (void*)MapReflectionFriend::GetIterator},
{Py_tp_repr, (void*)MapReflectionFriend::ScalarMapToStr},
Expand Down Expand Up @@ -727,12 +729,12 @@ PyObject* MessageMapGet(PyObject* self, PyObject* args, PyObject* kwargs) {
return nullptr;
}

ScopedPyObjectPtr is_present(MapReflectionFriend::Contains(self, key));
if (is_present.get() == nullptr) {
int is_present = MapReflectionFriend::Contains(self, key);
if (is_present < 0) {
return nullptr;
}

if (PyObject_IsTrue(is_present.get())) {
if (is_present) {
return MapReflectionFriend::MessageMapGetItem(self, key);
} else {
if (default_value != nullptr) {
Expand All @@ -757,8 +759,6 @@ static void MessageMapDealloc(PyObject* _self) {
}

static PyMethodDef MessageMapMethods[] = {
{"__contains__", (PyCFunction)MapReflectionFriend::Contains, METH_O,
"Tests whether the map contains this element."},
{"clear", (PyCFunction)Clear, METH_NOARGS,
"Removes all elements from the map."},
{"setdefault", (PyCFunction)MessageMapSetdefault, METH_VARARGS,
Expand Down Expand Up @@ -786,6 +786,7 @@ static PyType_Slot MessageMapContainer_Type_slots[] = {
{Py_mp_length, (void*)MapReflectionFriend::Length},
{Py_mp_subscript, (void*)MapReflectionFriend::MessageMapGetItem},
{Py_mp_ass_subscript, (void*)MapReflectionFriend::MessageMapSetItem},
{Py_sq_contains, (void*)MapReflectionFriend::Contains},
{Py_tp_methods, (void*)MessageMapMethods},
{Py_tp_iter, (void*)MapReflectionFriend::GetIterator},
{Py_tp_repr, (void*)MapReflectionFriend::MessageMapToStr},
Expand Down Expand Up @@ -910,6 +911,30 @@ PyTypeObject MapIterator_Type = {
nullptr, // tp_init
};

PyTypeObject* Py_AddClassWithRegister(PyType_Spec* spec, PyObject* virtual_base,
const char** methods) {
PyObject* type = PyType_FromSpec(spec);
PyObject* ret1 = PyObject_CallMethod(virtual_base, "register", "O", type);
if (!ret1) {
Py_XDECREF(type);
return nullptr;
}
for (size_t i = 0; methods[i] != nullptr; i++) {
PyObject* method = PyObject_GetAttrString(virtual_base, methods[i]);
if (!method) {
Py_XDECREF(type);
return nullptr;
}
int ret2 = PyObject_SetAttrString(type, methods[i], method);
if (ret2 < 0) {
Py_XDECREF(type);
return nullptr;
}
}

return (PyTypeObject*)type;
}

bool InitMapContainers() {
// ScalarMapContainer_Type derives from our MutableMapping type.
ScopedPyObjectPtr abc(PyImport_ImportModule("collections.abc"));
Expand All @@ -923,21 +948,20 @@ bool InitMapContainers() {
return false;
}

Py_INCREF(mutable_mapping.get());
ScopedPyObjectPtr bases(PyTuple_Pack(1, mutable_mapping.get()));
if (bases == nullptr) {
return false;
}
const char* methods[] = {"keys", "items", "values", "__eq__", "__ne__",
"pop", "popitem", "update", nullptr};

ScalarMapContainer_Type = reinterpret_cast<PyTypeObject*>(
PyType_FromSpecWithBases(&ScalarMapContainer_Type_spec, bases.get()));
ScalarMapContainer_Type =
reinterpret_cast<PyTypeObject*>(Py_AddClassWithRegister(
&ScalarMapContainer_Type_spec, mutable_mapping.get(), methods));

if (PyType_Ready(&MapIterator_Type) < 0) {
return false;
}

MessageMapContainer_Type = reinterpret_cast<PyTypeObject*>(
PyType_FromSpecWithBases(&MessageMapContainer_Type_spec, bases.get()));
MessageMapContainer_Type =
reinterpret_cast<PyTypeObject*>(Py_AddClassWithRegister(
&MessageMapContainer_Type_spec, mutable_mapping.get(), methods));
return true;
}

Expand Down

0 comments on commit 9a0b591

Please sign in to comment.