Skip to content

Commit

Permalink
Merge pull request #479 from ganisback/support-stream
Browse files Browse the repository at this point in the history
feat: support stream api
  • Loading branch information
yuvipanda authored Aug 29, 2024
2 parents b689a4d + e201ffa commit 7f78e1b
Show file tree
Hide file tree
Showing 4 changed files with 168 additions and 4 deletions.
96 changes: 93 additions & 3 deletions jupyter_server_proxy/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""

import os
import re
import socket
from asyncio import Lock
from copy import copy
Expand Down Expand Up @@ -287,7 +288,7 @@ def get_client_uri(self, protocol, host, port, proxied_path):

return client_uri

def _build_proxy_request(self, host, port, proxied_path, body):
def _build_proxy_request(self, host, port, proxied_path, body, **extra_opts):
headers = self.proxy_request_headers()

client_uri = self.get_client_uri("http", host, port, proxied_path)
Expand All @@ -307,6 +308,7 @@ def _build_proxy_request(self, host, port, proxied_path, body):
decompress_response=False,
headers=headers,
**self.proxy_request_options(),
**extra_opts,
)
return req

Expand Down Expand Up @@ -365,7 +367,6 @@ async def proxy(self, host, port, proxied_path):
body = b""
else:
body = None

if self.unix_socket is not None:
# Port points to a Unix domain socket
self.log.debug("Making client for Unix socket %r", self.unix_socket)
Expand All @@ -374,8 +375,97 @@ async def proxy(self, host, port, proxied_path):
force_instance=True, resolver=UnixResolver(self.unix_socket)
)
else:
client = httpclient.AsyncHTTPClient()
client = httpclient.AsyncHTTPClient(force_instance=True)
# check if the request is stream request
accept_header = self.request.headers.get("Accept")
if accept_header == "text/event-stream":
return await self._proxy_progressive(host, port, proxied_path, body, client)
else:
return await self._proxy_buffered(host, port, proxied_path, body, client)

async def _proxy_progressive(self, host, port, proxied_path, body, client):
# Proxy in progressive flush mode, whenever chunks are received. Potentially slower but get results quicker for voila
# Set up handlers so we can progressively flush result

headers_raw = []

def dump_headers(headers_raw):
for line in headers_raw:
r = re.match("^([a-zA-Z0-9\\-_]+)\\s*\\:\\s*([^\r\n]+)[\r\n]*$", line)
if r:
k, v = r.groups([1, 2])
if k not in (
"Content-Length",
"Transfer-Encoding",
"Content-Encoding",
"Connection",
):
# some header appear multiple times, eg 'Set-Cookie'
self.set_header(k, v)
else:
r = re.match(r"^HTTP[^\s]* ([0-9]+)", line)
if r:
status_code = r.group(1)
self.set_status(int(status_code))
headers_raw.clear()

# clear tornado default header
self._headers = httputil.HTTPHeaders()

def header_callback(line):
headers_raw.append(line)

def streaming_callback(chunk):
# record activity at start and end of requests
self._record_activity()
# Do this here, not in header_callback so we can be sure headers are out of the way first
dump_headers(
headers_raw
) # array will be empty if this was already called before
self.write(chunk)
self.flush()

# Now make the request

req = self._build_proxy_request(
host,
port,
proxied_path,
body,
streaming_callback=streaming_callback,
header_callback=header_callback,
)

# no timeout for stream api
req.request_timeout = 7200
req.connect_timeout = 600

try:
response = await client.fetch(req, raise_error=False)
except httpclient.HTTPError as err:
if err.code == 599:
self._record_activity()
self.set_status(599)
self.write(str(err))
return
else:
raise

# For all non http errors...
if response.error and type(response.error) is not httpclient.HTTPError:
self.set_status(500)
self.write(str(response.error))
else:
self.set_status(
response.code, response.reason
) # Should already have been set

dump_headers(headers_raw) # Should already have been emptied

if response.body: # Likewise, should already be chunked out and flushed
self.write(response.body)

async def _proxy_buffered(self, host, port, proxied_path, body, client):
req = self._build_proxy_request(host, port, proxied_path, body)

self.log.debug(f"Proxying request to {req.url}")
Expand Down
36 changes: 36 additions & 0 deletions tests/resources/eventstream.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import asyncio

import tornado.escape
import tornado.ioloop
import tornado.options
import tornado.web
import tornado.websocket
from tornado.options import define, options


class Application(tornado.web.Application):
def __init__(self):
handlers = [
(r"/stream/(\d+)", StreamHandler),
]
super().__init__(handlers)


class StreamHandler(tornado.web.RequestHandler):
async def get(self, seconds):
for i in range(int(seconds)):
await asyncio.sleep(0.5)
self.write(f"data: {i}\n\n")
await self.flush()


def main():
define("port", default=8888, help="run on the given port", type=int)
options.parse_command_line()
app = Application()
app.listen(options.port)
tornado.ioloop.IOLoop.current().start()


if __name__ == "__main__":
main()
3 changes: 3 additions & 0 deletions tests/resources/jupyter_server_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ def my_env():
"X-Custom-Header": "pytest-23456",
},
},
"python-eventstream": {
"command": [sys.executable, "./tests/resources/eventstream.py", "--port={port}"]
},
"python-unix-socket-true": {
"command": [
sys.executable,
Expand Down
37 changes: 36 additions & 1 deletion tests/test_proxies.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import gzip
import json
import sys
import time
from http.client import HTTPConnection
from io import BytesIO
from typing import Tuple
from urllib.parse import quote

import pytest
from tornado.httpclient import HTTPClientError
from tornado.httpclient import AsyncHTTPClient, HTTPClientError
from tornado.websocket import websocket_connect

# use ipv4 for CI, etc.
Expand Down Expand Up @@ -343,6 +344,40 @@ def test_server_content_encoding_header(
assert f.read() == b"this is a test"


async def test_eventstream(a_server_port_and_token: Tuple[int, str]) -> None:
PORT, TOKEN = a_server_port_and_token
# The test server under eventstream.py will send back monotonically increasing numbers
# starting at 0 until the specified limit, with a 500ms gap between them. We test that:
# 1. We get back as many callbacks from our streaming read as the total number,
# as the server does a flush after writing each entry.
# 2. The streaming entries are read (with some error margin) around the 500ms mark, to
# ensure this is *actually* being streamed
limit = 3
last_cb_time = time.perf_counter()
times_called = 0
stream_read_intervals = []
stream_data = []

def streaming_cb(data):
nonlocal times_called, last_cb_time, stream_read_intervals
time_taken = time.perf_counter() - last_cb_time
last_cb_time = time.perf_counter()
stream_read_intervals.append(time_taken)
times_called += 1
stream_data.append(data)

url = f"http://{LOCALHOST}:{PORT}/python-eventstream/stream/{limit}?token={TOKEN}"
client = AsyncHTTPClient()
await client.fetch(
url,
headers={"Accept": "text/event-stream"},
streaming_callback=streaming_cb,
)
assert times_called == limit
assert all([0.45 < t < 3.0 for t in stream_read_intervals])
assert stream_data == [b"data: 0\n\n", b"data: 1\n\n", b"data: 2\n\n"]


async def test_server_proxy_websocket_messages(
a_server_port_and_token: Tuple[int, str]
) -> None:
Expand Down

0 comments on commit 7f78e1b

Please sign in to comment.