From de3f125df4c3ff19165ac2bbf67e2df6e73d1ef3 Mon Sep 17 00:00:00 2001 From: alangenfeld Date: Fri, 19 Jan 2024 15:41:48 -0600 Subject: [PATCH] [dagster webserver] fix websockets on py3.11 --- .../dagster-webserver/dagster_webserver/graphql.py | 2 +- .../dagster_webserver_tests/test_subscriptions.py | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/python_modules/dagster-webserver/dagster_webserver/graphql.py b/python_modules/dagster-webserver/dagster_webserver/graphql.py index 1a0a70ebbd6b7..f95af22296564 100644 --- a/python_modules/dagster-webserver/dagster_webserver/graphql.py +++ b/python_modules/dagster-webserver/dagster_webserver/graphql.py @@ -182,7 +182,7 @@ async def graphql_ws_endpoint(self, websocket: WebSocket): """ tasks: Dict[str, Task] = {} - await websocket.accept(subprotocol=GraphQLWS.PROTOCOL) + await websocket.accept(subprotocol=GraphQLWS.PROTOCOL.value) try: while ( diff --git a/python_modules/dagster-webserver/dagster_webserver_tests/test_subscriptions.py b/python_modules/dagster-webserver/dagster_webserver_tests/test_subscriptions.py index 07c41320768bd..9e9091878ebb3 100644 --- a/python_modules/dagster-webserver/dagster_webserver_tests/test_subscriptions.py +++ b/python_modules/dagster-webserver/dagster_webserver_tests/test_subscriptions.py @@ -1,6 +1,7 @@ import gc import sys from contextlib import contextmanager +from typing import Iterator from unittest import mock import objgraph @@ -36,7 +37,7 @@ @contextmanager -def create_asgi_client(instance): +def create_asgi_client(instance) -> Iterator[TestClient]: yaml_paths = [file_relative_path(__file__, "./workspace.yaml")] with WorkspaceProcessContext( @@ -82,7 +83,7 @@ def example_job(): example_op() -def test_event_log_subscription(): +def test_event_log_subscription() -> None: with instance_for_test() as instance: run = example_job.execute_in_process(instance=instance) assert run.success @@ -90,6 +91,7 @@ def test_event_log_subscription(): with create_asgi_client(instance) as client: with client.websocket_connect("/graphql", GraphQLWS.PROTOCOL) as ws: + assert str(ws.accepted_subprotocol) == "graphql-ws" start_subscription(ws, EVENT_LOG_SUBSCRIPTION, {"runId": run.run_id}) gc.collect() assert len(objgraph.by_type("async_generator")) == 1