Skip to content

Commit

Permalink
gh-76785: Add Interpreter.prepare_main() (gh-113021)
Browse files Browse the repository at this point in the history
This is one of the last pieces to get test.support.interpreters in sync with PEP 734.
  • Loading branch information
ericsnowcurrently authored Dec 12, 2023
1 parent a49b427 commit 9898e61
Show file tree
Hide file tree
Showing 6 changed files with 146 additions and 17 deletions.
16 changes: 12 additions & 4 deletions Lib/test/support/interpreters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,15 @@ def close(self):
"""
return _interpreters.destroy(self._id)

def exec_sync(self, code, /, channels=None):
def prepare_main(self, ns=None, /, **kwargs):
"""Bind the given values into the interpreter's __main__.
The values must be shareable.
"""
ns = dict(ns, **kwargs) if ns is not None else kwargs
_interpreters.set___main___attrs(self._id, ns)

def exec_sync(self, code, /):
"""Run the given source code in the interpreter.
This is essentially the same as calling the builtin "exec"
Expand All @@ -148,13 +156,13 @@ def exec_sync(self, code, /, channels=None):
that time, the previous interpreter is allowed to run
in other threads.
"""
excinfo = _interpreters.exec(self._id, code, channels)
excinfo = _interpreters.exec(self._id, code)
if excinfo is not None:
raise ExecFailure(excinfo)

def run(self, code, /, channels=None):
def run(self, code, /):
def task():
self.exec_sync(code, channels=channels)
self.exec_sync(code)
t = threading.Thread(target=task)
t.start()
return t
4 changes: 2 additions & 2 deletions Lib/test/test__xxinterpchannels.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,12 +586,12 @@ def test_run_string_arg_unresolved(self):
cid = channels.create()
interp = interpreters.create()

interpreters.set___main___attrs(interp, dict(cid=cid.send))
out = _run_output(interp, dedent("""
import _xxinterpchannels as _channels
print(cid.end)
_channels.send(cid, b'spam', blocking=False)
"""),
dict(cid=cid.send))
"""))
obj = channels.recv(cid)

self.assertEqual(obj, b'spam')
Expand Down
24 changes: 15 additions & 9 deletions Lib/test/test__xxsubinterpreters.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@ def _captured_script(script):
return wrapped, open(r, encoding="utf-8")


def _run_output(interp, request, shared=None):
def _run_output(interp, request):
script, rpipe = _captured_script(request)
with rpipe:
interpreters.run_string(interp, script, shared)
interpreters.run_string(interp, script)
return rpipe.read()


Expand Down Expand Up @@ -630,10 +630,10 @@ def test_shareable_types(self):
]
for obj in objects:
with self.subTest(obj):
interpreters.set___main___attrs(interp, dict(obj=obj))
interpreters.run_string(
interp,
f'assert(obj == {obj!r})',
shared=dict(obj=obj),
)

def test_os_exec(self):
Expand Down Expand Up @@ -721,7 +721,8 @@ def test_with_shared(self):
with open({w}, 'wb') as chan:
pickle.dump(ns, chan)
""")
interpreters.run_string(self.id, script, shared)
interpreters.set___main___attrs(self.id, shared)
interpreters.run_string(self.id, script)
with open(r, 'rb') as chan:
ns = pickle.load(chan)

Expand All @@ -742,7 +743,8 @@ def test_shared_overwrites(self):
ns2 = dict(vars())
del ns2['__builtins__']
""")
interpreters.run_string(self.id, script, shared)
interpreters.set___main___attrs(self.id, shared)
interpreters.run_string(self.id, script)

r, w = os.pipe()
script = dedent(f"""
Expand Down Expand Up @@ -773,7 +775,8 @@ def test_shared_overwrites_default_vars(self):
with open({w}, 'wb') as chan:
pickle.dump(ns, chan)
""")
interpreters.run_string(self.id, script, shared)
interpreters.set___main___attrs(self.id, shared)
interpreters.run_string(self.id, script)
with open(r, 'rb') as chan:
ns = pickle.load(chan)

Expand Down Expand Up @@ -1036,7 +1039,8 @@ def script():
with open(w, 'w', encoding="utf-8") as spipe:
with contextlib.redirect_stdout(spipe):
print('it worked!', end='')
interpreters.run_func(self.id, script, shared=dict(w=w))
interpreters.set___main___attrs(self.id, dict(w=w))
interpreters.run_func(self.id, script)

with open(r, encoding="utf-8") as outfile:
out = outfile.read()
Expand All @@ -1052,7 +1056,8 @@ def script():
with contextlib.redirect_stdout(spipe):
print('it worked!', end='')
def f():
interpreters.run_func(self.id, script, shared=dict(w=w))
interpreters.set___main___attrs(self.id, dict(w=w))
interpreters.run_func(self.id, script)
t = threading.Thread(target=f)
t.start()
t.join()
Expand All @@ -1072,7 +1077,8 @@ def script():
with contextlib.redirect_stdout(spipe):
print('it worked!', end='')
code = script.__code__
interpreters.run_func(self.id, code, shared=dict(w=w))
interpreters.set___main___attrs(self.id, dict(w=w))
interpreters.run_func(self.id, code)

with open(r, encoding="utf-8") as outfile:
out = outfile.read()
Expand Down
57 changes: 57 additions & 0 deletions Lib/test/test_interpreters/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,63 @@ def task():
self.assertEqual(os.read(r_interp, 1), FINISHED)


class TestInterpreterPrepareMain(TestBase):

def test_empty(self):
interp = interpreters.create()
with self.assertRaises(ValueError):
interp.prepare_main()

def test_dict(self):
values = {'spam': 42, 'eggs': 'ham'}
interp = interpreters.create()
interp.prepare_main(values)
out = _run_output(interp, dedent("""
print(spam, eggs)
"""))
self.assertEqual(out.strip(), '42 ham')

def test_tuple(self):
values = {'spam': 42, 'eggs': 'ham'}
values = tuple(values.items())
interp = interpreters.create()
interp.prepare_main(values)
out = _run_output(interp, dedent("""
print(spam, eggs)
"""))
self.assertEqual(out.strip(), '42 ham')

def test_kwargs(self):
values = {'spam': 42, 'eggs': 'ham'}
interp = interpreters.create()
interp.prepare_main(**values)
out = _run_output(interp, dedent("""
print(spam, eggs)
"""))
self.assertEqual(out.strip(), '42 ham')

def test_dict_and_kwargs(self):
values = {'spam': 42, 'eggs': 'ham'}
interp = interpreters.create()
interp.prepare_main(values, foo='bar')
out = _run_output(interp, dedent("""
print(spam, eggs, foo)
"""))
self.assertEqual(out.strip(), '42 ham bar')

def test_not_shareable(self):
interp = interpreters.create()
# XXX TypeError?
with self.assertRaises(ValueError):
interp.prepare_main(spam={'spam': 'eggs', 'foo': 'bar'})

# Make sure neither was actually bound.
with self.assertRaises(interpreters.ExecFailure):
interp.exec_sync('print(foo)')
with self.assertRaises(interpreters.ExecFailure):
interp.exec_sync('print(spam)')


class TestInterpreterExecSync(TestBase):

def test_success(self):
Expand Down
6 changes: 4 additions & 2 deletions Lib/test/test_interpreters/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,12 @@ def clean_up_interpreters():
pass # already destroyed


def _run_output(interp, request, channels=None):
def _run_output(interp, request, init=None):
script, rpipe = _captured_script(request)
with rpipe:
interp.exec_sync(script, channels=channels)
if init:
interp.prepare_main(init)
interp.exec_sync(script)
return rpipe.read()


Expand Down
56 changes: 56 additions & 0 deletions Modules/_xxsubinterpretersmodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -685,6 +685,60 @@ PyDoc_STRVAR(get_main_doc,
\n\
Return the ID of main interpreter.");

static PyObject *
interp_set___main___attrs(PyObject *self, PyObject *args)
{
PyObject *id, *updates;
if (!PyArg_ParseTuple(args, "OO:" MODULE_NAME ".set___main___attrs",
&id, &updates))
{
return NULL;
}

// Look up the interpreter.
PyInterpreterState *interp = PyInterpreterID_LookUp(id);
if (interp == NULL) {
return NULL;
}

// Check the updates.
if (updates != Py_None) {
Py_ssize_t size = PyObject_Size(updates);
if (size < 0) {
return NULL;
}
if (size == 0) {
PyErr_SetString(PyExc_ValueError,
"arg 2 must be a non-empty mapping");
return NULL;
}
}

_PyXI_session session = {0};

// Prep and switch interpreters, including apply the updates.
if (_PyXI_Enter(&session, interp, updates) < 0) {
if (!PyErr_Occurred()) {
_PyXI_ApplyCapturedException(&session);
assert(PyErr_Occurred());
}
else {
assert(!_PyXI_HasCapturedException(&session));
}
return NULL;
}

// Clean up and switch back.
_PyXI_Exit(&session);

Py_RETURN_NONE;
}

PyDoc_STRVAR(set___main___attrs_doc,
"set___main___attrs(id, ns)\n\
\n\
Bind the given attributes in the interpreter's __main__ module.");

static PyUnicodeObject *
convert_script_arg(PyObject *arg, const char *fname, const char *displayname,
const char *expected)
Expand Down Expand Up @@ -1033,6 +1087,8 @@ static PyMethodDef module_functions[] = {
{"run_func", _PyCFunction_CAST(interp_run_func),
METH_VARARGS | METH_KEYWORDS, run_func_doc},

{"set___main___attrs", _PyCFunction_CAST(interp_set___main___attrs),
METH_VARARGS, set___main___attrs_doc},
{"is_shareable", _PyCFunction_CAST(object_is_shareable),
METH_VARARGS | METH_KEYWORDS, is_shareable_doc},

Expand Down

0 comments on commit 9898e61

Please sign in to comment.