Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix streaming, add minimal tests #592

Merged
merged 1 commit into from
Jan 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions packages/jupyter-ai/jupyter_ai/completions/handlers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ async def handle_exc(self, e: Exception, request: InlineCompletionRequest):
"""
Handles an exception raised in either `handle_request()` or
`handle_stream_request()`. This base class provides a default
implementation, which may be overriden by subclasses.
implementation, which may be overridden by subclasses.
"""
error = CompletionError(
type=e.__class__.__name__,
Expand All @@ -162,8 +162,6 @@ async def _handle_request(self, request: InlineCompletionRequest):
async def _handle_stream_request(self, request: InlineCompletionRequest):
"""Private wrapper around `self.handle_stream_request()`."""
start = time.time()
await self._handle_stream_request(request)
async for chunk in self.stream(request):
self.write_message(chunk.dict())
await self.handle_stream_request(request)
latency_ms = round((time.time() - start) * 1000)
self.log.info(f"Inline completion streaming completed in {latency_ms} ms.")
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,7 @@ def create_llm_chain(
self.llm = llm
self.llm_chain = prompt_template | llm | StrOutputParser()

async def handle_request(
self, request: InlineCompletionRequest
) -> InlineCompletionReply:
async def handle_request(self, request: InlineCompletionRequest) -> None:
"""Handles an inline completion request without streaming."""
self.get_llm_chain()
model_arguments = self._template_inputs_from_request(request)
Expand Down Expand Up @@ -111,7 +109,7 @@ def _write_incomplete_reply(self, request: InlineCompletionRequest):

async def handle_stream_request(self, request: InlineCompletionRequest):
# first, send empty initial reply.
self._write_incomplete_reply()
self._write_incomplete_reply(request)

# then, generate and stream LLM output over this connection.
self.get_llm_chain()
Expand Down
Empty file.
116 changes: 116 additions & 0 deletions packages/jupyter-ai/jupyter_ai/tests/completions/test_handlers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import json
from types import SimpleNamespace

from jupyter_ai.completions.handlers.default import DefaultInlineCompletionHandler
from jupyter_ai.completions.models import InlineCompletionRequest
from jupyter_ai_magics import BaseProvider
from langchain_community.llms import FakeListLLM
from pytest import fixture
from tornado.httputil import HTTPServerRequest
from tornado.web import Application


class MockProvider(BaseProvider, FakeListLLM):
id = "my_provider"
name = "My Provider"
model_id_key = "model"
models = ["model"]

def __init__(self, **kwargs):
kwargs["responses"] = ["Test response"]
super().__init__(**kwargs)


class MockCompletionHandler(DefaultInlineCompletionHandler):
def __init__(self):
self.request = HTTPServerRequest()
self.application = Application()
self.messages = []
self.tasks = []
self.settings["jai_config_manager"] = SimpleNamespace(
lm_provider=MockProvider, lm_provider_params={"model_id": "model"}
)
self.settings["jai_event_loop"] = SimpleNamespace(
create_task=lambda x: self.tasks.append(x)
)
self.settings["model_parameters"] = {}
self.llm_params = {}
self.create_llm_chain(MockProvider, {"model_id": "model"})

def write_message(self, message: str) -> None: # type: ignore
self.messages.append(message)

async def handle_exc(self, e: Exception, _request: InlineCompletionRequest):
# raise all exceptions during testing rather
raise e


@fixture
def inline_handler() -> MockCompletionHandler:
return MockCompletionHandler()


async def test_on_message(inline_handler):
request = InlineCompletionRequest(
number=1, prefix="", suffix="", mime="", stream=False
)
# Test end to end, without checking details of the replies,
# which are tested in appropriate method unit tests.
await inline_handler.on_message(json.dumps(dict(request)))
assert len(inline_handler.tasks) == 1
await inline_handler.tasks[0]
assert len(inline_handler.messages) == 1


async def test_on_message_stream(inline_handler):
stream_request = InlineCompletionRequest(
number=1, prefix="", suffix="", mime="", stream=True
)
# Test end to end, without checking details of the replies,
# which are tested in appropriate method unit tests.
await inline_handler.on_message(json.dumps(dict(stream_request)))
assert len(inline_handler.tasks) == 1
await inline_handler.tasks[0]
assert len(inline_handler.messages) == 3


async def test_handle_request(inline_handler):
dummy_request = InlineCompletionRequest(
number=1, prefix="", suffix="", mime="", stream=False
)
await inline_handler.handle_request(dummy_request)
# should write a single reply
assert len(inline_handler.messages) == 1
# reply should contain a single suggestion
suggestions = inline_handler.messages[0].list.items
assert len(suggestions) == 1
# the suggestion should include insert text from LLM
assert suggestions[0].insertText == "Test response"


async def test_handle_stream_request(inline_handler):
inline_handler.llm_chain = FakeListLLM(responses=["test"])
dummy_request = InlineCompletionRequest(
number=1, prefix="", suffix="", mime="", stream=True
)
await inline_handler.handle_stream_request(dummy_request)

# should write three replies
assert len(inline_handler.messages) == 3

# first reply should be empty to start the stream
first = inline_handler.messages[0].list.items[0]
assert first.insertText == ""
assert first.isIncomplete == True

# second reply should be a chunk containing the token
second = inline_handler.messages[1]
assert second.type == "stream"
assert second.response.insertText == "Test response"
assert second.done == False

# third reply should be a closing chunk
third = inline_handler.messages[2]
assert third.type == "stream"
assert third.response.insertText == "Test response"
assert third.done == True
Loading