From 9898e6104171dcdd88b32776e69ca2cddf515e63 Mon Sep 17 00:00:00 2001 From: Eric Snow Date: Tue, 12 Dec 2023 11:06:06 -0700 Subject: [PATCH] gh-76785: Add Interpreter.prepare_main() (gh-113021) This is one of the last pieces to get test.support.interpreters in sync with PEP 734. --- Lib/test/support/interpreters/__init__.py | 16 +++++-- Lib/test/test__xxinterpchannels.py | 4 +- Lib/test/test__xxsubinterpreters.py | 24 ++++++---- Lib/test/test_interpreters/test_api.py | 57 +++++++++++++++++++++++ Lib/test/test_interpreters/utils.py | 6 ++- Modules/_xxsubinterpretersmodule.c | 56 ++++++++++++++++++++++ 6 files changed, 146 insertions(+), 17 deletions(-) diff --git a/Lib/test/support/interpreters/__init__.py b/Lib/test/support/interpreters/__init__.py index 2d6376deb5907e..9cd1c3de0274d2 100644 --- a/Lib/test/support/interpreters/__init__.py +++ b/Lib/test/support/interpreters/__init__.py @@ -130,7 +130,15 @@ def close(self): """ return _interpreters.destroy(self._id) - def exec_sync(self, code, /, channels=None): + def prepare_main(self, ns=None, /, **kwargs): + """Bind the given values into the interpreter's __main__. + + The values must be shareable. + """ + ns = dict(ns, **kwargs) if ns is not None else kwargs + _interpreters.set___main___attrs(self._id, ns) + + def exec_sync(self, code, /): """Run the given source code in the interpreter. This is essentially the same as calling the builtin "exec" @@ -148,13 +156,13 @@ def exec_sync(self, code, /, channels=None): that time, the previous interpreter is allowed to run in other threads. """ - excinfo = _interpreters.exec(self._id, code, channels) + excinfo = _interpreters.exec(self._id, code) if excinfo is not None: raise ExecFailure(excinfo) - def run(self, code, /, channels=None): + def run(self, code, /): def task(): - self.exec_sync(code, channels=channels) + self.exec_sync(code) t = threading.Thread(target=task) t.start() return t diff --git a/Lib/test/test__xxinterpchannels.py b/Lib/test/test__xxinterpchannels.py index 13c8a10296e502..cc2ed7849b0c0f 100644 --- a/Lib/test/test__xxinterpchannels.py +++ b/Lib/test/test__xxinterpchannels.py @@ -586,12 +586,12 @@ def test_run_string_arg_unresolved(self): cid = channels.create() interp = interpreters.create() + interpreters.set___main___attrs(interp, dict(cid=cid.send)) out = _run_output(interp, dedent(""" import _xxinterpchannels as _channels print(cid.end) _channels.send(cid, b'spam', blocking=False) - """), - dict(cid=cid.send)) + """)) obj = channels.recv(cid) self.assertEqual(obj, b'spam') diff --git a/Lib/test/test__xxsubinterpreters.py b/Lib/test/test__xxsubinterpreters.py index 260ab64b07cb2d..a76e4d0ade5b8a 100644 --- a/Lib/test/test__xxsubinterpreters.py +++ b/Lib/test/test__xxsubinterpreters.py @@ -33,10 +33,10 @@ def _captured_script(script): return wrapped, open(r, encoding="utf-8") -def _run_output(interp, request, shared=None): +def _run_output(interp, request): script, rpipe = _captured_script(request) with rpipe: - interpreters.run_string(interp, script, shared) + interpreters.run_string(interp, script) return rpipe.read() @@ -630,10 +630,10 @@ def test_shareable_types(self): ] for obj in objects: with self.subTest(obj): + interpreters.set___main___attrs(interp, dict(obj=obj)) interpreters.run_string( interp, f'assert(obj == {obj!r})', - shared=dict(obj=obj), ) def test_os_exec(self): @@ -721,7 +721,8 @@ def test_with_shared(self): with open({w}, 'wb') as chan: pickle.dump(ns, chan) """) - interpreters.run_string(self.id, script, shared) + interpreters.set___main___attrs(self.id, shared) + interpreters.run_string(self.id, script) with open(r, 'rb') as chan: ns = pickle.load(chan) @@ -742,7 +743,8 @@ def test_shared_overwrites(self): ns2 = dict(vars()) del ns2['__builtins__'] """) - interpreters.run_string(self.id, script, shared) + interpreters.set___main___attrs(self.id, shared) + interpreters.run_string(self.id, script) r, w = os.pipe() script = dedent(f""" @@ -773,7 +775,8 @@ def test_shared_overwrites_default_vars(self): with open({w}, 'wb') as chan: pickle.dump(ns, chan) """) - interpreters.run_string(self.id, script, shared) + interpreters.set___main___attrs(self.id, shared) + interpreters.run_string(self.id, script) with open(r, 'rb') as chan: ns = pickle.load(chan) @@ -1036,7 +1039,8 @@ def script(): with open(w, 'w', encoding="utf-8") as spipe: with contextlib.redirect_stdout(spipe): print('it worked!', end='') - interpreters.run_func(self.id, script, shared=dict(w=w)) + interpreters.set___main___attrs(self.id, dict(w=w)) + interpreters.run_func(self.id, script) with open(r, encoding="utf-8") as outfile: out = outfile.read() @@ -1052,7 +1056,8 @@ def script(): with contextlib.redirect_stdout(spipe): print('it worked!', end='') def f(): - interpreters.run_func(self.id, script, shared=dict(w=w)) + interpreters.set___main___attrs(self.id, dict(w=w)) + interpreters.run_func(self.id, script) t = threading.Thread(target=f) t.start() t.join() @@ -1072,7 +1077,8 @@ def script(): with contextlib.redirect_stdout(spipe): print('it worked!', end='') code = script.__code__ - interpreters.run_func(self.id, code, shared=dict(w=w)) + interpreters.set___main___attrs(self.id, dict(w=w)) + interpreters.run_func(self.id, code) with open(r, encoding="utf-8") as outfile: out = outfile.read() diff --git a/Lib/test/test_interpreters/test_api.py b/Lib/test/test_interpreters/test_api.py index e4ae9d005b5282..b702338c3de1ad 100644 --- a/Lib/test/test_interpreters/test_api.py +++ b/Lib/test/test_interpreters/test_api.py @@ -452,6 +452,63 @@ def task(): self.assertEqual(os.read(r_interp, 1), FINISHED) +class TestInterpreterPrepareMain(TestBase): + + def test_empty(self): + interp = interpreters.create() + with self.assertRaises(ValueError): + interp.prepare_main() + + def test_dict(self): + values = {'spam': 42, 'eggs': 'ham'} + interp = interpreters.create() + interp.prepare_main(values) + out = _run_output(interp, dedent(""" + print(spam, eggs) + """)) + self.assertEqual(out.strip(), '42 ham') + + def test_tuple(self): + values = {'spam': 42, 'eggs': 'ham'} + values = tuple(values.items()) + interp = interpreters.create() + interp.prepare_main(values) + out = _run_output(interp, dedent(""" + print(spam, eggs) + """)) + self.assertEqual(out.strip(), '42 ham') + + def test_kwargs(self): + values = {'spam': 42, 'eggs': 'ham'} + interp = interpreters.create() + interp.prepare_main(**values) + out = _run_output(interp, dedent(""" + print(spam, eggs) + """)) + self.assertEqual(out.strip(), '42 ham') + + def test_dict_and_kwargs(self): + values = {'spam': 42, 'eggs': 'ham'} + interp = interpreters.create() + interp.prepare_main(values, foo='bar') + out = _run_output(interp, dedent(""" + print(spam, eggs, foo) + """)) + self.assertEqual(out.strip(), '42 ham bar') + + def test_not_shareable(self): + interp = interpreters.create() + # XXX TypeError? + with self.assertRaises(ValueError): + interp.prepare_main(spam={'spam': 'eggs', 'foo': 'bar'}) + + # Make sure neither was actually bound. + with self.assertRaises(interpreters.ExecFailure): + interp.exec_sync('print(foo)') + with self.assertRaises(interpreters.ExecFailure): + interp.exec_sync('print(spam)') + + class TestInterpreterExecSync(TestBase): def test_success(self): diff --git a/Lib/test/test_interpreters/utils.py b/Lib/test/test_interpreters/utils.py index 623c8737b79831..11b6f126dff0f4 100644 --- a/Lib/test/test_interpreters/utils.py +++ b/Lib/test/test_interpreters/utils.py @@ -29,10 +29,12 @@ def clean_up_interpreters(): pass # already destroyed -def _run_output(interp, request, channels=None): +def _run_output(interp, request, init=None): script, rpipe = _captured_script(request) with rpipe: - interp.exec_sync(script, channels=channels) + if init: + interp.prepare_main(init) + interp.exec_sync(script) return rpipe.read() diff --git a/Modules/_xxsubinterpretersmodule.c b/Modules/_xxsubinterpretersmodule.c index 37959e953ee4f5..4bb54c93b0a61b 100644 --- a/Modules/_xxsubinterpretersmodule.c +++ b/Modules/_xxsubinterpretersmodule.c @@ -685,6 +685,60 @@ PyDoc_STRVAR(get_main_doc, \n\ Return the ID of main interpreter."); +static PyObject * +interp_set___main___attrs(PyObject *self, PyObject *args) +{ + PyObject *id, *updates; + if (!PyArg_ParseTuple(args, "OO:" MODULE_NAME ".set___main___attrs", + &id, &updates)) + { + return NULL; + } + + // Look up the interpreter. + PyInterpreterState *interp = PyInterpreterID_LookUp(id); + if (interp == NULL) { + return NULL; + } + + // Check the updates. + if (updates != Py_None) { + Py_ssize_t size = PyObject_Size(updates); + if (size < 0) { + return NULL; + } + if (size == 0) { + PyErr_SetString(PyExc_ValueError, + "arg 2 must be a non-empty mapping"); + return NULL; + } + } + + _PyXI_session session = {0}; + + // Prep and switch interpreters, including apply the updates. + if (_PyXI_Enter(&session, interp, updates) < 0) { + if (!PyErr_Occurred()) { + _PyXI_ApplyCapturedException(&session); + assert(PyErr_Occurred()); + } + else { + assert(!_PyXI_HasCapturedException(&session)); + } + return NULL; + } + + // Clean up and switch back. + _PyXI_Exit(&session); + + Py_RETURN_NONE; +} + +PyDoc_STRVAR(set___main___attrs_doc, +"set___main___attrs(id, ns)\n\ +\n\ +Bind the given attributes in the interpreter's __main__ module."); + static PyUnicodeObject * convert_script_arg(PyObject *arg, const char *fname, const char *displayname, const char *expected) @@ -1033,6 +1087,8 @@ static PyMethodDef module_functions[] = { {"run_func", _PyCFunction_CAST(interp_run_func), METH_VARARGS | METH_KEYWORDS, run_func_doc}, + {"set___main___attrs", _PyCFunction_CAST(interp_set___main___attrs), + METH_VARARGS, set___main___attrs_doc}, {"is_shareable", _PyCFunction_CAST(object_is_shareable), METH_VARARGS | METH_KEYWORDS, is_shareable_doc},