diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index aa6d937d9cbc31..f8291d8689dd95 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -78,6 +78,11 @@ Python/traceback.c @iritkatriel **/*importlib/resources/* @jaraco @warsaw @FFY00 **/importlib/metadata/* @jaraco @warsaw +# Subinterpreters +Lib/test/support/interpreters/** @ericsnowcurrently +Modules/_xx*interp*module.c @ericsnowcurrently +Lib/test/test_interpreters/** @ericsnowcurrently + # Dates and times **/*datetime* @pganssle @abalkin **/*str*time* @pganssle @abalkin @@ -148,7 +153,15 @@ Doc/c-api/stable.rst @encukou **/*itertools* @rhettinger **/*collections* @rhettinger **/*random* @rhettinger -**/*queue* @rhettinger +Doc/**/*queue* @rhettinger +PCbuild/**/*queue* @rhettinger +Modules/_queuemodule.c @rhettinger +Lib/*queue*.py @rhettinger +Lib/asyncio/*queue*.py @rhettinger +Lib/multiprocessing/*queue*.py @rhettinger +Lib/test/*queue*.py @rhettinger +Lib/test_asyncio/*queue*.py @rhettinger +Lib/test_multiprocessing/*queue*.py @rhettinger **/*bisect* @rhettinger **/*heapq* @rhettinger **/*functools* @rhettinger diff --git a/Include/internal/pycore_crossinterp.h b/Include/internal/pycore_crossinterp.h index 2e6d09a49f95d3..ce95979f8d343b 100644 --- a/Include/internal/pycore_crossinterp.h +++ b/Include/internal/pycore_crossinterp.h @@ -11,6 +11,13 @@ extern "C" { #include "pycore_lock.h" // PyMutex #include "pycore_pyerrors.h" +/**************/ +/* exceptions */ +/**************/ + +PyAPI_DATA(PyObject *) PyExc_InterpreterError; +PyAPI_DATA(PyObject *) PyExc_InterpreterNotFoundError; + /***************************/ /* cross-interpreter calls */ @@ -160,6 +167,9 @@ struct _xi_state { extern PyStatus _PyXI_Init(PyInterpreterState *interp); extern void _PyXI_Fini(PyInterpreterState *interp); +extern PyStatus _PyXI_InitTypes(PyInterpreterState *interp); +extern void _PyXI_FiniTypes(PyInterpreterState *interp); + /***************************/ /* short-term data sharing */ diff --git a/Include/internal/pycore_interp.h b/Include/internal/pycore_interp.h index 2a683196eeced3..04d7a6a615e370 100644 --- a/Include/internal/pycore_interp.h +++ b/Include/internal/pycore_interp.h @@ -250,9 +250,9 @@ _PyInterpreterState_SetFinalizing(PyInterpreterState *interp, PyThreadState *tst // Export for the _xxinterpchannels module. PyAPI_FUNC(PyInterpreterState *) _PyInterpreterState_LookUpID(int64_t); -extern int _PyInterpreterState_IDInitref(PyInterpreterState *); -extern int _PyInterpreterState_IDIncref(PyInterpreterState *); -extern void _PyInterpreterState_IDDecref(PyInterpreterState *); +PyAPI_FUNC(int) _PyInterpreterState_IDInitref(PyInterpreterState *); +PyAPI_FUNC(int) _PyInterpreterState_IDIncref(PyInterpreterState *); +PyAPI_FUNC(void) _PyInterpreterState_IDDecref(PyInterpreterState *); extern const PyConfig* _PyInterpreterState_GetConfig(PyInterpreterState *interp); diff --git a/Lib/test/support/interpreters/__init__.py b/Lib/test/support/interpreters/__init__.py new file mode 100644 index 00000000000000..2d6376deb5907e --- /dev/null +++ b/Lib/test/support/interpreters/__init__.py @@ -0,0 +1,160 @@ +"""Subinterpreters High Level Module.""" + +import threading +import weakref +import _xxsubinterpreters as _interpreters + +# aliases: +from _xxsubinterpreters import ( + InterpreterError, InterpreterNotFoundError, + is_shareable, +) + + +__all__ = [ + 'get_current', 'get_main', 'create', 'list_all', 'is_shareable', + 'Interpreter', + 'InterpreterError', 'InterpreterNotFoundError', 'ExecFailure', + 'create_queue', 'Queue', 'QueueEmpty', 'QueueFull', +] + + +_queuemod = None + +def __getattr__(name): + if name in ('Queue', 'QueueEmpty', 'QueueFull', 'create_queue'): + global create_queue, Queue, QueueEmpty, QueueFull + ns = globals() + from .queues import ( + create as create_queue, + Queue, QueueEmpty, QueueFull, + ) + return ns[name] + else: + raise AttributeError(name) + + +class ExecFailure(RuntimeError): + + def __init__(self, excinfo): + msg = excinfo.formatted + if not msg: + if excinfo.type and snapshot.msg: + msg = f'{snapshot.type.__name__}: {snapshot.msg}' + else: + msg = snapshot.type.__name__ or snapshot.msg + super().__init__(msg) + self.snapshot = excinfo + + +def create(): + """Return a new (idle) Python interpreter.""" + id = _interpreters.create(isolated=True) + return Interpreter(id) + + +def list_all(): + """Return all existing interpreters.""" + return [Interpreter(id) for id in _interpreters.list_all()] + + +def get_current(): + """Return the currently running interpreter.""" + id = _interpreters.get_current() + return Interpreter(id) + + +def get_main(): + """Return the main interpreter.""" + id = _interpreters.get_main() + return Interpreter(id) + + +_known = weakref.WeakValueDictionary() + +class Interpreter: + """A single Python interpreter.""" + + def __new__(cls, id, /): + # There is only one instance for any given ID. + if not isinstance(id, int): + raise TypeError(f'id must be an int, got {id!r}') + id = int(id) + try: + self = _known[id] + assert hasattr(self, '_ownsref') + except KeyError: + # This may raise InterpreterNotFoundError: + _interpreters._incref(id) + try: + self = super().__new__(cls) + self._id = id + self._ownsref = True + except BaseException: + _interpreters._deccref(id) + raise + _known[id] = self + return self + + def __repr__(self): + return f'{type(self).__name__}({self.id})' + + def __hash__(self): + return hash(self._id) + + def __del__(self): + self._decref() + + def _decref(self): + if not self._ownsref: + return + self._ownsref = False + try: + _interpreters._decref(self.id) + except InterpreterNotFoundError: + pass + + @property + def id(self): + return self._id + + def is_running(self): + """Return whether or not the identified interpreter is running.""" + return _interpreters.is_running(self._id) + + def close(self): + """Finalize and destroy the interpreter. + + Attempting to destroy the current interpreter results + in a RuntimeError. + """ + return _interpreters.destroy(self._id) + + def exec_sync(self, code, /, channels=None): + """Run the given source code in the interpreter. + + This is essentially the same as calling the builtin "exec" + with this interpreter, using the __dict__ of its __main__ + module as both globals and locals. + + There is no return value. + + If the code raises an unhandled exception then an ExecFailure + is raised, which summarizes the unhandled exception. The actual + exception is discarded because objects cannot be shared between + interpreters. + + This blocks the current Python thread until done. During + that time, the previous interpreter is allowed to run + in other threads. + """ + excinfo = _interpreters.exec(self._id, code, channels) + if excinfo is not None: + raise ExecFailure(excinfo) + + def run(self, code, /, channels=None): + def task(): + self.exec_sync(code, channels=channels) + t = threading.Thread(target=task) + t.start() + return t diff --git a/Lib/test/support/interpreters.py b/Lib/test/support/interpreters/channels.py similarity index 56% rename from Lib/test/support/interpreters.py rename to Lib/test/support/interpreters/channels.py index 089fe7ef56df78..75a5a60f54f926 100644 --- a/Lib/test/support/interpreters.py +++ b/Lib/test/support/interpreters/channels.py @@ -1,11 +1,9 @@ -"""Subinterpreters High Level Module.""" +"""Cross-interpreter Channels High Level Module.""" import time -import _xxsubinterpreters as _interpreters import _xxinterpchannels as _channels # aliases: -from _xxsubinterpreters import is_shareable from _xxinterpchannels import ( ChannelError, ChannelNotFoundError, ChannelClosedError, ChannelEmptyError, ChannelNotEmptyError, @@ -13,123 +11,13 @@ __all__ = [ - 'Interpreter', 'get_current', 'get_main', 'create', 'list_all', - 'RunFailedError', + 'create', 'list_all', 'SendChannel', 'RecvChannel', - 'create_channel', 'list_all_channels', 'is_shareable', - 'ChannelError', 'ChannelNotFoundError', - 'ChannelEmptyError', - ] + 'ChannelError', 'ChannelNotFoundError', 'ChannelEmptyError', +] -class RunFailedError(RuntimeError): - - def __init__(self, excinfo): - msg = excinfo.formatted - if not msg: - if excinfo.type and snapshot.msg: - msg = f'{snapshot.type.__name__}: {snapshot.msg}' - else: - msg = snapshot.type.__name__ or snapshot.msg - super().__init__(msg) - self.snapshot = excinfo - - -def create(*, isolated=True): - """Return a new (idle) Python interpreter.""" - id = _interpreters.create(isolated=isolated) - return Interpreter(id, isolated=isolated) - - -def list_all(): - """Return all existing interpreters.""" - return [Interpreter(id) for id in _interpreters.list_all()] - - -def get_current(): - """Return the currently running interpreter.""" - id = _interpreters.get_current() - return Interpreter(id) - - -def get_main(): - """Return the main interpreter.""" - id = _interpreters.get_main() - return Interpreter(id) - - -class Interpreter: - """A single Python interpreter.""" - - def __init__(self, id, *, isolated=None): - if not isinstance(id, (int, _interpreters.InterpreterID)): - raise TypeError(f'id must be an int, got {id!r}') - self._id = id - self._isolated = isolated - - def __repr__(self): - data = dict(id=int(self._id), isolated=self._isolated) - kwargs = (f'{k}={v!r}' for k, v in data.items()) - return f'{type(self).__name__}({", ".join(kwargs)})' - - def __hash__(self): - return hash(self._id) - - def __eq__(self, other): - if not isinstance(other, Interpreter): - return NotImplemented - else: - return other._id == self._id - - @property - def id(self): - return self._id - - @property - def isolated(self): - if self._isolated is None: - # XXX The low-level function has not been added yet. - # See bpo-.... - self._isolated = _interpreters.is_isolated(self._id) - return self._isolated - - def is_running(self): - """Return whether or not the identified interpreter is running.""" - return _interpreters.is_running(self._id) - - def close(self): - """Finalize and destroy the interpreter. - - Attempting to destroy the current interpreter results - in a RuntimeError. - """ - return _interpreters.destroy(self._id) - - # XXX Rename "run" to "exec"? - def run(self, src_str, /, channels=None): - """Run the given source code in the interpreter. - - This is essentially the same as calling the builtin "exec" - with this interpreter, using the __dict__ of its __main__ - module as both globals and locals. - - There is no return value. - - If the code raises an unhandled exception then a RunFailedError - is raised, which summarizes the unhandled exception. The actual - exception is discarded because objects cannot be shared between - interpreters. - - This blocks the current Python thread until done. During - that time, the previous interpreter is allowed to run - in other threads. - """ - excinfo = _interpreters.exec(self._id, src_str, channels) - if excinfo is not None: - raise RunFailedError(excinfo) - - -def create_channel(): +def create(): """Return (recv, send) for a new cross-interpreter channel. The channel may be used to pass data safely between interpreters. @@ -139,7 +27,7 @@ def create_channel(): return recv, send -def list_all_channels(): +def list_all(): """Return a list of (recv, send) for all open channels.""" return [(RecvChannel(cid), SendChannel(cid)) for cid in _channels.list_all()] diff --git a/Lib/test/support/interpreters/queues.py b/Lib/test/support/interpreters/queues.py new file mode 100644 index 00000000000000..ed6b0d551dd890 --- /dev/null +++ b/Lib/test/support/interpreters/queues.py @@ -0,0 +1,156 @@ +"""Cross-interpreter Queues High Level Module.""" + +import queue +import time +import weakref +import _xxinterpchannels as _channels +import _xxinterpchannels as _queues + +# aliases: +from _xxinterpchannels import ( + ChannelError as QueueError, + ChannelNotFoundError as QueueNotFoundError, +) + +__all__ = [ + 'create', 'list_all', + 'Queue', + 'QueueError', 'QueueNotFoundError', 'QueueEmpty', 'QueueFull', +] + + +def create(maxsize=0): + """Return a new cross-interpreter queue. + + The queue may be used to pass data safely between interpreters. + """ + # XXX honor maxsize + qid = _queues.create() + return Queue._with_maxsize(qid, maxsize) + + +def list_all(): + """Return a list of all open queues.""" + return [Queue(qid) + for qid in _queues.list_all()] + + +class QueueEmpty(queue.Empty): + """Raised from get_nowait() when the queue is empty. + + It is also raised from get() if it times out. + """ + + +class QueueFull(queue.Full): + """Raised from put_nowait() when the queue is full. + + It is also raised from put() if it times out. + """ + + +_known_queues = weakref.WeakValueDictionary() + +class Queue: + """A cross-interpreter queue.""" + + @classmethod + def _with_maxsize(cls, id, maxsize): + if not isinstance(maxsize, int): + raise TypeError(f'maxsize must be an int, got {maxsize!r}') + elif maxsize < 0: + maxsize = 0 + else: + maxsize = int(maxsize) + self = cls(id) + self._maxsize = maxsize + return self + + def __new__(cls, id, /): + # There is only one instance for any given ID. + if isinstance(id, int): + id = _channels._channel_id(id, force=False) + elif not isinstance(id, _channels.ChannelID): + raise TypeError(f'id must be an int, got {id!r}') + key = int(id) + try: + self = _known_queues[key] + except KeyError: + self = super().__new__(cls) + self._id = id + self._maxsize = 0 + _known_queues[key] = self + return self + + def __repr__(self): + return f'{type(self).__name__}({self.id})' + + def __hash__(self): + return hash(self._id) + + @property + def id(self): + return int(self._id) + + @property + def maxsize(self): + return self._maxsize + + @property + def _info(self): + return _channels.get_info(self._id) + + def empty(self): + return self._info.count == 0 + + def full(self): + if self._maxsize <= 0: + return False + return self._info.count >= self._maxsize + + def qsize(self): + return self._info.count + + def put(self, obj, timeout=None): + # XXX block if full + _channels.send(self._id, obj, blocking=False) + + def put_nowait(self, obj): + # XXX raise QueueFull if full + return _channels.send(self._id, obj, blocking=False) + + def get(self, timeout=None, *, + _sentinel=object(), + _delay=10 / 1000, # 10 milliseconds + ): + """Return the next object from the queue. + + This blocks while the queue is empty. + """ + if timeout is not None: + timeout = int(timeout) + if timeout < 0: + raise ValueError(f'timeout value must be non-negative') + end = time.time() + timeout + obj = _channels.recv(self._id, _sentinel) + while obj is _sentinel: + time.sleep(_delay) + if timeout is not None and time.time() >= end: + raise QueueEmpty + obj = _channels.recv(self._id, _sentinel) + return obj + + def get_nowait(self, *, _sentinel=object()): + """Return the next object from the channel. + + If the queue is empty then raise QueueEmpty. Otherwise this + is the same as get(). + """ + obj = _channels.recv(self._id, _sentinel) + if obj is _sentinel: + raise QueueEmpty + return obj + + +# XXX add this: +#_channels._register_queue_type(Queue) diff --git a/Lib/test/test__xxinterpchannels.py b/Lib/test/test__xxinterpchannels.py index 2b75e2f1916c82..13c8a10296e502 100644 --- a/Lib/test/test__xxinterpchannels.py +++ b/Lib/test/test__xxinterpchannels.py @@ -79,8 +79,7 @@ def __new__(cls, name=None, id=None): name = 'interp' elif name == 'main': raise ValueError('name mismatch (unexpected "main")') - if not isinstance(id, interpreters.InterpreterID): - id = interpreters.InterpreterID(id) + assert isinstance(id, int), repr(id) elif not name or name == 'main': name = 'main' id = main diff --git a/Lib/test/test__xxsubinterpreters.py b/Lib/test/test__xxsubinterpreters.py index 64a9db95e5eaf5..260ab64b07cb2d 100644 --- a/Lib/test/test__xxsubinterpreters.py +++ b/Lib/test/test__xxsubinterpreters.py @@ -15,6 +15,7 @@ interpreters = import_helper.import_module('_xxsubinterpreters') +from _xxsubinterpreters import InterpreterNotFoundError ################################## @@ -266,7 +267,7 @@ def test_main(self): main = interpreters.get_main() cur = interpreters.get_current() self.assertEqual(cur, main) - self.assertIsInstance(cur, interpreters.InterpreterID) + self.assertIsInstance(cur, int) def test_subinterpreter(self): main = interpreters.get_main() @@ -275,7 +276,7 @@ def test_subinterpreter(self): import _xxsubinterpreters as _interpreters cur = _interpreters.get_current() print(cur) - assert isinstance(cur, _interpreters.InterpreterID) + assert isinstance(cur, int) """)) cur = int(out.strip()) _, expected = interpreters.list_all() @@ -289,7 +290,7 @@ def test_from_main(self): [expected] = interpreters.list_all() main = interpreters.get_main() self.assertEqual(main, expected) - self.assertIsInstance(main, interpreters.InterpreterID) + self.assertIsInstance(main, int) def test_from_subinterpreter(self): [expected] = interpreters.list_all() @@ -298,7 +299,7 @@ def test_from_subinterpreter(self): import _xxsubinterpreters as _interpreters main = _interpreters.get_main() print(main) - assert isinstance(main, _interpreters.InterpreterID) + assert isinstance(main, int) """)) main = int(out.strip()) self.assertEqual(main, expected) @@ -333,11 +334,11 @@ def test_from_subinterpreter(self): def test_already_destroyed(self): interp = interpreters.create() interpreters.destroy(interp) - with self.assertRaises(RuntimeError): + with self.assertRaises(InterpreterNotFoundError): interpreters.is_running(interp) def test_does_not_exist(self): - with self.assertRaises(RuntimeError): + with self.assertRaises(InterpreterNotFoundError): interpreters.is_running(1_000_000) def test_bad_id(self): @@ -345,70 +346,11 @@ def test_bad_id(self): interpreters.is_running(-1) -class InterpreterIDTests(TestBase): - - def test_with_int(self): - id = interpreters.InterpreterID(10, force=True) - - self.assertEqual(int(id), 10) - - def test_coerce_id(self): - class Int(str): - def __index__(self): - return 10 - - id = interpreters.InterpreterID(Int(), force=True) - self.assertEqual(int(id), 10) - - def test_bad_id(self): - self.assertRaises(TypeError, interpreters.InterpreterID, object()) - self.assertRaises(TypeError, interpreters.InterpreterID, 10.0) - self.assertRaises(TypeError, interpreters.InterpreterID, '10') - self.assertRaises(TypeError, interpreters.InterpreterID, b'10') - self.assertRaises(ValueError, interpreters.InterpreterID, -1) - self.assertRaises(OverflowError, interpreters.InterpreterID, 2**64) - - def test_does_not_exist(self): - id = interpreters.create() - with self.assertRaises(RuntimeError): - interpreters.InterpreterID(int(id) + 1) # unforced - - def test_str(self): - id = interpreters.InterpreterID(10, force=True) - self.assertEqual(str(id), '10') - - def test_repr(self): - id = interpreters.InterpreterID(10, force=True) - self.assertEqual(repr(id), 'InterpreterID(10)') - - def test_equality(self): - id1 = interpreters.create() - id2 = interpreters.InterpreterID(int(id1)) - id3 = interpreters.create() - - self.assertTrue(id1 == id1) - self.assertTrue(id1 == id2) - self.assertTrue(id1 == int(id1)) - self.assertTrue(int(id1) == id1) - self.assertTrue(id1 == float(int(id1))) - self.assertTrue(float(int(id1)) == id1) - self.assertFalse(id1 == float(int(id1)) + 0.1) - self.assertFalse(id1 == str(int(id1))) - self.assertFalse(id1 == 2**1000) - self.assertFalse(id1 == float('inf')) - self.assertFalse(id1 == 'spam') - self.assertFalse(id1 == id3) - - self.assertFalse(id1 != id1) - self.assertFalse(id1 != id2) - self.assertTrue(id1 != id3) - - class CreateTests(TestBase): def test_in_main(self): id = interpreters.create() - self.assertIsInstance(id, interpreters.InterpreterID) + self.assertIsInstance(id, int) self.assertIn(id, interpreters.list_all()) @@ -444,7 +386,7 @@ def test_in_subinterpreter(self): import _xxsubinterpreters as _interpreters id = _interpreters.create() print(id) - assert isinstance(id, _interpreters.InterpreterID) + assert isinstance(id, int) """)) id2 = int(out.strip()) @@ -536,11 +478,11 @@ def f(): def test_already_destroyed(self): id = interpreters.create() interpreters.destroy(id) - with self.assertRaises(RuntimeError): + with self.assertRaises(InterpreterNotFoundError): interpreters.destroy(id) def test_does_not_exist(self): - with self.assertRaises(RuntimeError): + with self.assertRaises(InterpreterNotFoundError): interpreters.destroy(1_000_000) def test_bad_id(self): @@ -741,7 +683,7 @@ def test_does_not_exist(self): id = 0 while id in interpreters.list_all(): id += 1 - with self.assertRaises(RuntimeError): + with self.assertRaises(InterpreterNotFoundError): interpreters.run_string(id, 'print("spam")') def test_error_id(self): diff --git a/Lib/test/test_capi/test_misc.py b/Lib/test/test_capi/test_misc.py index 3d86ae37190475..e6b532e858c8f9 100644 --- a/Lib/test/test_capi/test_misc.py +++ b/Lib/test/test_capi/test_misc.py @@ -1527,6 +1527,7 @@ def test_isolated_subinterpreter(self): maxtext = 250 main_interpid = 0 interpid = _interpreters.create() + self.addCleanup(lambda: _interpreters.destroy(interpid)) _interpreters.run_string(interpid, f"""if True: import json import os @@ -2020,6 +2021,137 @@ def test_module_state_shared_in_global(self): self.assertEqual(main_attr_id, subinterp_attr_id) +@requires_subinterpreters +class InterpreterIDTests(unittest.TestCase): + + InterpreterID = _testcapi.get_interpreterid_type() + + def new_interpreter(self): + def ensure_destroyed(interpid): + try: + _interpreters.destroy(interpid) + except _interpreters.InterpreterNotFoundError: + pass + id = _interpreters.create() + self.addCleanup(lambda: ensure_destroyed(id)) + return id + + def test_with_int(self): + id = self.InterpreterID(10, force=True) + + self.assertEqual(int(id), 10) + + def test_coerce_id(self): + class Int(str): + def __index__(self): + return 10 + + id = self.InterpreterID(Int(), force=True) + self.assertEqual(int(id), 10) + + def test_bad_id(self): + for badid in [ + object(), + 10.0, + '10', + b'10', + ]: + with self.subTest(badid): + with self.assertRaises(TypeError): + self.InterpreterID(badid) + + badid = -1 + with self.subTest(badid): + with self.assertRaises(ValueError): + self.InterpreterID(badid) + + badid = 2**64 + with self.subTest(badid): + with self.assertRaises(OverflowError): + self.InterpreterID(badid) + + def test_exists(self): + id = self.new_interpreter() + with self.assertRaises(_interpreters.InterpreterNotFoundError): + self.InterpreterID(int(id) + 1) # unforced + + def test_does_not_exist(self): + id = self.new_interpreter() + with self.assertRaises(_interpreters.InterpreterNotFoundError): + self.InterpreterID(int(id) + 1) # unforced + + def test_destroyed(self): + id = _interpreters.create() + _interpreters.destroy(id) + with self.assertRaises(_interpreters.InterpreterNotFoundError): + self.InterpreterID(id) # unforced + + def test_str(self): + id = self.InterpreterID(10, force=True) + self.assertEqual(str(id), '10') + + def test_repr(self): + id = self.InterpreterID(10, force=True) + self.assertEqual(repr(id), 'InterpreterID(10)') + + def test_equality(self): + id1 = self.new_interpreter() + id2 = self.InterpreterID(id1) + id3 = self.InterpreterID( + self.new_interpreter()) + + self.assertTrue(id2 == id2) # identity + self.assertTrue(id2 == id1) # int-equivalent + self.assertTrue(id1 == id2) # reversed + self.assertTrue(id2 == int(id2)) + self.assertTrue(id2 == float(int(id2))) + self.assertTrue(float(int(id2)) == id2) + self.assertFalse(id2 == float(int(id2)) + 0.1) + self.assertFalse(id2 == str(int(id2))) + self.assertFalse(id2 == 2**1000) + self.assertFalse(id2 == float('inf')) + self.assertFalse(id2 == 'spam') + self.assertFalse(id2 == id3) + + self.assertFalse(id2 != id2) + self.assertFalse(id2 != id1) + self.assertFalse(id1 != id2) + self.assertTrue(id2 != id3) + + def test_linked_lifecycle(self): + id1 = _interpreters.create() + _testcapi.unlink_interpreter_refcount(id1) + self.assertEqual( + _testinternalcapi.get_interpreter_refcount(id1), + 0) + + id2 = self.InterpreterID(id1) + self.assertEqual( + _testinternalcapi.get_interpreter_refcount(id1), + 1) + + # The interpreter isn't linked to ID objects, so it isn't destroyed. + del id2 + self.assertEqual( + _testinternalcapi.get_interpreter_refcount(id1), + 0) + + _testcapi.link_interpreter_refcount(id1) + self.assertEqual( + _testinternalcapi.get_interpreter_refcount(id1), + 0) + + id3 = self.InterpreterID(id1) + self.assertEqual( + _testinternalcapi.get_interpreter_refcount(id1), + 1) + + # The interpreter is linked now so is destroyed. + del id3 + with self.assertRaises(_interpreters.InterpreterNotFoundError): + _testinternalcapi.get_interpreter_refcount(id1) + + class BuiltinStaticTypesTests(unittest.TestCase): TYPES = [ diff --git a/Lib/test/test_import/__init__.py b/Lib/test/test_import/__init__.py index bbfbb57b1d8299..48c0a43f29e27f 100644 --- a/Lib/test/test_import/__init__.py +++ b/Lib/test/test_import/__init__.py @@ -1971,6 +1971,7 @@ def test_disallowed_reimport(self): print(_testsinglephase) ''') interpid = _interpreters.create() + self.addCleanup(lambda: _interpreters.destroy(interpid)) excsnap = _interpreters.run_string(interpid, script) self.assertIsNot(excsnap, None) @@ -2105,12 +2106,18 @@ def re_load(self, name, mod): def add_subinterpreter(self): interpid = _interpreters.create(isolated=False) - _interpreters.run_string(interpid, textwrap.dedent(''' + def ensure_destroyed(): + try: + _interpreters.destroy(interpid) + except _interpreters.InterpreterNotFoundError: + pass + self.addCleanup(ensure_destroyed) + _interpreters.exec(interpid, textwrap.dedent(''' import sys import _testinternalcapi ''')) def clean_up(): - _interpreters.run_string(interpid, textwrap.dedent(f''' + _interpreters.exec(interpid, textwrap.dedent(f''' name = {self.NAME!r} if name in sys.modules: sys.modules.pop(name)._clear_globals() diff --git a/Lib/test/test_importlib/test_util.py b/Lib/test/test_importlib/test_util.py index 914176559806f4..fe5e7b31d9c32b 100644 --- a/Lib/test/test_importlib/test_util.py +++ b/Lib/test/test_importlib/test_util.py @@ -657,14 +657,26 @@ class IncompatibleExtensionModuleRestrictionsTests(unittest.TestCase): def run_with_own_gil(self, script): interpid = _interpreters.create(isolated=True) - excsnap = _interpreters.run_string(interpid, script) + def ensure_destroyed(): + try: + _interpreters.destroy(interpid) + except _interpreters.InterpreterNotFoundError: + pass + self.addCleanup(ensure_destroyed) + excsnap = _interpreters.exec(interpid, script) if excsnap is not None: if excsnap.type.__name__ == 'ImportError': raise ImportError(excsnap.msg) def run_with_shared_gil(self, script): interpid = _interpreters.create(isolated=False) - excsnap = _interpreters.run_string(interpid, script) + def ensure_destroyed(): + try: + _interpreters.destroy(interpid) + except _interpreters.InterpreterNotFoundError: + pass + self.addCleanup(ensure_destroyed) + excsnap = _interpreters.exec(interpid, script) if excsnap is not None: if excsnap.type.__name__ == 'ImportError': raise ImportError(excsnap.msg) diff --git a/Lib/test/test_interpreters.py b/Lib/test/test_interpreters.py deleted file mode 100644 index 5663706c0ccfb7..00000000000000 --- a/Lib/test/test_interpreters.py +++ /dev/null @@ -1,1136 +0,0 @@ -import contextlib -import json -import os -import os.path -import sys -import threading -from textwrap import dedent -import unittest -import time - -from test import support -from test.support import import_helper -from test.support import threading_helper -from test.support import os_helper -_interpreters = import_helper.import_module('_xxsubinterpreters') -_channels = import_helper.import_module('_xxinterpchannels') -from test.support import interpreters - - -def _captured_script(script): - r, w = os.pipe() - indented = script.replace('\n', '\n ') - wrapped = dedent(f""" - import contextlib - with open({w}, 'w', encoding='utf-8') as spipe: - with contextlib.redirect_stdout(spipe): - {indented} - """) - return wrapped, open(r, encoding='utf-8') - - -def clean_up_interpreters(): - for interp in interpreters.list_all(): - if interp.id == 0: # main - continue - try: - interp.close() - except RuntimeError: - pass # already destroyed - - -def _run_output(interp, request, channels=None): - script, rpipe = _captured_script(request) - with rpipe: - interp.run(script, channels=channels) - return rpipe.read() - - -@contextlib.contextmanager -def _running(interp): - r, w = os.pipe() - def run(): - interp.run(dedent(f""" - # wait for "signal" - with open({r}) as rpipe: - rpipe.read() - """)) - - t = threading.Thread(target=run) - t.start() - - yield - - with open(w, 'w') as spipe: - spipe.write('done') - t.join() - - -class TestBase(unittest.TestCase): - - def pipe(self): - def ensure_closed(fd): - try: - os.close(fd) - except OSError: - pass - r, w = os.pipe() - self.addCleanup(lambda: ensure_closed(r)) - self.addCleanup(lambda: ensure_closed(w)) - return r, w - - def tearDown(self): - clean_up_interpreters() - - -class CreateTests(TestBase): - - def test_in_main(self): - interp = interpreters.create() - self.assertIsInstance(interp, interpreters.Interpreter) - self.assertIn(interp, interpreters.list_all()) - - def test_in_thread(self): - lock = threading.Lock() - interp = None - def f(): - nonlocal interp - interp = interpreters.create() - lock.acquire() - lock.release() - t = threading.Thread(target=f) - with lock: - t.start() - t.join() - self.assertIn(interp, interpreters.list_all()) - - def test_in_subinterpreter(self): - main, = interpreters.list_all() - interp = interpreters.create() - out = _run_output(interp, dedent(""" - from test.support import interpreters - interp = interpreters.create() - print(interp.id) - """)) - interp2 = interpreters.Interpreter(int(out)) - self.assertEqual(interpreters.list_all(), [main, interp, interp2]) - - def test_after_destroy_all(self): - before = set(interpreters.list_all()) - # Create 3 subinterpreters. - interp_lst = [] - for _ in range(3): - interps = interpreters.create() - interp_lst.append(interps) - # Now destroy them. - for interp in interp_lst: - interp.close() - # Finally, create another. - interp = interpreters.create() - self.assertEqual(set(interpreters.list_all()), before | {interp}) - - def test_after_destroy_some(self): - before = set(interpreters.list_all()) - # Create 3 subinterpreters. - interp1 = interpreters.create() - interp2 = interpreters.create() - interp3 = interpreters.create() - # Now destroy 2 of them. - interp1.close() - interp2.close() - # Finally, create another. - interp = interpreters.create() - self.assertEqual(set(interpreters.list_all()), before | {interp3, interp}) - - -class GetCurrentTests(TestBase): - - def test_main(self): - main = interpreters.get_main() - current = interpreters.get_current() - self.assertEqual(current, main) - - def test_subinterpreter(self): - main = _interpreters.get_main() - interp = interpreters.create() - out = _run_output(interp, dedent(""" - from test.support import interpreters - cur = interpreters.get_current() - print(cur.id) - """)) - current = interpreters.Interpreter(int(out)) - self.assertNotEqual(current, main) - - -class ListAllTests(TestBase): - - def test_initial(self): - interps = interpreters.list_all() - self.assertEqual(1, len(interps)) - - def test_after_creating(self): - main = interpreters.get_current() - first = interpreters.create() - second = interpreters.create() - - ids = [] - for interp in interpreters.list_all(): - ids.append(interp.id) - - self.assertEqual(ids, [main.id, first.id, second.id]) - - def test_after_destroying(self): - main = interpreters.get_current() - first = interpreters.create() - second = interpreters.create() - first.close() - - ids = [] - for interp in interpreters.list_all(): - ids.append(interp.id) - - self.assertEqual(ids, [main.id, second.id]) - - -class TestInterpreterAttrs(TestBase): - - def test_id_type(self): - main = interpreters.get_main() - current = interpreters.get_current() - interp = interpreters.create() - self.assertIsInstance(main.id, _interpreters.InterpreterID) - self.assertIsInstance(current.id, _interpreters.InterpreterID) - self.assertIsInstance(interp.id, _interpreters.InterpreterID) - - def test_main_id(self): - main = interpreters.get_main() - self.assertEqual(main.id, 0) - - def test_custom_id(self): - interp = interpreters.Interpreter(1) - self.assertEqual(interp.id, 1) - - with self.assertRaises(TypeError): - interpreters.Interpreter('1') - - def test_id_readonly(self): - interp = interpreters.Interpreter(1) - with self.assertRaises(AttributeError): - interp.id = 2 - - @unittest.skip('not ready yet (see bpo-32604)') - def test_main_isolated(self): - main = interpreters.get_main() - self.assertFalse(main.isolated) - - @unittest.skip('not ready yet (see bpo-32604)') - def test_subinterpreter_isolated_default(self): - interp = interpreters.create() - self.assertFalse(interp.isolated) - - def test_subinterpreter_isolated_explicit(self): - interp1 = interpreters.create(isolated=True) - interp2 = interpreters.create(isolated=False) - self.assertTrue(interp1.isolated) - self.assertFalse(interp2.isolated) - - @unittest.skip('not ready yet (see bpo-32604)') - def test_custom_isolated_default(self): - interp = interpreters.Interpreter(1) - self.assertFalse(interp.isolated) - - def test_custom_isolated_explicit(self): - interp1 = interpreters.Interpreter(1, isolated=True) - interp2 = interpreters.Interpreter(1, isolated=False) - self.assertTrue(interp1.isolated) - self.assertFalse(interp2.isolated) - - def test_isolated_readonly(self): - interp = interpreters.Interpreter(1) - with self.assertRaises(AttributeError): - interp.isolated = True - - def test_equality(self): - interp1 = interpreters.create() - interp2 = interpreters.create() - self.assertEqual(interp1, interp1) - self.assertNotEqual(interp1, interp2) - - -class TestInterpreterIsRunning(TestBase): - - def test_main(self): - main = interpreters.get_main() - self.assertTrue(main.is_running()) - - @unittest.skip('Fails on FreeBSD') - def test_subinterpreter(self): - interp = interpreters.create() - self.assertFalse(interp.is_running()) - - with _running(interp): - self.assertTrue(interp.is_running()) - self.assertFalse(interp.is_running()) - - def test_finished(self): - r, w = self.pipe() - interp = interpreters.create() - interp.run(f"""if True: - import os - os.write({w}, b'x') - """) - self.assertFalse(interp.is_running()) - self.assertEqual(os.read(r, 1), b'x') - - def test_from_subinterpreter(self): - interp = interpreters.create() - out = _run_output(interp, dedent(f""" - import _xxsubinterpreters as _interpreters - if _interpreters.is_running({interp.id}): - print(True) - else: - print(False) - """)) - self.assertEqual(out.strip(), 'True') - - def test_already_destroyed(self): - interp = interpreters.create() - interp.close() - with self.assertRaises(RuntimeError): - interp.is_running() - - def test_does_not_exist(self): - interp = interpreters.Interpreter(1_000_000) - with self.assertRaises(RuntimeError): - interp.is_running() - - def test_bad_id(self): - interp = interpreters.Interpreter(-1) - with self.assertRaises(ValueError): - interp.is_running() - - def test_with_only_background_threads(self): - r_interp, w_interp = self.pipe() - r_thread, w_thread = self.pipe() - - DONE = b'D' - FINISHED = b'F' - - interp = interpreters.create() - interp.run(f"""if True: - import os - import threading - - def task(): - v = os.read({r_thread}, 1) - assert v == {DONE!r} - os.write({w_interp}, {FINISHED!r}) - t = threading.Thread(target=task) - t.start() - """) - self.assertFalse(interp.is_running()) - - os.write(w_thread, DONE) - interp.run('t.join()') - self.assertEqual(os.read(r_interp, 1), FINISHED) - - -class TestInterpreterClose(TestBase): - - def test_basic(self): - main = interpreters.get_main() - interp1 = interpreters.create() - interp2 = interpreters.create() - interp3 = interpreters.create() - self.assertEqual(set(interpreters.list_all()), - {main, interp1, interp2, interp3}) - interp2.close() - self.assertEqual(set(interpreters.list_all()), - {main, interp1, interp3}) - - def test_all(self): - before = set(interpreters.list_all()) - interps = set() - for _ in range(3): - interp = interpreters.create() - interps.add(interp) - self.assertEqual(set(interpreters.list_all()), before | interps) - for interp in interps: - interp.close() - self.assertEqual(set(interpreters.list_all()), before) - - def test_main(self): - main, = interpreters.list_all() - with self.assertRaises(RuntimeError): - main.close() - - def f(): - with self.assertRaises(RuntimeError): - main.close() - - t = threading.Thread(target=f) - t.start() - t.join() - - def test_already_destroyed(self): - interp = interpreters.create() - interp.close() - with self.assertRaises(RuntimeError): - interp.close() - - def test_does_not_exist(self): - interp = interpreters.Interpreter(1_000_000) - with self.assertRaises(RuntimeError): - interp.close() - - def test_bad_id(self): - interp = interpreters.Interpreter(-1) - with self.assertRaises(ValueError): - interp.close() - - def test_from_current(self): - main, = interpreters.list_all() - interp = interpreters.create() - out = _run_output(interp, dedent(f""" - from test.support import interpreters - interp = interpreters.Interpreter({int(interp.id)}) - try: - interp.close() - except RuntimeError: - print('failed') - """)) - self.assertEqual(out.strip(), 'failed') - self.assertEqual(set(interpreters.list_all()), {main, interp}) - - def test_from_sibling(self): - main, = interpreters.list_all() - interp1 = interpreters.create() - interp2 = interpreters.create() - self.assertEqual(set(interpreters.list_all()), - {main, interp1, interp2}) - interp1.run(dedent(f""" - from test.support import interpreters - interp2 = interpreters.Interpreter(int({interp2.id})) - interp2.close() - interp3 = interpreters.create() - interp3.close() - """)) - self.assertEqual(set(interpreters.list_all()), {main, interp1}) - - def test_from_other_thread(self): - interp = interpreters.create() - def f(): - interp.close() - - t = threading.Thread(target=f) - t.start() - t.join() - - @unittest.skip('Fails on FreeBSD') - def test_still_running(self): - main, = interpreters.list_all() - interp = interpreters.create() - with _running(interp): - with self.assertRaises(RuntimeError): - interp.close() - self.assertTrue(interp.is_running()) - - def test_subthreads_still_running(self): - r_interp, w_interp = self.pipe() - r_thread, w_thread = self.pipe() - - FINISHED = b'F' - - interp = interpreters.create() - interp.run(f"""if True: - import os - import threading - import time - - done = False - - def notify_fini(): - global done - done = True - t.join() - threading._register_atexit(notify_fini) - - def task(): - while not done: - time.sleep(0.1) - os.write({w_interp}, {FINISHED!r}) - t = threading.Thread(target=task) - t.start() - """) - interp.close() - - self.assertEqual(os.read(r_interp, 1), FINISHED) - - -class TestInterpreterRun(TestBase): - - def test_success(self): - interp = interpreters.create() - script, file = _captured_script('print("it worked!", end="")') - with file: - interp.run(script) - out = file.read() - - self.assertEqual(out, 'it worked!') - - def test_failure(self): - interp = interpreters.create() - with self.assertRaises(interpreters.RunFailedError): - interp.run('raise Exception') - - def test_in_thread(self): - interp = interpreters.create() - script, file = _captured_script('print("it worked!", end="")') - with file: - def f(): - interp.run(script) - - t = threading.Thread(target=f) - t.start() - t.join() - out = file.read() - - self.assertEqual(out, 'it worked!') - - @support.requires_fork() - def test_fork(self): - interp = interpreters.create() - import tempfile - with tempfile.NamedTemporaryFile('w+', encoding='utf-8') as file: - file.write('') - file.flush() - - expected = 'spam spam spam spam spam' - script = dedent(f""" - import os - try: - os.fork() - except RuntimeError: - with open('{file.name}', 'w', encoding='utf-8') as out: - out.write('{expected}') - """) - interp.run(script) - - file.seek(0) - content = file.read() - self.assertEqual(content, expected) - - @unittest.skip('Fails on FreeBSD') - def test_already_running(self): - interp = interpreters.create() - with _running(interp): - with self.assertRaises(RuntimeError): - interp.run('print("spam")') - - def test_does_not_exist(self): - interp = interpreters.Interpreter(1_000_000) - with self.assertRaises(RuntimeError): - interp.run('print("spam")') - - def test_bad_id(self): - interp = interpreters.Interpreter(-1) - with self.assertRaises(ValueError): - interp.run('print("spam")') - - def test_bad_script(self): - interp = interpreters.create() - with self.assertRaises(TypeError): - interp.run(10) - - def test_bytes_for_script(self): - interp = interpreters.create() - with self.assertRaises(TypeError): - interp.run(b'print("spam")') - - def test_with_background_threads_still_running(self): - r_interp, w_interp = self.pipe() - r_thread, w_thread = self.pipe() - - RAN = b'R' - DONE = b'D' - FINISHED = b'F' - - interp = interpreters.create() - interp.run(f"""if True: - import os - import threading - - def task(): - v = os.read({r_thread}, 1) - assert v == {DONE!r} - os.write({w_interp}, {FINISHED!r}) - t = threading.Thread(target=task) - t.start() - os.write({w_interp}, {RAN!r}) - """) - interp.run(f"""if True: - os.write({w_interp}, {RAN!r}) - """) - - os.write(w_thread, DONE) - interp.run('t.join()') - self.assertEqual(os.read(r_interp, 1), RAN) - self.assertEqual(os.read(r_interp, 1), RAN) - self.assertEqual(os.read(r_interp, 1), FINISHED) - - # test_xxsubinterpreters covers the remaining Interpreter.run() behavior. - - -class StressTests(TestBase): - - # In these tests we generally want a lot of interpreters, - # but not so many that any test takes too long. - - @support.requires_resource('cpu') - def test_create_many_sequential(self): - alive = [] - for _ in range(100): - interp = interpreters.create() - alive.append(interp) - - @support.requires_resource('cpu') - def test_create_many_threaded(self): - alive = [] - def task(): - interp = interpreters.create() - alive.append(interp) - threads = (threading.Thread(target=task) for _ in range(200)) - with threading_helper.start_threads(threads): - pass - - -class StartupTests(TestBase): - - # We want to ensure the initial state of subinterpreters - # matches expectations. - - _subtest_count = 0 - - @contextlib.contextmanager - def subTest(self, *args): - with super().subTest(*args) as ctx: - self._subtest_count += 1 - try: - yield ctx - finally: - if self._debugged_in_subtest: - if self._subtest_count == 1: - # The first subtest adds a leading newline, so we - # compensate here by not printing a trailing newline. - print('### end subtest debug ###', end='') - else: - print('### end subtest debug ###') - self._debugged_in_subtest = False - - def debug(self, msg, *, header=None): - if header: - self._debug(f'--- {header} ---') - if msg: - if msg.endswith(os.linesep): - self._debug(msg[:-len(os.linesep)]) - else: - self._debug(msg) - self._debug('') - self._debug('------') - else: - self._debug(msg) - - _debugged = False - _debugged_in_subtest = False - def _debug(self, msg): - if not self._debugged: - print() - self._debugged = True - if self._subtest is not None: - if True: - if not self._debugged_in_subtest: - self._debugged_in_subtest = True - print('### start subtest debug ###') - print(msg) - else: - print(msg) - - def create_temp_dir(self): - import tempfile - tmp = tempfile.mkdtemp(prefix='test_interpreters_') - tmp = os.path.realpath(tmp) - self.addCleanup(os_helper.rmtree, tmp) - return tmp - - def write_script(self, *path, text): - filename = os.path.join(*path) - dirname = os.path.dirname(filename) - if dirname: - os.makedirs(dirname, exist_ok=True) - with open(filename, 'w', encoding='utf-8') as outfile: - outfile.write(dedent(text)) - return filename - - @support.requires_subprocess() - def run_python(self, argv, *, cwd=None): - # This method is inspired by - # EmbeddingTestsMixin.run_embedded_interpreter() in test_embed.py. - import shlex - import subprocess - if isinstance(argv, str): - argv = shlex.split(argv) - argv = [sys.executable, *argv] - try: - proc = subprocess.run( - argv, - cwd=cwd, - capture_output=True, - text=True, - ) - except Exception as exc: - self.debug(f'# cmd: {shlex.join(argv)}') - if isinstance(exc, FileNotFoundError) and not exc.filename: - if os.path.exists(argv[0]): - exists = 'exists' - else: - exists = 'does not exist' - self.debug(f'{argv[0]} {exists}') - raise # re-raise - assert proc.stderr == '' or proc.returncode != 0, proc.stderr - if proc.returncode != 0 and support.verbose: - self.debug(f'# python3 {shlex.join(argv[1:])} failed:') - self.debug(proc.stdout, header='stdout') - self.debug(proc.stderr, header='stderr') - self.assertEqual(proc.returncode, 0) - self.assertEqual(proc.stderr, '') - return proc.stdout - - def test_sys_path_0(self): - # The main interpreter's sys.path[0] should be used by subinterpreters. - script = ''' - import sys - from test.support import interpreters - - orig = sys.path[0] - - interp = interpreters.create() - interp.run(f"""if True: - import json - import sys - print(json.dumps({{ - 'main': {orig!r}, - 'sub': sys.path[0], - }}, indent=4), flush=True) - """) - ''' - # / - # pkg/ - # __init__.py - # __main__.py - # script.py - # script.py - cwd = self.create_temp_dir() - self.write_script(cwd, 'pkg', '__init__.py', text='') - self.write_script(cwd, 'pkg', '__main__.py', text=script) - self.write_script(cwd, 'pkg', 'script.py', text=script) - self.write_script(cwd, 'script.py', text=script) - - cases = [ - ('script.py', cwd), - ('-m script', cwd), - ('-m pkg', cwd), - ('-m pkg.script', cwd), - ('-c "import script"', ''), - ] - for argv, expected in cases: - with self.subTest(f'python3 {argv}'): - out = self.run_python(argv, cwd=cwd) - data = json.loads(out) - sp0_main, sp0_sub = data['main'], data['sub'] - self.assertEqual(sp0_sub, sp0_main) - self.assertEqual(sp0_sub, expected) - # XXX Also check them all with the -P cmdline flag? - - -class FinalizationTests(TestBase): - - def test_gh_109793(self): - import subprocess - argv = [sys.executable, '-c', '''if True: - import _xxsubinterpreters as _interpreters - interpid = _interpreters.create() - raise Exception - '''] - proc = subprocess.run(argv, capture_output=True, text=True) - self.assertIn('Traceback', proc.stderr) - if proc.returncode == 0 and support.verbose: - print() - print("--- cmd unexpected succeeded ---") - print(f"stdout:\n{proc.stdout}") - print(f"stderr:\n{proc.stderr}") - print("------") - self.assertEqual(proc.returncode, 1) - - -class TestIsShareable(TestBase): - - def test_default_shareables(self): - shareables = [ - # singletons - None, - # builtin objects - b'spam', - 'spam', - 10, - -10, - True, - False, - 100.0, - (), - (1, ('spam', 'eggs'), True), - ] - for obj in shareables: - with self.subTest(obj): - shareable = interpreters.is_shareable(obj) - self.assertTrue(shareable) - - def test_not_shareable(self): - class Cheese: - def __init__(self, name): - self.name = name - def __str__(self): - return self.name - - class SubBytes(bytes): - """A subclass of a shareable type.""" - - not_shareables = [ - # singletons - NotImplemented, - ..., - # builtin types and objects - type, - object, - object(), - Exception(), - # user-defined types and objects - Cheese, - Cheese('Wensleydale'), - SubBytes(b'spam'), - ] - for obj in not_shareables: - with self.subTest(repr(obj)): - self.assertFalse( - interpreters.is_shareable(obj)) - - -class TestChannels(TestBase): - - def test_create(self): - r, s = interpreters.create_channel() - self.assertIsInstance(r, interpreters.RecvChannel) - self.assertIsInstance(s, interpreters.SendChannel) - - def test_list_all(self): - self.assertEqual(interpreters.list_all_channels(), []) - created = set() - for _ in range(3): - ch = interpreters.create_channel() - created.add(ch) - after = set(interpreters.list_all_channels()) - self.assertEqual(after, created) - - def test_shareable(self): - rch, sch = interpreters.create_channel() - - self.assertTrue( - interpreters.is_shareable(rch)) - self.assertTrue( - interpreters.is_shareable(sch)) - - sch.send_nowait(rch) - sch.send_nowait(sch) - rch2 = rch.recv() - sch2 = rch.recv() - - self.assertEqual(rch2, rch) - self.assertEqual(sch2, sch) - - def test_is_closed(self): - rch, sch = interpreters.create_channel() - rbefore = rch.is_closed - sbefore = sch.is_closed - rch.close() - rafter = rch.is_closed - safter = sch.is_closed - - self.assertFalse(rbefore) - self.assertFalse(sbefore) - self.assertTrue(rafter) - self.assertTrue(safter) - - -class TestRecvChannelAttrs(TestBase): - - def test_id_type(self): - rch, _ = interpreters.create_channel() - self.assertIsInstance(rch.id, _channels.ChannelID) - - def test_custom_id(self): - rch = interpreters.RecvChannel(1) - self.assertEqual(rch.id, 1) - - with self.assertRaises(TypeError): - interpreters.RecvChannel('1') - - def test_id_readonly(self): - rch = interpreters.RecvChannel(1) - with self.assertRaises(AttributeError): - rch.id = 2 - - def test_equality(self): - ch1, _ = interpreters.create_channel() - ch2, _ = interpreters.create_channel() - self.assertEqual(ch1, ch1) - self.assertNotEqual(ch1, ch2) - - -class TestSendChannelAttrs(TestBase): - - def test_id_type(self): - _, sch = interpreters.create_channel() - self.assertIsInstance(sch.id, _channels.ChannelID) - - def test_custom_id(self): - sch = interpreters.SendChannel(1) - self.assertEqual(sch.id, 1) - - with self.assertRaises(TypeError): - interpreters.SendChannel('1') - - def test_id_readonly(self): - sch = interpreters.SendChannel(1) - with self.assertRaises(AttributeError): - sch.id = 2 - - def test_equality(self): - _, ch1 = interpreters.create_channel() - _, ch2 = interpreters.create_channel() - self.assertEqual(ch1, ch1) - self.assertNotEqual(ch1, ch2) - - -class TestSendRecv(TestBase): - - def test_send_recv_main(self): - r, s = interpreters.create_channel() - orig = b'spam' - s.send_nowait(orig) - obj = r.recv() - - self.assertEqual(obj, orig) - self.assertIsNot(obj, orig) - - def test_send_recv_same_interpreter(self): - interp = interpreters.create() - interp.run(dedent(""" - from test.support import interpreters - r, s = interpreters.create_channel() - orig = b'spam' - s.send_nowait(orig) - obj = r.recv() - assert obj == orig, 'expected: obj == orig' - assert obj is not orig, 'expected: obj is not orig' - """)) - - @unittest.skip('broken (see BPO-...)') - def test_send_recv_different_interpreters(self): - r1, s1 = interpreters.create_channel() - r2, s2 = interpreters.create_channel() - orig1 = b'spam' - s1.send_nowait(orig1) - out = _run_output( - interpreters.create(), - dedent(f""" - obj1 = r.recv() - assert obj1 == b'spam', 'expected: obj1 == orig1' - # When going to another interpreter we get a copy. - assert id(obj1) != {id(orig1)}, 'expected: obj1 is not orig1' - orig2 = b'eggs' - print(id(orig2)) - s.send_nowait(orig2) - """), - channels=dict(r=r1, s=s2), - ) - obj2 = r2.recv() - - self.assertEqual(obj2, b'eggs') - self.assertNotEqual(id(obj2), int(out)) - - def test_send_recv_different_threads(self): - r, s = interpreters.create_channel() - - def f(): - while True: - try: - obj = r.recv() - break - except interpreters.ChannelEmptyError: - time.sleep(0.1) - s.send(obj) - t = threading.Thread(target=f) - t.start() - - orig = b'spam' - s.send(orig) - obj = r.recv() - t.join() - - self.assertEqual(obj, orig) - self.assertIsNot(obj, orig) - - def test_send_recv_nowait_main(self): - r, s = interpreters.create_channel() - orig = b'spam' - s.send_nowait(orig) - obj = r.recv_nowait() - - self.assertEqual(obj, orig) - self.assertIsNot(obj, orig) - - def test_send_recv_nowait_main_with_default(self): - r, _ = interpreters.create_channel() - obj = r.recv_nowait(None) - - self.assertIsNone(obj) - - def test_send_recv_nowait_same_interpreter(self): - interp = interpreters.create() - interp.run(dedent(""" - from test.support import interpreters - r, s = interpreters.create_channel() - orig = b'spam' - s.send_nowait(orig) - obj = r.recv_nowait() - assert obj == orig, 'expected: obj == orig' - # When going back to the same interpreter we get the same object. - assert obj is not orig, 'expected: obj is not orig' - """)) - - @unittest.skip('broken (see BPO-...)') - def test_send_recv_nowait_different_interpreters(self): - r1, s1 = interpreters.create_channel() - r2, s2 = interpreters.create_channel() - orig1 = b'spam' - s1.send_nowait(orig1) - out = _run_output( - interpreters.create(), - dedent(f""" - obj1 = r.recv_nowait() - assert obj1 == b'spam', 'expected: obj1 == orig1' - # When going to another interpreter we get a copy. - assert id(obj1) != {id(orig1)}, 'expected: obj1 is not orig1' - orig2 = b'eggs' - print(id(orig2)) - s.send_nowait(orig2) - """), - channels=dict(r=r1, s=s2), - ) - obj2 = r2.recv_nowait() - - self.assertEqual(obj2, b'eggs') - self.assertNotEqual(id(obj2), int(out)) - - def test_recv_timeout(self): - r, _ = interpreters.create_channel() - with self.assertRaises(TimeoutError): - r.recv(timeout=1) - - def test_recv_channel_does_not_exist(self): - ch = interpreters.RecvChannel(1_000_000) - with self.assertRaises(interpreters.ChannelNotFoundError): - ch.recv() - - def test_send_channel_does_not_exist(self): - ch = interpreters.SendChannel(1_000_000) - with self.assertRaises(interpreters.ChannelNotFoundError): - ch.send(b'spam') - - def test_recv_nowait_channel_does_not_exist(self): - ch = interpreters.RecvChannel(1_000_000) - with self.assertRaises(interpreters.ChannelNotFoundError): - ch.recv_nowait() - - def test_send_nowait_channel_does_not_exist(self): - ch = interpreters.SendChannel(1_000_000) - with self.assertRaises(interpreters.ChannelNotFoundError): - ch.send_nowait(b'spam') - - def test_recv_nowait_empty(self): - ch, _ = interpreters.create_channel() - with self.assertRaises(interpreters.ChannelEmptyError): - ch.recv_nowait() - - def test_recv_nowait_default(self): - default = object() - rch, sch = interpreters.create_channel() - obj1 = rch.recv_nowait(default) - sch.send_nowait(None) - sch.send_nowait(1) - sch.send_nowait(b'spam') - sch.send_nowait(b'eggs') - obj2 = rch.recv_nowait(default) - obj3 = rch.recv_nowait(default) - obj4 = rch.recv_nowait() - obj5 = rch.recv_nowait(default) - obj6 = rch.recv_nowait(default) - - self.assertIs(obj1, default) - self.assertIs(obj2, None) - self.assertEqual(obj3, 1) - self.assertEqual(obj4, b'spam') - self.assertEqual(obj5, b'eggs') - self.assertIs(obj6, default) - - def test_send_buffer(self): - buf = bytearray(b'spamspamspam') - obj = None - rch, sch = interpreters.create_channel() - - def f(): - nonlocal obj - while True: - try: - obj = rch.recv() - break - except interpreters.ChannelEmptyError: - time.sleep(0.1) - t = threading.Thread(target=f) - t.start() - - sch.send_buffer(buf) - t.join() - - self.assertIsNot(obj, buf) - self.assertIsInstance(obj, memoryview) - self.assertEqual(obj, buf) - - buf[4:8] = b'eggs' - self.assertEqual(obj, buf) - obj[4:8] = b'ham.' - self.assertEqual(obj, buf) - - def test_send_buffer_nowait(self): - buf = bytearray(b'spamspamspam') - rch, sch = interpreters.create_channel() - sch.send_buffer_nowait(buf) - obj = rch.recv() - - self.assertIsNot(obj, buf) - self.assertIsInstance(obj, memoryview) - self.assertEqual(obj, buf) - - buf[4:8] = b'eggs' - self.assertEqual(obj, buf) - obj[4:8] = b'ham.' - self.assertEqual(obj, buf) diff --git a/Lib/test/test_interpreters/__init__.py b/Lib/test/test_interpreters/__init__.py new file mode 100644 index 00000000000000..4b16ecc31156a5 --- /dev/null +++ b/Lib/test/test_interpreters/__init__.py @@ -0,0 +1,5 @@ +import os +from test.support import load_package_tests + +def load_tests(*args): + return load_package_tests(os.path.dirname(__file__), *args) diff --git a/Lib/test/test_interpreters/__main__.py b/Lib/test/test_interpreters/__main__.py new file mode 100644 index 00000000000000..8641229877b2be --- /dev/null +++ b/Lib/test/test_interpreters/__main__.py @@ -0,0 +1,4 @@ +from . import load_tests +import unittest + +nittest.main() diff --git a/Lib/test/test_interpreters/test_api.py b/Lib/test/test_interpreters/test_api.py new file mode 100644 index 00000000000000..e4ae9d005b5282 --- /dev/null +++ b/Lib/test/test_interpreters/test_api.py @@ -0,0 +1,642 @@ +import os +import threading +from textwrap import dedent +import unittest + +from test import support +from test.support import import_helper +# Raise SkipTest if subinterpreters not supported. +import_helper.import_module('_xxsubinterpreters') +from test.support import interpreters +from test.support.interpreters import InterpreterNotFoundError +from .utils import _captured_script, _run_output, _running, TestBase + + +class ModuleTests(TestBase): + + def test_queue_aliases(self): + first = [ + interpreters.create_queue, + interpreters.Queue, + interpreters.QueueEmpty, + interpreters.QueueFull, + ] + second = [ + interpreters.create_queue, + interpreters.Queue, + interpreters.QueueEmpty, + interpreters.QueueFull, + ] + self.assertEqual(second, first) + + +class CreateTests(TestBase): + + def test_in_main(self): + interp = interpreters.create() + self.assertIsInstance(interp, interpreters.Interpreter) + self.assertIn(interp, interpreters.list_all()) + + def test_in_thread(self): + lock = threading.Lock() + interp = None + def f(): + nonlocal interp + interp = interpreters.create() + lock.acquire() + lock.release() + t = threading.Thread(target=f) + with lock: + t.start() + t.join() + self.assertIn(interp, interpreters.list_all()) + + def test_in_subinterpreter(self): + main, = interpreters.list_all() + interp = interpreters.create() + out = _run_output(interp, dedent(""" + from test.support import interpreters + interp = interpreters.create() + print(interp.id) + """)) + interp2 = interpreters.Interpreter(int(out)) + self.assertEqual(interpreters.list_all(), [main, interp, interp2]) + + def test_after_destroy_all(self): + before = set(interpreters.list_all()) + # Create 3 subinterpreters. + interp_lst = [] + for _ in range(3): + interps = interpreters.create() + interp_lst.append(interps) + # Now destroy them. + for interp in interp_lst: + interp.close() + # Finally, create another. + interp = interpreters.create() + self.assertEqual(set(interpreters.list_all()), before | {interp}) + + def test_after_destroy_some(self): + before = set(interpreters.list_all()) + # Create 3 subinterpreters. + interp1 = interpreters.create() + interp2 = interpreters.create() + interp3 = interpreters.create() + # Now destroy 2 of them. + interp1.close() + interp2.close() + # Finally, create another. + interp = interpreters.create() + self.assertEqual(set(interpreters.list_all()), before | {interp3, interp}) + + +class GetMainTests(TestBase): + + def test_id(self): + main = interpreters.get_main() + self.assertEqual(main.id, 0) + + def test_current(self): + main = interpreters.get_main() + current = interpreters.get_current() + self.assertIs(main, current) + + def test_idempotent(self): + main1 = interpreters.get_main() + main2 = interpreters.get_main() + self.assertIs(main1, main2) + + +class GetCurrentTests(TestBase): + + def test_main(self): + main = interpreters.get_main() + current = interpreters.get_current() + self.assertEqual(current, main) + + def test_subinterpreter(self): + main = interpreters.get_main() + interp = interpreters.create() + out = _run_output(interp, dedent(""" + from test.support import interpreters + cur = interpreters.get_current() + print(cur.id) + """)) + current = interpreters.Interpreter(int(out)) + self.assertEqual(current, interp) + self.assertNotEqual(current, main) + + def test_idempotent(self): + with self.subTest('main'): + cur1 = interpreters.get_current() + cur2 = interpreters.get_current() + self.assertIs(cur1, cur2) + + with self.subTest('subinterpreter'): + interp = interpreters.create() + out = _run_output(interp, dedent(""" + from test.support import interpreters + cur = interpreters.get_current() + print(id(cur)) + cur = interpreters.get_current() + print(id(cur)) + """)) + objid1, objid2 = (int(v) for v in out.splitlines()) + self.assertEqual(objid1, objid2) + + with self.subTest('per-interpreter'): + interp = interpreters.create() + out = _run_output(interp, dedent(""" + from test.support import interpreters + cur = interpreters.get_current() + print(id(cur)) + """)) + id1 = int(out) + id2 = id(interp) + self.assertNotEqual(id1, id2) + + +class ListAllTests(TestBase): + + def test_initial(self): + interps = interpreters.list_all() + self.assertEqual(1, len(interps)) + + def test_after_creating(self): + main = interpreters.get_current() + first = interpreters.create() + second = interpreters.create() + + ids = [] + for interp in interpreters.list_all(): + ids.append(interp.id) + + self.assertEqual(ids, [main.id, first.id, second.id]) + + def test_after_destroying(self): + main = interpreters.get_current() + first = interpreters.create() + second = interpreters.create() + first.close() + + ids = [] + for interp in interpreters.list_all(): + ids.append(interp.id) + + self.assertEqual(ids, [main.id, second.id]) + + def test_idempotent(self): + main = interpreters.get_current() + first = interpreters.create() + second = interpreters.create() + expected = [main, first, second] + + actual = interpreters.list_all() + + self.assertEqual(actual, expected) + for interp1, interp2 in zip(actual, expected): + self.assertIs(interp1, interp2) + + +class InterpreterObjectTests(TestBase): + + def test_init_int(self): + interpid = interpreters.get_current().id + interp = interpreters.Interpreter(interpid) + self.assertEqual(interp.id, interpid) + + def test_init_interpreter_id(self): + interpid = interpreters.get_current()._id + interp = interpreters.Interpreter(interpid) + self.assertEqual(interp._id, interpid) + + def test_init_unsupported(self): + actualid = interpreters.get_current().id + for interpid in [ + str(actualid), + float(actualid), + object(), + None, + '', + ]: + with self.subTest(repr(interpid)): + with self.assertRaises(TypeError): + interpreters.Interpreter(interpid) + + def test_idempotent(self): + main = interpreters.get_main() + interp = interpreters.Interpreter(main.id) + self.assertIs(interp, main) + + def test_init_does_not_exist(self): + with self.assertRaises(InterpreterNotFoundError): + interpreters.Interpreter(1_000_000) + + def test_init_bad_id(self): + with self.assertRaises(ValueError): + interpreters.Interpreter(-1) + + def test_id_type(self): + main = interpreters.get_main() + current = interpreters.get_current() + interp = interpreters.create() + self.assertIsInstance(main.id, int) + self.assertIsInstance(current.id, int) + self.assertIsInstance(interp.id, int) + + def test_id_readonly(self): + interp = interpreters.create() + with self.assertRaises(AttributeError): + interp.id = 1_000_000 + + def test_hashable(self): + interp = interpreters.create() + expected = hash(interp.id) + actual = hash(interp) + self.assertEqual(actual, expected) + + def test_equality(self): + interp1 = interpreters.create() + interp2 = interpreters.create() + self.assertEqual(interp1, interp1) + self.assertNotEqual(interp1, interp2) + + +class TestInterpreterIsRunning(TestBase): + + def test_main(self): + main = interpreters.get_main() + self.assertTrue(main.is_running()) + + @unittest.skip('Fails on FreeBSD') + def test_subinterpreter(self): + interp = interpreters.create() + self.assertFalse(interp.is_running()) + + with _running(interp): + self.assertTrue(interp.is_running()) + self.assertFalse(interp.is_running()) + + def test_finished(self): + r, w = self.pipe() + interp = interpreters.create() + interp.exec_sync(f"""if True: + import os + os.write({w}, b'x') + """) + self.assertFalse(interp.is_running()) + self.assertEqual(os.read(r, 1), b'x') + + def test_from_subinterpreter(self): + interp = interpreters.create() + out = _run_output(interp, dedent(f""" + import _xxsubinterpreters as _interpreters + if _interpreters.is_running({interp.id}): + print(True) + else: + print(False) + """)) + self.assertEqual(out.strip(), 'True') + + def test_already_destroyed(self): + interp = interpreters.create() + interp.close() + with self.assertRaises(InterpreterNotFoundError): + interp.is_running() + + def test_with_only_background_threads(self): + r_interp, w_interp = self.pipe() + r_thread, w_thread = self.pipe() + + DONE = b'D' + FINISHED = b'F' + + interp = interpreters.create() + interp.exec_sync(f"""if True: + import os + import threading + + def task(): + v = os.read({r_thread}, 1) + assert v == {DONE!r} + os.write({w_interp}, {FINISHED!r}) + t = threading.Thread(target=task) + t.start() + """) + self.assertFalse(interp.is_running()) + + os.write(w_thread, DONE) + interp.exec_sync('t.join()') + self.assertEqual(os.read(r_interp, 1), FINISHED) + + +class TestInterpreterClose(TestBase): + + def test_basic(self): + main = interpreters.get_main() + interp1 = interpreters.create() + interp2 = interpreters.create() + interp3 = interpreters.create() + self.assertEqual(set(interpreters.list_all()), + {main, interp1, interp2, interp3}) + interp2.close() + self.assertEqual(set(interpreters.list_all()), + {main, interp1, interp3}) + + def test_all(self): + before = set(interpreters.list_all()) + interps = set() + for _ in range(3): + interp = interpreters.create() + interps.add(interp) + self.assertEqual(set(interpreters.list_all()), before | interps) + for interp in interps: + interp.close() + self.assertEqual(set(interpreters.list_all()), before) + + def test_main(self): + main, = interpreters.list_all() + with self.assertRaises(RuntimeError): + main.close() + + def f(): + with self.assertRaises(RuntimeError): + main.close() + + t = threading.Thread(target=f) + t.start() + t.join() + + def test_already_destroyed(self): + interp = interpreters.create() + interp.close() + with self.assertRaises(InterpreterNotFoundError): + interp.close() + + def test_from_current(self): + main, = interpreters.list_all() + interp = interpreters.create() + out = _run_output(interp, dedent(f""" + from test.support import interpreters + interp = interpreters.Interpreter({interp.id}) + try: + interp.close() + except RuntimeError: + print('failed') + """)) + self.assertEqual(out.strip(), 'failed') + self.assertEqual(set(interpreters.list_all()), {main, interp}) + + def test_from_sibling(self): + main, = interpreters.list_all() + interp1 = interpreters.create() + interp2 = interpreters.create() + self.assertEqual(set(interpreters.list_all()), + {main, interp1, interp2}) + interp1.exec_sync(dedent(f""" + from test.support import interpreters + interp2 = interpreters.Interpreter({interp2.id}) + interp2.close() + interp3 = interpreters.create() + interp3.close() + """)) + self.assertEqual(set(interpreters.list_all()), {main, interp1}) + + def test_from_other_thread(self): + interp = interpreters.create() + def f(): + interp.close() + + t = threading.Thread(target=f) + t.start() + t.join() + + @unittest.skip('Fails on FreeBSD') + def test_still_running(self): + main, = interpreters.list_all() + interp = interpreters.create() + with _running(interp): + with self.assertRaises(RuntimeError): + interp.close() + self.assertTrue(interp.is_running()) + + def test_subthreads_still_running(self): + r_interp, w_interp = self.pipe() + r_thread, w_thread = self.pipe() + + FINISHED = b'F' + + interp = interpreters.create() + interp.exec_sync(f"""if True: + import os + import threading + import time + + done = False + + def notify_fini(): + global done + done = True + t.join() + threading._register_atexit(notify_fini) + + def task(): + while not done: + time.sleep(0.1) + os.write({w_interp}, {FINISHED!r}) + t = threading.Thread(target=task) + t.start() + """) + interp.close() + + self.assertEqual(os.read(r_interp, 1), FINISHED) + + +class TestInterpreterExecSync(TestBase): + + def test_success(self): + interp = interpreters.create() + script, file = _captured_script('print("it worked!", end="")') + with file: + interp.exec_sync(script) + out = file.read() + + self.assertEqual(out, 'it worked!') + + def test_failure(self): + interp = interpreters.create() + with self.assertRaises(interpreters.ExecFailure): + interp.exec_sync('raise Exception') + + def test_in_thread(self): + interp = interpreters.create() + script, file = _captured_script('print("it worked!", end="")') + with file: + def f(): + interp.exec_sync(script) + + t = threading.Thread(target=f) + t.start() + t.join() + out = file.read() + + self.assertEqual(out, 'it worked!') + + @support.requires_fork() + def test_fork(self): + interp = interpreters.create() + import tempfile + with tempfile.NamedTemporaryFile('w+', encoding='utf-8') as file: + file.write('') + file.flush() + + expected = 'spam spam spam spam spam' + script = dedent(f""" + import os + try: + os.fork() + except RuntimeError: + with open('{file.name}', 'w', encoding='utf-8') as out: + out.write('{expected}') + """) + interp.exec_sync(script) + + file.seek(0) + content = file.read() + self.assertEqual(content, expected) + + @unittest.skip('Fails on FreeBSD') + def test_already_running(self): + interp = interpreters.create() + with _running(interp): + with self.assertRaises(RuntimeError): + interp.exec_sync('print("spam")') + + def test_bad_script(self): + interp = interpreters.create() + with self.assertRaises(TypeError): + interp.exec_sync(10) + + def test_bytes_for_script(self): + interp = interpreters.create() + with self.assertRaises(TypeError): + interp.exec_sync(b'print("spam")') + + def test_with_background_threads_still_running(self): + r_interp, w_interp = self.pipe() + r_thread, w_thread = self.pipe() + + RAN = b'R' + DONE = b'D' + FINISHED = b'F' + + interp = interpreters.create() + interp.exec_sync(f"""if True: + import os + import threading + + def task(): + v = os.read({r_thread}, 1) + assert v == {DONE!r} + os.write({w_interp}, {FINISHED!r}) + t = threading.Thread(target=task) + t.start() + os.write({w_interp}, {RAN!r}) + """) + interp.exec_sync(f"""if True: + os.write({w_interp}, {RAN!r}) + """) + + os.write(w_thread, DONE) + interp.exec_sync('t.join()') + self.assertEqual(os.read(r_interp, 1), RAN) + self.assertEqual(os.read(r_interp, 1), RAN) + self.assertEqual(os.read(r_interp, 1), FINISHED) + + # test_xxsubinterpreters covers the remaining + # Interpreter.exec_sync() behavior. + + +class TestInterpreterRun(TestBase): + + def test_success(self): + interp = interpreters.create() + script, file = _captured_script('print("it worked!", end="")') + with file: + t = interp.run(script) + t.join() + out = file.read() + + self.assertEqual(out, 'it worked!') + + def test_failure(self): + caught = False + def excepthook(args): + nonlocal caught + caught = True + threading.excepthook = excepthook + try: + interp = interpreters.create() + t = interp.run('raise Exception') + t.join() + + self.assertTrue(caught) + except BaseException: + threading.excepthook = threading.__excepthook__ + + +class TestIsShareable(TestBase): + + def test_default_shareables(self): + shareables = [ + # singletons + None, + # builtin objects + b'spam', + 'spam', + 10, + -10, + True, + False, + 100.0, + (), + (1, ('spam', 'eggs'), True), + ] + for obj in shareables: + with self.subTest(obj): + shareable = interpreters.is_shareable(obj) + self.assertTrue(shareable) + + def test_not_shareable(self): + class Cheese: + def __init__(self, name): + self.name = name + def __str__(self): + return self.name + + class SubBytes(bytes): + """A subclass of a shareable type.""" + + not_shareables = [ + # singletons + NotImplemented, + ..., + # builtin types and objects + type, + object, + object(), + Exception(), + # user-defined types and objects + Cheese, + Cheese('Wensleydale'), + SubBytes(b'spam'), + ] + for obj in not_shareables: + with self.subTest(repr(obj)): + self.assertFalse( + interpreters.is_shareable(obj)) + + +if __name__ == '__main__': + # Test needs to be a package, so we can do relative imports. + unittest.main() diff --git a/Lib/test/test_interpreters/test_channels.py b/Lib/test/test_interpreters/test_channels.py new file mode 100644 index 00000000000000..3c3e18832d4168 --- /dev/null +++ b/Lib/test/test_interpreters/test_channels.py @@ -0,0 +1,328 @@ +import threading +from textwrap import dedent +import unittest +import time + +from test.support import import_helper +# Raise SkipTest if subinterpreters not supported. +_channels = import_helper.import_module('_xxinterpchannels') +from test.support import interpreters +from test.support.interpreters import channels +from .utils import _run_output, TestBase + + +class TestChannels(TestBase): + + def test_create(self): + r, s = channels.create() + self.assertIsInstance(r, channels.RecvChannel) + self.assertIsInstance(s, channels.SendChannel) + + def test_list_all(self): + self.assertEqual(channels.list_all(), []) + created = set() + for _ in range(3): + ch = channels.create() + created.add(ch) + after = set(channels.list_all()) + self.assertEqual(after, created) + + def test_shareable(self): + rch, sch = channels.create() + + self.assertTrue( + interpreters.is_shareable(rch)) + self.assertTrue( + interpreters.is_shareable(sch)) + + sch.send_nowait(rch) + sch.send_nowait(sch) + rch2 = rch.recv() + sch2 = rch.recv() + + self.assertEqual(rch2, rch) + self.assertEqual(sch2, sch) + + def test_is_closed(self): + rch, sch = channels.create() + rbefore = rch.is_closed + sbefore = sch.is_closed + rch.close() + rafter = rch.is_closed + safter = sch.is_closed + + self.assertFalse(rbefore) + self.assertFalse(sbefore) + self.assertTrue(rafter) + self.assertTrue(safter) + + +class TestRecvChannelAttrs(TestBase): + + def test_id_type(self): + rch, _ = channels.create() + self.assertIsInstance(rch.id, _channels.ChannelID) + + def test_custom_id(self): + rch = channels.RecvChannel(1) + self.assertEqual(rch.id, 1) + + with self.assertRaises(TypeError): + channels.RecvChannel('1') + + def test_id_readonly(self): + rch = channels.RecvChannel(1) + with self.assertRaises(AttributeError): + rch.id = 2 + + def test_equality(self): + ch1, _ = channels.create() + ch2, _ = channels.create() + self.assertEqual(ch1, ch1) + self.assertNotEqual(ch1, ch2) + + +class TestSendChannelAttrs(TestBase): + + def test_id_type(self): + _, sch = channels.create() + self.assertIsInstance(sch.id, _channels.ChannelID) + + def test_custom_id(self): + sch = channels.SendChannel(1) + self.assertEqual(sch.id, 1) + + with self.assertRaises(TypeError): + channels.SendChannel('1') + + def test_id_readonly(self): + sch = channels.SendChannel(1) + with self.assertRaises(AttributeError): + sch.id = 2 + + def test_equality(self): + _, ch1 = channels.create() + _, ch2 = channels.create() + self.assertEqual(ch1, ch1) + self.assertNotEqual(ch1, ch2) + + +class TestSendRecv(TestBase): + + def test_send_recv_main(self): + r, s = channels.create() + orig = b'spam' + s.send_nowait(orig) + obj = r.recv() + + self.assertEqual(obj, orig) + self.assertIsNot(obj, orig) + + def test_send_recv_same_interpreter(self): + interp = interpreters.create() + interp.exec_sync(dedent(""" + from test.support.interpreters import channels + r, s = channels.create() + orig = b'spam' + s.send_nowait(orig) + obj = r.recv() + assert obj == orig, 'expected: obj == orig' + assert obj is not orig, 'expected: obj is not orig' + """)) + + @unittest.skip('broken (see BPO-...)') + def test_send_recv_different_interpreters(self): + r1, s1 = channels.create() + r2, s2 = channels.create() + orig1 = b'spam' + s1.send_nowait(orig1) + out = _run_output( + interpreters.create(), + dedent(f""" + obj1 = r.recv() + assert obj1 == b'spam', 'expected: obj1 == orig1' + # When going to another interpreter we get a copy. + assert id(obj1) != {id(orig1)}, 'expected: obj1 is not orig1' + orig2 = b'eggs' + print(id(orig2)) + s.send_nowait(orig2) + """), + channels=dict(r=r1, s=s2), + ) + obj2 = r2.recv() + + self.assertEqual(obj2, b'eggs') + self.assertNotEqual(id(obj2), int(out)) + + def test_send_recv_different_threads(self): + r, s = channels.create() + + def f(): + while True: + try: + obj = r.recv() + break + except channels.ChannelEmptyError: + time.sleep(0.1) + s.send(obj) + t = threading.Thread(target=f) + t.start() + + orig = b'spam' + s.send(orig) + obj = r.recv() + t.join() + + self.assertEqual(obj, orig) + self.assertIsNot(obj, orig) + + def test_send_recv_nowait_main(self): + r, s = channels.create() + orig = b'spam' + s.send_nowait(orig) + obj = r.recv_nowait() + + self.assertEqual(obj, orig) + self.assertIsNot(obj, orig) + + def test_send_recv_nowait_main_with_default(self): + r, _ = channels.create() + obj = r.recv_nowait(None) + + self.assertIsNone(obj) + + def test_send_recv_nowait_same_interpreter(self): + interp = interpreters.create() + interp.exec_sync(dedent(""" + from test.support.interpreters import channels + r, s = channels.create() + orig = b'spam' + s.send_nowait(orig) + obj = r.recv_nowait() + assert obj == orig, 'expected: obj == orig' + # When going back to the same interpreter we get the same object. + assert obj is not orig, 'expected: obj is not orig' + """)) + + @unittest.skip('broken (see BPO-...)') + def test_send_recv_nowait_different_interpreters(self): + r1, s1 = channels.create() + r2, s2 = channels.create() + orig1 = b'spam' + s1.send_nowait(orig1) + out = _run_output( + interpreters.create(), + dedent(f""" + obj1 = r.recv_nowait() + assert obj1 == b'spam', 'expected: obj1 == orig1' + # When going to another interpreter we get a copy. + assert id(obj1) != {id(orig1)}, 'expected: obj1 is not orig1' + orig2 = b'eggs' + print(id(orig2)) + s.send_nowait(orig2) + """), + channels=dict(r=r1, s=s2), + ) + obj2 = r2.recv_nowait() + + self.assertEqual(obj2, b'eggs') + self.assertNotEqual(id(obj2), int(out)) + + def test_recv_timeout(self): + r, _ = channels.create() + with self.assertRaises(TimeoutError): + r.recv(timeout=1) + + def test_recv_channel_does_not_exist(self): + ch = channels.RecvChannel(1_000_000) + with self.assertRaises(channels.ChannelNotFoundError): + ch.recv() + + def test_send_channel_does_not_exist(self): + ch = channels.SendChannel(1_000_000) + with self.assertRaises(channels.ChannelNotFoundError): + ch.send(b'spam') + + def test_recv_nowait_channel_does_not_exist(self): + ch = channels.RecvChannel(1_000_000) + with self.assertRaises(channels.ChannelNotFoundError): + ch.recv_nowait() + + def test_send_nowait_channel_does_not_exist(self): + ch = channels.SendChannel(1_000_000) + with self.assertRaises(channels.ChannelNotFoundError): + ch.send_nowait(b'spam') + + def test_recv_nowait_empty(self): + ch, _ = channels.create() + with self.assertRaises(channels.ChannelEmptyError): + ch.recv_nowait() + + def test_recv_nowait_default(self): + default = object() + rch, sch = channels.create() + obj1 = rch.recv_nowait(default) + sch.send_nowait(None) + sch.send_nowait(1) + sch.send_nowait(b'spam') + sch.send_nowait(b'eggs') + obj2 = rch.recv_nowait(default) + obj3 = rch.recv_nowait(default) + obj4 = rch.recv_nowait() + obj5 = rch.recv_nowait(default) + obj6 = rch.recv_nowait(default) + + self.assertIs(obj1, default) + self.assertIs(obj2, None) + self.assertEqual(obj3, 1) + self.assertEqual(obj4, b'spam') + self.assertEqual(obj5, b'eggs') + self.assertIs(obj6, default) + + def test_send_buffer(self): + buf = bytearray(b'spamspamspam') + obj = None + rch, sch = channels.create() + + def f(): + nonlocal obj + while True: + try: + obj = rch.recv() + break + except channels.ChannelEmptyError: + time.sleep(0.1) + t = threading.Thread(target=f) + t.start() + + sch.send_buffer(buf) + t.join() + + self.assertIsNot(obj, buf) + self.assertIsInstance(obj, memoryview) + self.assertEqual(obj, buf) + + buf[4:8] = b'eggs' + self.assertEqual(obj, buf) + obj[4:8] = b'ham.' + self.assertEqual(obj, buf) + + def test_send_buffer_nowait(self): + buf = bytearray(b'spamspamspam') + rch, sch = channels.create() + sch.send_buffer_nowait(buf) + obj = rch.recv() + + self.assertIsNot(obj, buf) + self.assertIsInstance(obj, memoryview) + self.assertEqual(obj, buf) + + buf[4:8] = b'eggs' + self.assertEqual(obj, buf) + obj[4:8] = b'ham.' + self.assertEqual(obj, buf) + + +if __name__ == '__main__': + # Test needs to be a package, so we can do relative imports. + unittest.main() diff --git a/Lib/test/test_interpreters/test_lifecycle.py b/Lib/test/test_interpreters/test_lifecycle.py new file mode 100644 index 00000000000000..c2917d839904f9 --- /dev/null +++ b/Lib/test/test_interpreters/test_lifecycle.py @@ -0,0 +1,189 @@ +import contextlib +import json +import os +import os.path +import sys +from textwrap import dedent +import unittest + +from test import support +from test.support import import_helper +from test.support import os_helper +# Raise SkipTest if subinterpreters not supported. +import_helper.import_module('_xxsubinterpreters') +from .utils import TestBase + + +class StartupTests(TestBase): + + # We want to ensure the initial state of subinterpreters + # matches expectations. + + _subtest_count = 0 + + @contextlib.contextmanager + def subTest(self, *args): + with super().subTest(*args) as ctx: + self._subtest_count += 1 + try: + yield ctx + finally: + if self._debugged_in_subtest: + if self._subtest_count == 1: + # The first subtest adds a leading newline, so we + # compensate here by not printing a trailing newline. + print('### end subtest debug ###', end='') + else: + print('### end subtest debug ###') + self._debugged_in_subtest = False + + def debug(self, msg, *, header=None): + if header: + self._debug(f'--- {header} ---') + if msg: + if msg.endswith(os.linesep): + self._debug(msg[:-len(os.linesep)]) + else: + self._debug(msg) + self._debug('') + self._debug('------') + else: + self._debug(msg) + + _debugged = False + _debugged_in_subtest = False + def _debug(self, msg): + if not self._debugged: + print() + self._debugged = True + if self._subtest is not None: + if True: + if not self._debugged_in_subtest: + self._debugged_in_subtest = True + print('### start subtest debug ###') + print(msg) + else: + print(msg) + + def create_temp_dir(self): + import tempfile + tmp = tempfile.mkdtemp(prefix='test_interpreters_') + tmp = os.path.realpath(tmp) + self.addCleanup(os_helper.rmtree, tmp) + return tmp + + def write_script(self, *path, text): + filename = os.path.join(*path) + dirname = os.path.dirname(filename) + if dirname: + os.makedirs(dirname, exist_ok=True) + with open(filename, 'w', encoding='utf-8') as outfile: + outfile.write(dedent(text)) + return filename + + @support.requires_subprocess() + def run_python(self, argv, *, cwd=None): + # This method is inspired by + # EmbeddingTestsMixin.run_embedded_interpreter() in test_embed.py. + import shlex + import subprocess + if isinstance(argv, str): + argv = shlex.split(argv) + argv = [sys.executable, *argv] + try: + proc = subprocess.run( + argv, + cwd=cwd, + capture_output=True, + text=True, + ) + except Exception as exc: + self.debug(f'# cmd: {shlex.join(argv)}') + if isinstance(exc, FileNotFoundError) and not exc.filename: + if os.path.exists(argv[0]): + exists = 'exists' + else: + exists = 'does not exist' + self.debug(f'{argv[0]} {exists}') + raise # re-raise + assert proc.stderr == '' or proc.returncode != 0, proc.stderr + if proc.returncode != 0 and support.verbose: + self.debug(f'# python3 {shlex.join(argv[1:])} failed:') + self.debug(proc.stdout, header='stdout') + self.debug(proc.stderr, header='stderr') + self.assertEqual(proc.returncode, 0) + self.assertEqual(proc.stderr, '') + return proc.stdout + + def test_sys_path_0(self): + # The main interpreter's sys.path[0] should be used by subinterpreters. + script = ''' + import sys + from test.support import interpreters + + orig = sys.path[0] + + interp = interpreters.create() + interp.exec_sync(f"""if True: + import json + import sys + print(json.dumps({{ + 'main': {orig!r}, + 'sub': sys.path[0], + }}, indent=4), flush=True) + """) + ''' + # / + # pkg/ + # __init__.py + # __main__.py + # script.py + # script.py + cwd = self.create_temp_dir() + self.write_script(cwd, 'pkg', '__init__.py', text='') + self.write_script(cwd, 'pkg', '__main__.py', text=script) + self.write_script(cwd, 'pkg', 'script.py', text=script) + self.write_script(cwd, 'script.py', text=script) + + cases = [ + ('script.py', cwd), + ('-m script', cwd), + ('-m pkg', cwd), + ('-m pkg.script', cwd), + ('-c "import script"', ''), + ] + for argv, expected in cases: + with self.subTest(f'python3 {argv}'): + out = self.run_python(argv, cwd=cwd) + data = json.loads(out) + sp0_main, sp0_sub = data['main'], data['sub'] + self.assertEqual(sp0_sub, sp0_main) + self.assertEqual(sp0_sub, expected) + # XXX Also check them all with the -P cmdline flag? + + +class FinalizationTests(TestBase): + + def test_gh_109793(self): + # Make sure finalization finishes and the correct error code + # is reported, even when subinterpreters get cleaned up at the end. + import subprocess + argv = [sys.executable, '-c', '''if True: + from test.support import interpreters + interp = interpreters.create() + raise Exception + '''] + proc = subprocess.run(argv, capture_output=True, text=True) + self.assertIn('Traceback', proc.stderr) + if proc.returncode == 0 and support.verbose: + print() + print("--- cmd unexpected succeeded ---") + print(f"stdout:\n{proc.stdout}") + print(f"stderr:\n{proc.stderr}") + print("------") + self.assertEqual(proc.returncode, 1) + + +if __name__ == '__main__': + # Test needs to be a package, so we can do relative imports. + unittest.main() diff --git a/Lib/test/test_interpreters/test_queues.py b/Lib/test/test_interpreters/test_queues.py new file mode 100644 index 00000000000000..2af90b14d3e3c4 --- /dev/null +++ b/Lib/test/test_interpreters/test_queues.py @@ -0,0 +1,233 @@ +import threading +from textwrap import dedent +import unittest +import time + +from test.support import import_helper +# Raise SkipTest if subinterpreters not supported. +import_helper.import_module('_xxinterpchannels') +#import_helper.import_module('_xxinterpqueues') +from test.support import interpreters +from test.support.interpreters import queues +from .utils import _run_output, TestBase + + +class QueueTests(TestBase): + + def test_create(self): + with self.subTest('vanilla'): + queue = queues.create() + self.assertEqual(queue.maxsize, 0) + + with self.subTest('small maxsize'): + queue = queues.create(3) + self.assertEqual(queue.maxsize, 3) + + with self.subTest('big maxsize'): + queue = queues.create(100) + self.assertEqual(queue.maxsize, 100) + + with self.subTest('no maxsize'): + queue = queues.create(0) + self.assertEqual(queue.maxsize, 0) + + with self.subTest('negative maxsize'): + queue = queues.create(-1) + self.assertEqual(queue.maxsize, 0) + + with self.subTest('bad maxsize'): + with self.assertRaises(TypeError): + queues.create('1') + + @unittest.expectedFailure + def test_shareable(self): + queue1 = queues.create() + queue2 = queues.create() + queue1.put(queue2) + queue3 = queue1.get() + self.assertIs(queue3, queue1) + + def test_id_type(self): + queue = queues.create() + self.assertIsInstance(queue.id, int) + + def test_custom_id(self): + with self.assertRaises(queues.QueueNotFoundError): + queues.Queue(1_000_000) + + def test_id_readonly(self): + queue = queues.create() + with self.assertRaises(AttributeError): + queue.id = 1_000_000 + + def test_maxsize_readonly(self): + queue = queues.create(10) + with self.assertRaises(AttributeError): + queue.maxsize = 1_000_000 + + def test_hashable(self): + queue = queues.create() + expected = hash(queue.id) + actual = hash(queue) + self.assertEqual(actual, expected) + + def test_equality(self): + queue1 = queues.create() + queue2 = queues.create() + self.assertEqual(queue1, queue1) + self.assertNotEqual(queue1, queue2) + + +class TestQueueOps(TestBase): + + def test_empty(self): + queue = queues.create() + before = queue.empty() + queue.put(None) + during = queue.empty() + queue.get() + after = queue.empty() + + self.assertIs(before, True) + self.assertIs(during, False) + self.assertIs(after, True) + + def test_full(self): + expected = [False, False, False, True, False, False, False] + actual = [] + queue = queues.create(3) + for _ in range(3): + actual.append(queue.full()) + queue.put(None) + actual.append(queue.full()) + for _ in range(3): + queue.get() + actual.append(queue.full()) + + self.assertEqual(actual, expected) + + def test_qsize(self): + expected = [0, 1, 2, 3, 2, 3, 2, 1, 0, 1, 0] + actual = [] + queue = queues.create() + for _ in range(3): + actual.append(queue.qsize()) + queue.put(None) + actual.append(queue.qsize()) + queue.get() + actual.append(queue.qsize()) + queue.put(None) + actual.append(queue.qsize()) + for _ in range(3): + queue.get() + actual.append(queue.qsize()) + queue.put(None) + actual.append(queue.qsize()) + queue.get() + actual.append(queue.qsize()) + + self.assertEqual(actual, expected) + + def test_put_get_main(self): + expected = list(range(20)) + queue = queues.create() + for i in range(20): + queue.put(i) + actual = [queue.get() for _ in range(20)] + + self.assertEqual(actual, expected) + + @unittest.expectedFailure + def test_put_timeout(self): + queue = queues.create(2) + queue.put(None) + queue.put(None) + with self.assertRaises(queues.QueueFull): + queue.put(None, timeout=0.1) + queue.get() + queue.put(None) + + @unittest.expectedFailure + def test_put_nowait(self): + queue = queues.create(2) + queue.put_nowait(None) + queue.put_nowait(None) + with self.assertRaises(queues.QueueFull): + queue.put_nowait(None) + queue.get() + queue.put_nowait(None) + + def test_get_timeout(self): + queue = queues.create() + with self.assertRaises(queues.QueueEmpty): + queue.get(timeout=0.1) + + def test_get_nowait(self): + queue = queues.create() + with self.assertRaises(queues.QueueEmpty): + queue.get_nowait() + + def test_put_get_same_interpreter(self): + interp = interpreters.create() + interp.exec_sync(dedent(""" + from test.support.interpreters import queues + queue = queues.create() + orig = b'spam' + queue.put(orig) + obj = queue.get() + assert obj == orig, 'expected: obj == orig' + assert obj is not orig, 'expected: obj is not orig' + """)) + + @unittest.expectedFailure + def test_put_get_different_interpreters(self): + queue1 = queues.create() + queue2 = queues.create() + obj1 = b'spam' + queue1.put(obj1) + out = _run_output( + interpreters.create(), + dedent(f""" + import test.support.interpreters.queue as queues + queue1 = queues.Queue({queue1.id}) + queue2 = queues.Queue({queue2.id}) + obj = queue1.get() + assert obj == b'spam', 'expected: obj == obj1' + # When going to another interpreter we get a copy. + assert id(obj) != {id(obj1)}, 'expected: obj is not obj1' + obj2 = b'eggs' + print(id(obj2)) + queue2.put(obj2) + """)) + obj2 = queue2.get() + + self.assertEqual(obj2, b'eggs') + self.assertNotEqual(id(obj2), int(out)) + + def test_put_get_different_threads(self): + queue1 = queues.create() + queue2 = queues.create() + + def f(): + while True: + try: + obj = queue1.get(timeout=0.1) + break + except queues.QueueEmpty: + continue + queue2.put(obj) + t = threading.Thread(target=f) + t.start() + + orig = b'spam' + queue1.put(orig) + obj = queue2.get() + t.join() + + self.assertEqual(obj, orig) + self.assertIsNot(obj, orig) + + +if __name__ == '__main__': + # Test needs to be a package, so we can do relative imports. + unittest.main() diff --git a/Lib/test/test_interpreters/test_stress.py b/Lib/test/test_interpreters/test_stress.py new file mode 100644 index 00000000000000..3cc570b3bf7128 --- /dev/null +++ b/Lib/test/test_interpreters/test_stress.py @@ -0,0 +1,38 @@ +import threading +import unittest + +from test import support +from test.support import import_helper +from test.support import threading_helper +# Raise SkipTest if subinterpreters not supported. +import_helper.import_module('_xxsubinterpreters') +from test.support import interpreters +from .utils import TestBase + + +class StressTests(TestBase): + + # In these tests we generally want a lot of interpreters, + # but not so many that any test takes too long. + + @support.requires_resource('cpu') + def test_create_many_sequential(self): + alive = [] + for _ in range(100): + interp = interpreters.create() + alive.append(interp) + + @support.requires_resource('cpu') + def test_create_many_threaded(self): + alive = [] + def task(): + interp = interpreters.create() + alive.append(interp) + threads = (threading.Thread(target=task) for _ in range(200)) + with threading_helper.start_threads(threads): + pass + + +if __name__ == '__main__': + # Test needs to be a package, so we can do relative imports. + unittest.main() diff --git a/Lib/test/test_interpreters/utils.py b/Lib/test/test_interpreters/utils.py new file mode 100644 index 00000000000000..623c8737b79831 --- /dev/null +++ b/Lib/test/test_interpreters/utils.py @@ -0,0 +1,73 @@ +import contextlib +import os +import threading +from textwrap import dedent +import unittest + +from test.support import interpreters + + +def _captured_script(script): + r, w = os.pipe() + indented = script.replace('\n', '\n ') + wrapped = dedent(f""" + import contextlib + with open({w}, 'w', encoding='utf-8') as spipe: + with contextlib.redirect_stdout(spipe): + {indented} + """) + return wrapped, open(r, encoding='utf-8') + + +def clean_up_interpreters(): + for interp in interpreters.list_all(): + if interp.id == 0: # main + continue + try: + interp.close() + except RuntimeError: + pass # already destroyed + + +def _run_output(interp, request, channels=None): + script, rpipe = _captured_script(request) + with rpipe: + interp.exec_sync(script, channels=channels) + return rpipe.read() + + +@contextlib.contextmanager +def _running(interp): + r, w = os.pipe() + def run(): + interp.exec_sync(dedent(f""" + # wait for "signal" + with open({r}) as rpipe: + rpipe.read() + """)) + + t = threading.Thread(target=run) + t.start() + + yield + + with open(w, 'w') as spipe: + spipe.write('done') + t.join() + + +class TestBase(unittest.TestCase): + + def pipe(self): + def ensure_closed(fd): + try: + os.close(fd) + except OSError: + pass + r, w = os.pipe() + self.addCleanup(lambda: ensure_closed(r)) + self.addCleanup(lambda: ensure_closed(w)) + return r, w + + def tearDown(self): + clean_up_interpreters() diff --git a/Lib/test/test_sys.py b/Lib/test/test_sys.py index db5ba16c4d9739..6c87dfabad9f0f 100644 --- a/Lib/test/test_sys.py +++ b/Lib/test/test_sys.py @@ -729,7 +729,7 @@ def test_subinterp_intern_dynamically_allocated(self): self.assertIs(t, s) interp = interpreters.create() - interp.run(textwrap.dedent(f''' + interp.exec_sync(textwrap.dedent(f''' import sys t = sys.intern({s!r}) assert id(t) != {id(s)}, (id(t), {id(s)}) @@ -744,7 +744,7 @@ def test_subinterp_intern_statically_allocated(self): t = sys.intern(s) interp = interpreters.create() - interp.run(textwrap.dedent(f''' + interp.exec_sync(textwrap.dedent(f''' import sys t = sys.intern({s!r}) assert id(t) == {id(t)}, (id(t), {id(t)}) diff --git a/Lib/test/test_threading.py b/Lib/test/test_threading.py index 146e2dbc0fc396..a5744a4037ecea 100644 --- a/Lib/test/test_threading.py +++ b/Lib/test/test_threading.py @@ -1365,7 +1365,7 @@ def test_threads_join_with_no_main(self): DONE = b'D' interp = interpreters.create() - interp.run(f"""if True: + interp.exec_sync(f"""if True: import os import threading import time diff --git a/Modules/_testcapimodule.c b/Modules/_testcapimodule.c index 9fdd67093338e4..b3ddfae58e6fc0 100644 --- a/Modules/_testcapimodule.c +++ b/Modules/_testcapimodule.c @@ -13,6 +13,7 @@ #include "_testcapi/parts.h" #include "frameobject.h" // PyFrame_New() +#include "interpreteridobject.h" // PyInterpreterID_Type #include "marshal.h" // PyMarshal_WriteLongToFile() #include // FLT_MAX @@ -1451,6 +1452,36 @@ run_in_subinterp(PyObject *self, PyObject *args) return PyLong_FromLong(r); } +static PyObject * +get_interpreterid_type(PyObject *self, PyObject *Py_UNUSED(ignored)) +{ + return Py_NewRef(&PyInterpreterID_Type); +} + +static PyObject * +link_interpreter_refcount(PyObject *self, PyObject *idobj) +{ + PyInterpreterState *interp = PyInterpreterID_LookUp(idobj); + if (interp == NULL) { + assert(PyErr_Occurred()); + return NULL; + } + _PyInterpreterState_RequireIDRef(interp, 1); + Py_RETURN_NONE; +} + +static PyObject * +unlink_interpreter_refcount(PyObject *self, PyObject *idobj) +{ + PyInterpreterState *interp = PyInterpreterID_LookUp(idobj); + if (interp == NULL) { + assert(PyErr_Occurred()); + return NULL; + } + _PyInterpreterState_RequireIDRef(interp, 0); + Py_RETURN_NONE; +} + static PyMethodDef ml; static PyObject * @@ -3237,6 +3268,9 @@ static PyMethodDef TestMethods[] = { {"crash_no_current_thread", crash_no_current_thread, METH_NOARGS}, {"test_current_tstate_matches", test_current_tstate_matches, METH_NOARGS}, {"run_in_subinterp", run_in_subinterp, METH_VARARGS}, + {"get_interpreterid_type", get_interpreterid_type, METH_NOARGS}, + {"link_interpreter_refcount", link_interpreter_refcount, METH_O}, + {"unlink_interpreter_refcount", unlink_interpreter_refcount, METH_O}, {"create_cfunction", create_cfunction, METH_NOARGS}, {"call_in_temporary_c_thread", call_in_temporary_c_thread, METH_VARARGS, PyDoc_STR("set_error_class(error_class) -> None")}, diff --git a/Modules/_testinternalcapi.c b/Modules/_testinternalcapi.c index ba7653f2d9c7aa..7d277df164d3ec 100644 --- a/Modules/_testinternalcapi.c +++ b/Modules/_testinternalcapi.c @@ -1475,6 +1475,17 @@ run_in_subinterp_with_config(PyObject *self, PyObject *args, PyObject *kwargs) } +static PyObject * +get_interpreter_refcount(PyObject *self, PyObject *idobj) +{ + PyInterpreterState *interp = PyInterpreterID_LookUp(idobj); + if (interp == NULL) { + return NULL; + } + return PyLong_FromLongLong(interp->id_refcount); +} + + static void _xid_capsule_destructor(PyObject *capsule) { @@ -1693,6 +1704,7 @@ static PyMethodDef module_functions[] = { {"run_in_subinterp_with_config", _PyCFunction_CAST(run_in_subinterp_with_config), METH_VARARGS | METH_KEYWORDS}, + {"get_interpreter_refcount", get_interpreter_refcount, METH_O}, {"compile_perf_trampoline_entry", compile_perf_trampoline_entry, METH_VARARGS}, {"perf_trampoline_set_persist_after_fork", perf_trampoline_set_persist_after_fork, METH_VARARGS}, {"get_crossinterp_data", get_crossinterp_data, METH_VARARGS}, diff --git a/Modules/_xxinterpchannelsmodule.c b/Modules/_xxinterpchannelsmodule.c index 1c9ae3b87adf7c..97729ec269cb62 100644 --- a/Modules/_xxinterpchannelsmodule.c +++ b/Modules/_xxinterpchannelsmodule.c @@ -8,7 +8,6 @@ #include "Python.h" #include "interpreteridobject.h" #include "pycore_crossinterp.h" // struct _xid -#include "pycore_pybuffer.h" // _PyBuffer_ReleaseInInterpreterAndRawFree() #include "pycore_interp.h" // _PyInterpreterState_LookUpID() #ifdef MS_WINDOWS @@ -263,136 +262,6 @@ wait_for_lock(PyThread_type_lock mutex, PY_TIMEOUT_T timeout) } -/* Cross-interpreter Buffer Views *******************************************/ - -// XXX Release when the original interpreter is destroyed. - -typedef struct { - PyObject_HEAD - Py_buffer *view; - int64_t interpid; -} XIBufferViewObject; - -static PyObject * -xibufferview_from_xid(PyTypeObject *cls, _PyCrossInterpreterData *data) -{ - assert(data->data != NULL); - assert(data->obj == NULL); - assert(data->interpid >= 0); - XIBufferViewObject *self = PyObject_Malloc(sizeof(XIBufferViewObject)); - if (self == NULL) { - return NULL; - } - PyObject_Init((PyObject *)self, cls); - self->view = (Py_buffer *)data->data; - self->interpid = data->interpid; - return (PyObject *)self; -} - -static void -xibufferview_dealloc(XIBufferViewObject *self) -{ - PyInterpreterState *interp = _PyInterpreterState_LookUpID(self->interpid); - /* If the interpreter is no longer alive then we have problems, - since other objects may be using the buffer still. */ - assert(interp != NULL); - - if (_PyBuffer_ReleaseInInterpreterAndRawFree(interp, self->view) < 0) { - // XXX Emit a warning? - PyErr_Clear(); - } - - PyTypeObject *tp = Py_TYPE(self); - tp->tp_free(self); - /* "Instances of heap-allocated types hold a reference to their type." - * See: https://docs.python.org/3.11/howto/isolating-extensions.html#garbage-collection-protocol - * See: https://docs.python.org/3.11/c-api/typeobj.html#c.PyTypeObject.tp_traverse - */ - // XXX Why don't we implement Py_TPFLAGS_HAVE_GC, e.g. Py_tp_traverse, - // like we do for _abc._abc_data? - Py_DECREF(tp); -} - -static int -xibufferview_getbuf(XIBufferViewObject *self, Py_buffer *view, int flags) -{ - /* Only PyMemoryView_FromObject() should ever call this, - via _memoryview_from_xid() below. */ - *view = *self->view; - view->obj = (PyObject *)self; - // XXX Should we leave it alone? - view->internal = NULL; - return 0; -} - -static PyType_Slot XIBufferViewType_slots[] = { - {Py_tp_dealloc, (destructor)xibufferview_dealloc}, - {Py_bf_getbuffer, (getbufferproc)xibufferview_getbuf}, - // We don't bother with Py_bf_releasebuffer since we don't need it. - {0, NULL}, -}; - -static PyType_Spec XIBufferViewType_spec = { - .name = MODULE_NAME ".CrossInterpreterBufferView", - .basicsize = sizeof(XIBufferViewObject), - .flags = (Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | - Py_TPFLAGS_DISALLOW_INSTANTIATION | Py_TPFLAGS_IMMUTABLETYPE), - .slots = XIBufferViewType_slots, -}; - - -/* extra XID types **********************************************************/ - -static PyTypeObject * _get_current_xibufferview_type(void); - -static PyObject * -_memoryview_from_xid(_PyCrossInterpreterData *data) -{ - PyTypeObject *cls = _get_current_xibufferview_type(); - if (cls == NULL) { - return NULL; - } - PyObject *obj = xibufferview_from_xid(cls, data); - if (obj == NULL) { - return NULL; - } - return PyMemoryView_FromObject(obj); -} - -static int -_memoryview_shared(PyThreadState *tstate, PyObject *obj, - _PyCrossInterpreterData *data) -{ - Py_buffer *view = PyMem_RawMalloc(sizeof(Py_buffer)); - if (view == NULL) { - return -1; - } - if (PyObject_GetBuffer(obj, view, PyBUF_FULL_RO) < 0) { - PyMem_RawFree(view); - return -1; - } - _PyCrossInterpreterData_Init(data, tstate->interp, view, NULL, - _memoryview_from_xid); - return 0; -} - -static int -register_builtin_xid_types(struct xid_class_registry *classes) -{ - PyTypeObject *cls; - crossinterpdatafunc func; - - // builtin memoryview - cls = &PyMemoryView_Type; - func = _memoryview_shared; - if (register_xid_class(cls, func, classes)) { - return -1; - } - - return 0; -} - - /* module state *************************************************************/ typedef struct { @@ -405,7 +274,6 @@ typedef struct { /* heap types */ PyTypeObject *ChannelInfoType; PyTypeObject *ChannelIDType; - PyTypeObject *XIBufferViewType; /* exceptions */ PyObject *ChannelError; @@ -449,7 +317,6 @@ traverse_module_state(module_state *state, visitproc visit, void *arg) /* heap types */ Py_VISIT(state->ChannelInfoType); Py_VISIT(state->ChannelIDType); - Py_VISIT(state->XIBufferViewType); /* exceptions */ Py_VISIT(state->ChannelError); @@ -474,7 +341,6 @@ clear_module_state(module_state *state) (void)_PyCrossInterpreterData_UnregisterClass(state->ChannelIDType); } Py_CLEAR(state->ChannelIDType); - Py_CLEAR(state->XIBufferViewType); /* exceptions */ Py_CLEAR(state->ChannelError); @@ -487,17 +353,6 @@ clear_module_state(module_state *state) } -static PyTypeObject * -_get_current_xibufferview_type(void) -{ - module_state *state = _get_current_module_state(); - if (state == NULL) { - return NULL; - } - return state->XIBufferViewType; -} - - /* channel-specific code ****************************************************/ #define CHANNEL_SEND 1 @@ -3463,18 +3318,6 @@ module_exec(PyObject *mod) goto error; } - // XIBufferView - state->XIBufferViewType = add_new_type(mod, &XIBufferViewType_spec, NULL, - xid_classes); - if (state->XIBufferViewType == NULL) { - goto error; - } - - // Register external types. - if (register_builtin_xid_types(xid_classes) < 0) { - goto error; - } - /* Make sure chnnels drop objects owned by this interpreter. */ PyInterpreterState *interp = _get_current_interp(); PyUnstable_AtExit(interp, clear_interpreter, (void *)interp); diff --git a/Modules/_xxsubinterpretersmodule.c b/Modules/_xxsubinterpretersmodule.c index 02c2abed27ddfa..37959e953ee4f5 100644 --- a/Modules/_xxsubinterpretersmodule.c +++ b/Modules/_xxsubinterpretersmodule.c @@ -6,11 +6,14 @@ #endif #include "Python.h" +#include "pycore_abstract.h" // _PyIndex_Check() #include "pycore_crossinterp.h" // struct _xid -#include "pycore_pyerrors.h" // _Py_excinfo +#include "pycore_interp.h" // _PyInterpreterState_IDIncref() #include "pycore_initconfig.h" // _PyErr_SetFromPyStatus() +#include "pycore_long.h" // _PyLong_IsNegative() #include "pycore_modsupport.h" // _PyArg_BadArgument() -#include "pycore_pyerrors.h" // _PyErr_ChainExceptions1() +#include "pycore_pybuffer.h" // _PyBuffer_ReleaseInInterpreterAndRawFree() +#include "pycore_pyerrors.h" // _Py_excinfo #include "pycore_pystate.h" // _PyInterpreterState_SetRunningMain() #include "interpreteridobject.h" @@ -28,11 +31,260 @@ _get_current_interp(void) return PyInterpreterState_Get(); } +static int64_t +pylong_to_interpid(PyObject *idobj) +{ + assert(PyLong_CheckExact(idobj)); + + if (_PyLong_IsNegative((PyLongObject *)idobj)) { + PyErr_Format(PyExc_ValueError, + "interpreter ID must be a non-negative int, got %R", + idobj); + return -1; + } + + int overflow; + long long id = PyLong_AsLongLongAndOverflow(idobj, &overflow); + if (id == -1) { + if (!overflow) { + assert(PyErr_Occurred()); + return -1; + } + assert(!PyErr_Occurred()); + // For now, we don't worry about if LLONG_MAX < INT64_MAX. + goto bad_id; + } +#if LLONG_MAX > INT64_MAX + if (id > INT64_MAX) { + goto bad_id; + } +#endif + return (int64_t)id; + +bad_id: + PyErr_Format(PyExc_RuntimeError, + "unrecognized interpreter ID %O", idobj); + return -1; +} + +static int64_t +convert_interpid_obj(PyObject *arg) +{ + int64_t id = -1; + if (_PyIndex_Check(arg)) { + PyObject *idobj = PyNumber_Long(arg); + if (idobj == NULL) { + return -1; + } + id = pylong_to_interpid(idobj); + Py_DECREF(idobj); + if (id < 0) { + return -1; + } + } + else { + PyErr_Format(PyExc_TypeError, + "interpreter ID must be an int, got %.100s", + Py_TYPE(arg)->tp_name); + return -1; + } + return id; +} + +static PyInterpreterState * +look_up_interp(PyObject *arg) +{ + int64_t id = convert_interpid_obj(arg); + if (id < 0) { + return NULL; + } + return _PyInterpreterState_LookUpID(id); +} + + +static PyObject * +interpid_to_pylong(int64_t id) +{ + assert(id < LLONG_MAX); + return PyLong_FromLongLong(id); +} + +static PyObject * +get_interpid_obj(PyInterpreterState *interp) +{ + if (_PyInterpreterState_IDInitref(interp) != 0) { + return NULL; + }; + int64_t id = PyInterpreterState_GetID(interp); + if (id < 0) { + return NULL; + } + return interpid_to_pylong(id); +} + +static PyObject * +_get_current_module(void) +{ + PyObject *name = PyUnicode_FromString(MODULE_NAME); + if (name == NULL) { + return NULL; + } + PyObject *mod = PyImport_GetModule(name); + Py_DECREF(name); + if (mod == NULL) { + return NULL; + } + assert(mod != Py_None); + return mod; +} + + +/* Cross-interpreter Buffer Views *******************************************/ + +// XXX Release when the original interpreter is destroyed. + +typedef struct { + PyObject_HEAD + Py_buffer *view; + int64_t interpid; +} XIBufferViewObject; + +static PyObject * +xibufferview_from_xid(PyTypeObject *cls, _PyCrossInterpreterData *data) +{ + assert(data->data != NULL); + assert(data->obj == NULL); + assert(data->interpid >= 0); + XIBufferViewObject *self = PyObject_Malloc(sizeof(XIBufferViewObject)); + if (self == NULL) { + return NULL; + } + PyObject_Init((PyObject *)self, cls); + self->view = (Py_buffer *)data->data; + self->interpid = data->interpid; + return (PyObject *)self; +} + +static void +xibufferview_dealloc(XIBufferViewObject *self) +{ + PyInterpreterState *interp = _PyInterpreterState_LookUpID(self->interpid); + /* If the interpreter is no longer alive then we have problems, + since other objects may be using the buffer still. */ + assert(interp != NULL); + + if (_PyBuffer_ReleaseInInterpreterAndRawFree(interp, self->view) < 0) { + // XXX Emit a warning? + PyErr_Clear(); + } + + PyTypeObject *tp = Py_TYPE(self); + tp->tp_free(self); + /* "Instances of heap-allocated types hold a reference to their type." + * See: https://docs.python.org/3.11/howto/isolating-extensions.html#garbage-collection-protocol + * See: https://docs.python.org/3.11/c-api/typeobj.html#c.PyTypeObject.tp_traverse + */ + // XXX Why don't we implement Py_TPFLAGS_HAVE_GC, e.g. Py_tp_traverse, + // like we do for _abc._abc_data? + Py_DECREF(tp); +} + +static int +xibufferview_getbuf(XIBufferViewObject *self, Py_buffer *view, int flags) +{ + /* Only PyMemoryView_FromObject() should ever call this, + via _memoryview_from_xid() below. */ + *view = *self->view; + view->obj = (PyObject *)self; + // XXX Should we leave it alone? + view->internal = NULL; + return 0; +} + +static PyType_Slot XIBufferViewType_slots[] = { + {Py_tp_dealloc, (destructor)xibufferview_dealloc}, + {Py_bf_getbuffer, (getbufferproc)xibufferview_getbuf}, + // We don't bother with Py_bf_releasebuffer since we don't need it. + {0, NULL}, +}; + +static PyType_Spec XIBufferViewType_spec = { + .name = MODULE_NAME ".CrossInterpreterBufferView", + .basicsize = sizeof(XIBufferViewObject), + .flags = (Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | + Py_TPFLAGS_DISALLOW_INSTANTIATION | Py_TPFLAGS_IMMUTABLETYPE), + .slots = XIBufferViewType_slots, +}; + + +static PyTypeObject * _get_current_xibufferview_type(void); + +static PyObject * +_memoryview_from_xid(_PyCrossInterpreterData *data) +{ + PyTypeObject *cls = _get_current_xibufferview_type(); + if (cls == NULL) { + return NULL; + } + PyObject *obj = xibufferview_from_xid(cls, data); + if (obj == NULL) { + return NULL; + } + return PyMemoryView_FromObject(obj); +} + +static int +_memoryview_shared(PyThreadState *tstate, PyObject *obj, + _PyCrossInterpreterData *data) +{ + Py_buffer *view = PyMem_RawMalloc(sizeof(Py_buffer)); + if (view == NULL) { + return -1; + } + if (PyObject_GetBuffer(obj, view, PyBUF_FULL_RO) < 0) { + PyMem_RawFree(view); + return -1; + } + _PyCrossInterpreterData_Init(data, tstate->interp, view, NULL, + _memoryview_from_xid); + return 0; +} + +static int +register_memoryview_xid(PyObject *mod, PyTypeObject **p_state) +{ + // XIBufferView + assert(*p_state == NULL); + PyTypeObject *cls = (PyTypeObject *)PyType_FromModuleAndSpec( + mod, &XIBufferViewType_spec, NULL); + if (cls == NULL) { + return -1; + } + if (PyModule_AddType(mod, cls) < 0) { + Py_DECREF(cls); + return -1; + } + *p_state = cls; + + // Register XID for the builtin memoryview type. + if (_PyCrossInterpreterData_RegisterClass( + &PyMemoryView_Type, _memoryview_shared) < 0) { + return -1; + } + // We don't ever bother un-registering memoryview. + + return 0; +} + + /* module state *************************************************************/ typedef struct { int _notused; + + /* heap types */ + PyTypeObject *XIBufferViewType; } module_state; static inline module_state * @@ -44,19 +296,51 @@ get_module_state(PyObject *mod) return state; } +static module_state * +_get_current_module_state(void) +{ + PyObject *mod = _get_current_module(); + if (mod == NULL) { + // XXX import it? + PyErr_SetString(PyExc_RuntimeError, + MODULE_NAME " module not imported yet"); + return NULL; + } + module_state *state = get_module_state(mod); + Py_DECREF(mod); + return state; +} + static int traverse_module_state(module_state *state, visitproc visit, void *arg) { + /* heap types */ + Py_VISIT(state->XIBufferViewType); + return 0; } static int clear_module_state(module_state *state) { + /* heap types */ + Py_CLEAR(state->XIBufferViewType); + return 0; } +static PyTypeObject * +_get_current_xibufferview_type(void) +{ + module_state *state = _get_current_module_state(); + if (state == NULL) { + return NULL; + } + return state->XIBufferViewType; +} + + /* Python code **************************************************************/ static const char * @@ -254,7 +538,7 @@ interp_create(PyObject *self, PyObject *args, PyObject *kwds) assert(tstate != NULL); PyInterpreterState *interp = PyThreadState_GetInterpreter(tstate); - PyObject *idobj = PyInterpreterState_GetIDObject(interp); + PyObject *idobj = get_interpid_obj(interp); if (idobj == NULL) { // XXX Possible GILState issues? save_tstate = PyThreadState_Swap(tstate); @@ -273,7 +557,9 @@ interp_create(PyObject *self, PyObject *args, PyObject *kwds) PyDoc_STRVAR(create_doc, "create() -> ID\n\ \n\ -Create a new interpreter and return a unique generated ID."); +Create a new interpreter and return a unique generated ID.\n\ +\n\ +The caller is responsible for destroying the interpreter before exiting."); static PyObject * @@ -288,7 +574,7 @@ interp_destroy(PyObject *self, PyObject *args, PyObject *kwds) } // Look up the interpreter. - PyInterpreterState *interp = PyInterpreterID_LookUp(id); + PyInterpreterState *interp = look_up_interp(id); if (interp == NULL) { return NULL; } @@ -345,7 +631,7 @@ interp_list_all(PyObject *self, PyObject *Py_UNUSED(ignored)) interp = PyInterpreterState_Head(); while (interp != NULL) { - id = PyInterpreterState_GetIDObject(interp); + id = get_interpid_obj(interp); if (id == NULL) { Py_DECREF(ids); return NULL; @@ -377,7 +663,7 @@ interp_get_current(PyObject *self, PyObject *Py_UNUSED(ignored)) if (interp == NULL) { return NULL; } - return PyInterpreterState_GetIDObject(interp); + return get_interpid_obj(interp); } PyDoc_STRVAR(get_current_doc, @@ -391,7 +677,7 @@ interp_get_main(PyObject *self, PyObject *Py_UNUSED(ignored)) { // Currently, 0 is always the main interpreter. int64_t id = 0; - return PyInterpreterID_New(id); + return PyLong_FromLongLong(id); } PyDoc_STRVAR(get_main_doc, @@ -481,7 +767,7 @@ _interp_exec(PyObject *self, PyObject **p_excinfo) { // Look up the interpreter. - PyInterpreterState *interp = PyInterpreterID_LookUp(id_arg); + PyInterpreterState *interp = look_up_interp(id_arg); if (interp == NULL) { return -1; } @@ -667,7 +953,7 @@ interp_is_running(PyObject *self, PyObject *args, PyObject *kwds) return NULL; } - PyInterpreterState *interp = PyInterpreterID_LookUp(id); + PyInterpreterState *interp = look_up_interp(id); if (interp == NULL) { return NULL; } @@ -683,6 +969,49 @@ PyDoc_STRVAR(is_running_doc, Return whether or not the identified interpreter is running."); +static PyObject * +interp_incref(PyObject *self, PyObject *args, PyObject *kwds) +{ + static char *kwlist[] = {"id", NULL}; + PyObject *id; + if (!PyArg_ParseTupleAndKeywords(args, kwds, + "O:_incref", kwlist, &id)) { + return NULL; + } + + PyInterpreterState *interp = look_up_interp(id); + if (interp == NULL) { + return NULL; + } + if (_PyInterpreterState_IDInitref(interp) < 0) { + return NULL; + } + _PyInterpreterState_IDIncref(interp); + + Py_RETURN_NONE; +} + + +static PyObject * +interp_decref(PyObject *self, PyObject *args, PyObject *kwds) +{ + static char *kwlist[] = {"id", NULL}; + PyObject *id; + if (!PyArg_ParseTupleAndKeywords(args, kwds, + "O:_incref", kwlist, &id)) { + return NULL; + } + + PyInterpreterState *interp = look_up_interp(id); + if (interp == NULL) { + return NULL; + } + _PyInterpreterState_IDDecref(interp); + + Py_RETURN_NONE; +} + + static PyMethodDef module_functions[] = { {"create", _PyCFunction_CAST(interp_create), METH_VARARGS | METH_KEYWORDS, create_doc}, @@ -707,6 +1036,11 @@ static PyMethodDef module_functions[] = { {"is_shareable", _PyCFunction_CAST(object_is_shareable), METH_VARARGS | METH_KEYWORDS, is_shareable_doc}, + {"_incref", _PyCFunction_CAST(interp_incref), + METH_VARARGS | METH_KEYWORDS, NULL}, + {"_decref", _PyCFunction_CAST(interp_decref), + METH_VARARGS | METH_KEYWORDS, NULL}, + {NULL, NULL} /* sentinel */ }; @@ -720,8 +1054,17 @@ The 'interpreters' module provides a more convenient interface."); static int module_exec(PyObject *mod) { - // PyInterpreterID - if (PyModule_AddType(mod, &PyInterpreterID_Type) < 0) { + module_state *state = get_module_state(mod); + + // exceptions + if (PyModule_AddType(mod, (PyTypeObject *)PyExc_InterpreterError) < 0) { + goto error; + } + if (PyModule_AddType(mod, (PyTypeObject *)PyExc_InterpreterNotFoundError) < 0) { + goto error; + } + + if (register_memoryview_xid(mod, &state->XIBufferViewType) < 0) { goto error; } diff --git a/Python/crossinterp.c b/Python/crossinterp.c index f74fee38648266..a31b5ef4613dbd 100644 --- a/Python/crossinterp.c +++ b/Python/crossinterp.c @@ -12,6 +12,53 @@ #include "pycore_weakref.h" // _PyWeakref_GET_REF() +/**************/ +/* exceptions */ +/**************/ + +/* InterpreterError extends Exception */ + +static PyTypeObject _PyExc_InterpreterError = { + PyVarObject_HEAD_INIT(NULL, 0) + .tp_name = "InterpreterError", + .tp_doc = PyDoc_STR("An interpreter was not found."), + //.tp_base = (PyTypeObject *)PyExc_BaseException, +}; +PyObject *PyExc_InterpreterError = (PyObject *)&_PyExc_InterpreterError; + +/* InterpreterNotFoundError extends InterpreterError */ + +static PyTypeObject _PyExc_InterpreterNotFoundError = { + PyVarObject_HEAD_INIT(NULL, 0) + .tp_name = "InterpreterNotFoundError", + .tp_doc = PyDoc_STR("An interpreter was not found."), + .tp_base = &_PyExc_InterpreterError, +}; +PyObject *PyExc_InterpreterNotFoundError = (PyObject *)&_PyExc_InterpreterNotFoundError; + +/* lifecycle */ + +static int +init_exceptions(PyInterpreterState *interp) +{ + _PyExc_InterpreterError.tp_base = (PyTypeObject *)PyExc_BaseException; + if (_PyStaticType_InitBuiltin(interp, &_PyExc_InterpreterError) < 0) { + return -1; + } + if (_PyStaticType_InitBuiltin(interp, &_PyExc_InterpreterNotFoundError) < 0) { + return -1; + } + return 0; +} + +static void +fini_exceptions(PyInterpreterState *interp) +{ + _PyStaticType_Dealloc(interp, &_PyExc_InterpreterNotFoundError); + _PyStaticType_Dealloc(interp, &_PyExc_InterpreterError); +} + + /***************************/ /* cross-interpreter calls */ /***************************/ @@ -2099,3 +2146,18 @@ _PyXI_Fini(PyInterpreterState *interp) _xidregistry_fini(_get_global_xidregistry(interp->runtime)); } } + +PyStatus +_PyXI_InitTypes(PyInterpreterState *interp) +{ + if (init_exceptions(interp) < 0) { + return _PyStatus_ERR("failed to initialize an exception type"); + } + return _PyStatus_OK(); +} + +void +_PyXI_FiniTypes(PyInterpreterState *interp) +{ + fini_exceptions(interp); +} diff --git a/Python/pylifecycle.c b/Python/pylifecycle.c index 45a119fcca7f2c..b5c7dc5da596de 100644 --- a/Python/pylifecycle.c +++ b/Python/pylifecycle.c @@ -734,6 +734,11 @@ pycore_init_types(PyInterpreterState *interp) return status; } + status = _PyXI_InitTypes(interp); + if (_PyStatus_EXCEPTION(status)) { + return status; + } + return _PyStatus_OK(); } @@ -1742,6 +1747,7 @@ finalize_interp_types(PyInterpreterState *interp) { _PyUnicode_FiniTypes(interp); _PySys_FiniTypes(interp); + _PyXI_FiniTypes(interp); _PyExc_Fini(interp); _PyAsyncGen_Fini(interp); _PyContext_Fini(interp); diff --git a/Python/pystate.c b/Python/pystate.c index 1a7c0c968504d1..f0c5259967d907 100644 --- a/Python/pystate.c +++ b/Python/pystate.c @@ -1216,7 +1216,7 @@ _PyInterpreterState_LookUpID(int64_t requested_id) HEAD_UNLOCK(runtime); } if (interp == NULL && !PyErr_Occurred()) { - PyErr_Format(PyExc_RuntimeError, + PyErr_Format(PyExc_InterpreterNotFoundError, "unrecognized interpreter ID %lld", requested_id); } return interp; diff --git a/Tools/c-analyzer/cpython/globals-to-fix.tsv b/Tools/c-analyzer/cpython/globals-to-fix.tsv index aa8ce49ae86376..e3a1b5d532bda2 100644 --- a/Tools/c-analyzer/cpython/globals-to-fix.tsv +++ b/Tools/c-analyzer/cpython/globals-to-fix.tsv @@ -290,6 +290,10 @@ Objects/exceptions.c - PyExc_UnicodeWarning - Objects/exceptions.c - PyExc_BytesWarning - Objects/exceptions.c - PyExc_ResourceWarning - Objects/exceptions.c - PyExc_EncodingWarning - +Python/crossinterp.c - _PyExc_InterpreterError - +Python/crossinterp.c - _PyExc_InterpreterNotFoundError - +Python/crossinterp.c - PyExc_InterpreterError - +Python/crossinterp.c - PyExc_InterpreterNotFoundError - ##----------------------- ## singletons