Skip to content

Commit

Permalink
Fix in-process kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
davidbrochart committed Feb 28, 2023
1 parent 7f1c67e commit 9363910
Show file tree
Hide file tree
Showing 18 changed files with 301 additions and 226 deletions.
40 changes: 4 additions & 36 deletions examples/embedding/inprocess_terminal.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
"""An in-process terminal example."""
import os
import sys

import tornado
from anyio import run
from jupyter_console.ptshell import ZMQTerminalInteractiveShell

from ipykernel.inprocess.manager import InProcessKernelManager
Expand All @@ -13,46 +12,15 @@ def print_process_id():
print("Process ID is:", os.getpid())


def init_asyncio_patch():
"""set default asyncio policy to be compatible with tornado
Tornado 6 (at least) is not compatible with the default
asyncio implementation on Windows
Pick the older SelectorEventLoopPolicy on Windows
if the known-incompatible default policy is in use.
do this as early as possible to make it a low priority and overrideable
ref: https://github.com/tornadoweb/tornado/issues/2608
FIXME: if/when tornado supports the defaults in asyncio,
remove and bump tornado requirement for py38
"""
if (
sys.platform.startswith("win")
and sys.version_info >= (3, 8)
and tornado.version_info < (6, 1)
):
import asyncio

try:
from asyncio import WindowsProactorEventLoopPolicy, WindowsSelectorEventLoopPolicy
except ImportError:
pass
# not affected
else:
if type(asyncio.get_event_loop_policy()) is WindowsProactorEventLoopPolicy:
# WindowsProactorEventLoopPolicy is not compatible with tornado 6
# fallback to the pre-3.8 default of Selector
asyncio.set_event_loop_policy(WindowsSelectorEventLoopPolicy())


def main():
async def main():
"""The main function."""
print_process_id()

# Create an in-process kernel
# >>> print_process_id()
# will print the same process ID as the main process
init_asyncio_patch()
kernel_manager = InProcessKernelManager()
kernel_manager.start_kernel()
await kernel_manager.start_kernel()
kernel = kernel_manager.kernel
kernel.gui = "qt4"
kernel.shell.push({"foo": 43, "print_process_id": print_process_id})
Expand All @@ -64,4 +32,4 @@ def main():


if __name__ == "__main__":
main()
run(main)
5 changes: 2 additions & 3 deletions ipykernel/inprocess/blocking.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,10 @@ class BlockingInProcessKernelClient(InProcessKernelClient):
iopub_channel_class = Type(BlockingInProcessChannel)
stdin_channel_class = Type(BlockingInProcessStdInChannel)

def wait_for_ready(self):
async def wait_for_ready(self):
"""Wait for kernel info reply on shell channel."""
while True:
self.kernel_info()
await self.kernel_info()
try:
msg = self.shell_channel.get_msg(block=True, timeout=1)
except Empty:
Expand All @@ -103,6 +103,5 @@ def wait_for_ready(self):
while True:
try:
msg = self.iopub_channel.get_msg(block=True, timeout=0.2)
print(msg["msg_type"])
except Empty:
break
48 changes: 19 additions & 29 deletions ipykernel/inprocess/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,9 @@
# Imports
# -----------------------------------------------------------------------------

import asyncio

from jupyter_client.client import KernelClient
from jupyter_client.clientabc import KernelClientABC
from jupyter_core.utils import run_sync

# IPython imports
from traitlets import Instance, Type, default
Expand Down Expand Up @@ -101,7 +99,7 @@ def hb_channel(self):
# Methods for sending specific messages
# -------------------------------------

def execute(
async def execute(
self, code, silent=False, store_history=True, user_expressions=None, allow_stdin=None
):
"""Execute code on the client."""
Expand All @@ -115,19 +113,19 @@ def execute(
allow_stdin=allow_stdin,
)
msg = self.session.msg("execute_request", content)
self._dispatch_to_kernel(msg)
await self._dispatch_to_kernel(msg)
return msg["header"]["msg_id"]

def complete(self, code, cursor_pos=None):
async def complete(self, code, cursor_pos=None):
"""Get code completion."""
if cursor_pos is None:
cursor_pos = len(code)
content = dict(code=code, cursor_pos=cursor_pos)
msg = self.session.msg("complete_request", content)
self._dispatch_to_kernel(msg)
await self._dispatch_to_kernel(msg)
return msg["header"]["msg_id"]

def inspect(self, code, cursor_pos=None, detail_level=0):
async def inspect(self, code, cursor_pos=None, detail_level=0):
"""Get code inspection."""
if cursor_pos is None:
cursor_pos = len(code)
Expand All @@ -137,14 +135,14 @@ def inspect(self, code, cursor_pos=None, detail_level=0):
detail_level=detail_level,
)
msg = self.session.msg("inspect_request", content)
self._dispatch_to_kernel(msg)
await self._dispatch_to_kernel(msg)
return msg["header"]["msg_id"]

def history(self, raw=True, output=False, hist_access_type="range", **kwds):
async def history(self, raw=True, output=False, hist_access_type="range", **kwds):
"""Get code history."""
content = dict(raw=raw, output=output, hist_access_type=hist_access_type, **kwds)
msg = self.session.msg("history_request", content)
self._dispatch_to_kernel(msg)
await self._dispatch_to_kernel(msg)
return msg["header"]["msg_id"]

def shutdown(self, restart=False):
Expand All @@ -153,17 +151,17 @@ def shutdown(self, restart=False):
msg = "Cannot shutdown in-process kernel"
raise NotImplementedError(msg)

def kernel_info(self):
async def kernel_info(self):
"""Request kernel info."""
msg = self.session.msg("kernel_info_request")
self._dispatch_to_kernel(msg)
await self._dispatch_to_kernel(msg)
return msg["header"]["msg_id"]

def comm_info(self, target_name=None):
async def comm_info(self, target_name=None):
"""Request a dictionary of valid comms and their targets."""
content = {} if target_name is None else dict(target_name=target_name)
msg = self.session.msg("comm_info_request", content)
self._dispatch_to_kernel(msg)
await self._dispatch_to_kernel(msg)
return msg["header"]["msg_id"]

def input(self, string):
Expand All @@ -173,29 +171,21 @@ def input(self, string):
raise RuntimeError(msg)
self.kernel.raw_input_str = string

def is_complete(self, code):
async def is_complete(self, code):
"""Handle an is_complete request."""
msg = self.session.msg("is_complete_request", {"code": code})
self._dispatch_to_kernel(msg)
await self._dispatch_to_kernel(msg)
return msg["header"]["msg_id"]

def _dispatch_to_kernel(self, msg):
async def _dispatch_to_kernel(self, msg):
"""Send a message to the kernel and handle a reply."""
kernel = self.kernel
if kernel is None:
msg = "Cannot send request. No kernel exists."
raise RuntimeError(msg)
error_message = "Cannot send request. No kernel exists."
raise RuntimeError(error_message)

stream = kernel.shell_stream
self.session.send(stream, msg)
msg_parts = stream.recv_multipart()
if run_sync is not None:
dispatch_shell = run_sync(kernel.dispatch_shell)
dispatch_shell(msg_parts)
else:
loop = asyncio.get_event_loop()
loop.run_until_complete(kernel.dispatch_shell(msg_parts))
idents, reply_msg = self.session.recv(stream, copy=False)
kernel.shell_socket.put(msg)
reply_msg = await kernel.shell_socket.get()
self.shell_channel.call_handlers_later(reply_msg)

def get_shell_msg(self, block=True, timeout=None):
Expand Down
29 changes: 19 additions & 10 deletions ipykernel/inprocess/ipkernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import sys
from contextlib import contextmanager

from anyio import TASK_STATUS_IGNORED
from anyio.abc import TaskStatus
from IPython.core.interactiveshell import InteractiveShellABC
from traitlets import Any, Enum, Instance, List, Type, default

Expand Down Expand Up @@ -48,10 +50,10 @@ class InProcessKernel(IPythonKernel):
# -------------------------------------------------------------------------

shell_class = Type(allow_none=True)
_underlying_iopub_socket = Instance(DummySocket, ())
_underlying_iopub_socket = Instance(DummySocket, (False,))
iopub_thread: IOPubThread = Instance(IOPubThread) # type:ignore[assignment]

shell_stream = Instance(DummySocket, ())
shell_socket = Instance(DummySocket, (True,))

@default("iopub_thread")
def _default_iopub_thread(self):
Expand All @@ -65,23 +67,27 @@ def _default_iopub_thread(self):
def _default_iopub_socket(self):
return self.iopub_thread.background_socket

stdin_socket = Instance(DummySocket, ()) # type:ignore[assignment]
stdin_socket = Instance(DummySocket, (False,)) # type:ignore[assignment]

def __init__(self, **traits):
"""Initialize the kernel."""
super().__init__(**traits)

self._underlying_iopub_socket.observe(self._io_dispatch, names=["message_sent"])
self._io_dispatch()
self.shell.kernel = self

async def execute_request(self, stream, ident, parent):
"""Override for temporary IO redirection."""
with self._redirected_io():
await super().execute_request(stream, ident, parent)

def start(self):
async def start(self, *, task_status: TaskStatus = TASK_STATUS_IGNORED) -> None:
"""Override registration of dispatchers for streams."""
self.shell.exit_now = False
await super().start(task_status=task_status)

def stop(self):
super().stop()

def _abort_queues(self):
"""The in-process kernel doesn't abort requests."""
Expand Down Expand Up @@ -127,12 +133,15 @@ def _redirected_io(self):

# ------ Trait change handlers --------------------------------------------

def _io_dispatch(self, change):
def _io_dispatch(self):
"""Called when a message is sent to the IO socket."""
assert self.iopub_socket.io_thread is not None
ident, msg = self.session.recv(self.iopub_socket.io_thread.socket, copy=False)
for frontend in self.frontends:
frontend.iopub_channel.call_handlers(msg)

def callback(msg):
for frontend in self.frontends:
frontend.iopub_channel.call_handlers(msg)

self.iopub_thread.socket.on_recv = callback

# ------ Trait initializers -----------------------------------------------

Expand All @@ -142,7 +151,7 @@ def _default_log(self):

@default("session")
def _default_session(self):
from jupyter_client.session import Session
from .session import Session

return Session(parent=self, key=INPROCESS_KEY)

Expand Down
14 changes: 10 additions & 4 deletions ipykernel/inprocess/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
# Copyright (c) IPython Development Team.
# Distributed under the terms of the Modified BSD License.

from anyio import TASK_STATUS_IGNORED
from anyio.abc import TaskStatus
from jupyter_client.manager import KernelManager
from jupyter_client.managerabc import KernelManagerABC
from jupyter_client.session import Session
from traitlets import DottedObjectName, Instance, default

from .constants import INPROCESS_KEY
from .session import Session


class InProcessKernelManager(KernelManager):
Expand Down Expand Up @@ -41,27 +43,31 @@ def _default_session(self):
# Kernel management methods
# --------------------------------------------------------------------------

def start_kernel(self, **kwds):
async def start_kernel(self, *, task_status: TaskStatus = TASK_STATUS_IGNORED, **kwds) -> None:
"""Start the kernel."""
from ipykernel.inprocess.ipkernel import InProcessKernel

self.kernel = InProcessKernel(parent=self, session=self.session)
await self.kernel.start(task_status=task_status)

def shutdown_kernel(self):
"""Shutdown the kernel."""
self.kernel.iopub_thread.stop()
self._kill_kernel()

def restart_kernel(self, now=False, **kwds):
async def restart_kernel(
self, now=False, *, task_status: TaskStatus = TASK_STATUS_IGNORED, **kwds
) -> None:
"""Restart the kernel."""
self.shutdown_kernel()
self.start_kernel(**kwds)
await self.start_kernel(task_status=task_status, **kwds)

@property
def has_kernel(self):
return self.kernel is not None

def _kill_kernel(self):
self.kernel.stop()
self.kernel = None

def interrupt_kernel(self):
Expand Down
41 changes: 41 additions & 0 deletions ipykernel/inprocess/session.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from jupyter_client.session import Session as _Session


class Session(_Session):
async def recv(self, socket, copy=True):
return await socket.recv_multipart()

def send(
self,
socket,
msg_or_type,
content=None,
parent=None,
ident=None,
buffers=None,
track=False,
header=None,
metadata=None,
):
if isinstance(msg_or_type, str):
msg = self.msg(
msg_or_type,
content=content,
parent=parent,
header=header,
metadata=metadata,
)
else:
# We got a Message or message dict, not a msg_type so don't
# build a new Message.
msg = msg_or_type
buffers = buffers or msg.get("buffers", [])

socket.send_multipart(msg)
return msg

def feed_identities(self, msg, copy=True):
return "", msg

def deserialize(self, msg, content=True, copy=True):
return msg
Loading

0 comments on commit 9363910

Please sign in to comment.