From 95f5b3f109e951ab96616591c8de258544974747 Mon Sep 17 00:00:00 2001 From: Min RK Date: Tue, 3 Oct 2017 19:48:00 +0200 Subject: [PATCH] restore actual zmq channels when resuming connection rather than establishing new connections fixes failure to resume shell channel --- notebook/services/kernels/handlers.py | 41 ++++++++++++---------- notebook/services/kernels/kernelmanager.py | 26 ++++++++++++++ 2 files changed, 48 insertions(+), 19 deletions(-) diff --git a/notebook/services/kernels/handlers.py b/notebook/services/kernels/handlers.py index dfdd8a3d652..359b5e2dd8d 100644 --- a/notebook/services/kernels/handlers.py +++ b/notebook/services/kernels/handlers.py @@ -261,25 +261,28 @@ def open(self, kernel_id): self.kernel_manager.notify_connect(kernel_id) # on new connections, flush the message buffer - replay_buffer = self.kernel_manager.stop_buffering(kernel_id, self.session_key) - - try: - self.create_stream() - except web.HTTPError as e: - self.log.error("Error opening stream: %s", e) - # WebSockets don't response to traditional error codes so we - # close the connection. - for channel, stream in self.channels.items(): - if not stream.closed(): - stream.close() - self.close() - return - - if replay_buffer: - self.log.info("Replaying %s buffered messages", len(replay_buffer)) - for channel, msg_list in replay_buffer: - stream = self.channels[channel] - self._on_zmq_reply(stream, msg_list) + buffer_info = self.kernel_manager.get_buffer(kernel_id, self.session_key) + if buffer_info: + self.log.info("Restoring connection for %s", self.session_key) + self.channels = buffer_info['channels'] + replay_buffer = buffer_info['buffer'] + if replay_buffer: + self.log.info("Replaying %s buffered messages", len(replay_buffer)) + for channel, msg_list in replay_buffer: + stream = self.channels[channel] + self._on_zmq_reply(stream, msg_list) + else: + try: + self.create_stream() + except web.HTTPError as e: + self.log.error("Error opening stream: %s", e) + # WebSockets don't response to traditional error codes so we + # close the connection. + for channel, stream in self.channels.items(): + if not stream.closed(): + stream.close() + self.close() + return for channel, stream in self.channels.items(): stream.on_recv_stream(self._on_zmq_reply) diff --git a/notebook/services/kernels/kernelmanager.py b/notebook/services/kernels/kernelmanager.py index ed733a9c1c0..9c0924a06ea 100644 --- a/notebook/services/kernels/kernelmanager.py +++ b/notebook/services/kernels/kernelmanager.py @@ -182,6 +182,32 @@ def buffer_msg(channel, msg_parts): for channel, stream in channels.items(): stream.on_recv(partial(buffer_msg, channel)) + + def get_buffer(self, kernel_id, session_key): + """Get the buffer for a given kernel + + Parameters + ---------- + kernel_id : str + The id of the kernel to stop buffering. + session_key: str, optional + The session_key, if any, that should get the buffer. + If the session_key matches the current buffered session_key, + the buffer will be returned. + """ + self.log.debug("Getting buffer for %s", kernel_id) + if kernel_id not in self._kernel_buffers: + return + + buffer_info = self._kernel_buffers[kernel_id] + if buffer_info['session_key'] == session_key: + # remove buffer + self._kernel_buffers.pop(kernel_id) + # only return buffer_info if it's a match + return buffer_info + else: + self.stop_buffering(kernel_id) + def stop_buffering(self, kernel_id, session_key=None): """Stop buffering kernel messages