Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gh-76785: Add Interpreter.prepare_main() #113021

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions Lib/test/support/interpreters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,15 @@ def close(self):
"""
return _interpreters.destroy(self._id)

def exec_sync(self, code, /, channels=None):
def prepare_main(self, ns=None, /, **kwargs):
"""Bind the given values into the interpreter's __main__.

The values must be shareable.
"""
ns = dict(ns, **kwargs) if ns is not None else kwargs
_interpreters.set___main___attrs(self._id, ns)

def exec_sync(self, code, /):
"""Run the given source code in the interpreter.

This is essentially the same as calling the builtin "exec"
Expand All @@ -148,13 +156,13 @@ def exec_sync(self, code, /, channels=None):
that time, the previous interpreter is allowed to run
in other threads.
"""
excinfo = _interpreters.exec(self._id, code, channels)
excinfo = _interpreters.exec(self._id, code)
if excinfo is not None:
raise ExecFailure(excinfo)

def run(self, code, /, channels=None):
def run(self, code, /):
def task():
self.exec_sync(code, channels=channels)
self.exec_sync(code)
t = threading.Thread(target=task)
t.start()
return t
4 changes: 2 additions & 2 deletions Lib/test/test__xxinterpchannels.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,12 +586,12 @@ def test_run_string_arg_unresolved(self):
cid = channels.create()
interp = interpreters.create()

interpreters.set___main___attrs(interp, dict(cid=cid.send))
out = _run_output(interp, dedent("""
import _xxinterpchannels as _channels
print(cid.end)
_channels.send(cid, b'spam', blocking=False)
"""),
dict(cid=cid.send))
"""))
obj = channels.recv(cid)

self.assertEqual(obj, b'spam')
Expand Down
24 changes: 15 additions & 9 deletions Lib/test/test__xxsubinterpreters.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@ def _captured_script(script):
return wrapped, open(r, encoding="utf-8")


def _run_output(interp, request, shared=None):
def _run_output(interp, request):
script, rpipe = _captured_script(request)
with rpipe:
interpreters.run_string(interp, script, shared)
interpreters.run_string(interp, script)
return rpipe.read()


Expand Down Expand Up @@ -630,10 +630,10 @@ def test_shareable_types(self):
]
for obj in objects:
with self.subTest(obj):
interpreters.set___main___attrs(interp, dict(obj=obj))
interpreters.run_string(
interp,
f'assert(obj == {obj!r})',
shared=dict(obj=obj),
)

def test_os_exec(self):
Expand Down Expand Up @@ -721,7 +721,8 @@ def test_with_shared(self):
with open({w}, 'wb') as chan:
pickle.dump(ns, chan)
""")
interpreters.run_string(self.id, script, shared)
interpreters.set___main___attrs(self.id, shared)
interpreters.run_string(self.id, script)
with open(r, 'rb') as chan:
ns = pickle.load(chan)

Expand All @@ -742,7 +743,8 @@ def test_shared_overwrites(self):
ns2 = dict(vars())
del ns2['__builtins__']
""")
interpreters.run_string(self.id, script, shared)
interpreters.set___main___attrs(self.id, shared)
interpreters.run_string(self.id, script)

r, w = os.pipe()
script = dedent(f"""
Expand Down Expand Up @@ -773,7 +775,8 @@ def test_shared_overwrites_default_vars(self):
with open({w}, 'wb') as chan:
pickle.dump(ns, chan)
""")
interpreters.run_string(self.id, script, shared)
interpreters.set___main___attrs(self.id, shared)
interpreters.run_string(self.id, script)
with open(r, 'rb') as chan:
ns = pickle.load(chan)

Expand Down Expand Up @@ -1036,7 +1039,8 @@ def script():
with open(w, 'w', encoding="utf-8") as spipe:
with contextlib.redirect_stdout(spipe):
print('it worked!', end='')
interpreters.run_func(self.id, script, shared=dict(w=w))
interpreters.set___main___attrs(self.id, dict(w=w))
interpreters.run_func(self.id, script)

with open(r, encoding="utf-8") as outfile:
out = outfile.read()
Expand All @@ -1052,7 +1056,8 @@ def script():
with contextlib.redirect_stdout(spipe):
print('it worked!', end='')
def f():
interpreters.run_func(self.id, script, shared=dict(w=w))
interpreters.set___main___attrs(self.id, dict(w=w))
interpreters.run_func(self.id, script)
t = threading.Thread(target=f)
t.start()
t.join()
Expand All @@ -1072,7 +1077,8 @@ def script():
with contextlib.redirect_stdout(spipe):
print('it worked!', end='')
code = script.__code__
interpreters.run_func(self.id, code, shared=dict(w=w))
interpreters.set___main___attrs(self.id, dict(w=w))
interpreters.run_func(self.id, code)

with open(r, encoding="utf-8") as outfile:
out = outfile.read()
Expand Down
57 changes: 57 additions & 0 deletions Lib/test/test_interpreters/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,63 @@ def task():
self.assertEqual(os.read(r_interp, 1), FINISHED)


class TestInterpreterPrepareMain(TestBase):

def test_empty(self):
interp = interpreters.create()
with self.assertRaises(ValueError):
interp.prepare_main()

def test_dict(self):
values = {'spam': 42, 'eggs': 'ham'}
interp = interpreters.create()
interp.prepare_main(values)
out = _run_output(interp, dedent("""
print(spam, eggs)
"""))
self.assertEqual(out.strip(), '42 ham')

def test_tuple(self):
values = {'spam': 42, 'eggs': 'ham'}
values = tuple(values.items())
interp = interpreters.create()
interp.prepare_main(values)
out = _run_output(interp, dedent("""
print(spam, eggs)
"""))
self.assertEqual(out.strip(), '42 ham')

def test_kwargs(self):
values = {'spam': 42, 'eggs': 'ham'}
interp = interpreters.create()
interp.prepare_main(**values)
out = _run_output(interp, dedent("""
print(spam, eggs)
"""))
self.assertEqual(out.strip(), '42 ham')

def test_dict_and_kwargs(self):
values = {'spam': 42, 'eggs': 'ham'}
interp = interpreters.create()
interp.prepare_main(values, foo='bar')
out = _run_output(interp, dedent("""
print(spam, eggs, foo)
"""))
self.assertEqual(out.strip(), '42 ham bar')

def test_not_shareable(self):
interp = interpreters.create()
# XXX TypeError?
with self.assertRaises(ValueError):
interp.prepare_main(spam={'spam': 'eggs', 'foo': 'bar'})

# Make sure neither was actually bound.
with self.assertRaises(interpreters.ExecFailure):
interp.exec_sync('print(foo)')
with self.assertRaises(interpreters.ExecFailure):
interp.exec_sync('print(spam)')


class TestInterpreterExecSync(TestBase):

def test_success(self):
Expand Down
6 changes: 4 additions & 2 deletions Lib/test/test_interpreters/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,12 @@ def clean_up_interpreters():
pass # already destroyed


def _run_output(interp, request, channels=None):
def _run_output(interp, request, init=None):
script, rpipe = _captured_script(request)
with rpipe:
interp.exec_sync(script, channels=channels)
if init:
interp.prepare_main(init)
interp.exec_sync(script)
return rpipe.read()


Expand Down
56 changes: 56 additions & 0 deletions Modules/_xxsubinterpretersmodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -685,6 +685,60 @@ PyDoc_STRVAR(get_main_doc,
\n\
Return the ID of main interpreter.");

static PyObject *
interp_set___main___attrs(PyObject *self, PyObject *args)
{
PyObject *id, *updates;
if (!PyArg_ParseTuple(args, "OO:" MODULE_NAME ".set___main___attrs",
&id, &updates))
{
return NULL;
}

// Look up the interpreter.
PyInterpreterState *interp = PyInterpreterID_LookUp(id);
if (interp == NULL) {
return NULL;
}

// Check the updates.
if (updates != Py_None) {
Py_ssize_t size = PyObject_Size(updates);
if (size < 0) {
return NULL;
}
if (size == 0) {
PyErr_SetString(PyExc_ValueError,
"arg 2 must be a non-empty mapping");
return NULL;
}
}

_PyXI_session session = {0};

// Prep and switch interpreters, including apply the updates.
if (_PyXI_Enter(&session, interp, updates) < 0) {
if (!PyErr_Occurred()) {
_PyXI_ApplyCapturedException(&session);
assert(PyErr_Occurred());
}
else {
assert(!_PyXI_HasCapturedException(&session));
}
return NULL;
}

// Clean up and switch back.
_PyXI_Exit(&session);

Py_RETURN_NONE;
}

PyDoc_STRVAR(set___main___attrs_doc,
"set___main___attrs(id, ns)\n\
\n\
Bind the given attributes in the interpreter's __main__ module.");

static PyUnicodeObject *
convert_script_arg(PyObject *arg, const char *fname, const char *displayname,
const char *expected)
Expand Down Expand Up @@ -1033,6 +1087,8 @@ static PyMethodDef module_functions[] = {
{"run_func", _PyCFunction_CAST(interp_run_func),
METH_VARARGS | METH_KEYWORDS, run_func_doc},

{"set___main___attrs", _PyCFunction_CAST(interp_set___main___attrs),
METH_VARARGS, set___main___attrs_doc},
{"is_shareable", _PyCFunction_CAST(object_is_shareable),
METH_VARARGS | METH_KEYWORDS, is_shareable_doc},

Expand Down
Loading