From 8db93d1f01b57a2309448a83c9930299f0e51296 Mon Sep 17 00:00:00 2001 From: sanderegg <35365065+sanderegg@users.noreply.github.com> Date: Wed, 24 May 2023 12:14:09 +0200 Subject: [PATCH] checking pubsub --- services/dask-sidecar/tests/unit/conftest.py | 43 ++++++++++-- .../tests/unit/test_dask_utils.py | 68 ++++++++++++++++++- 2 files changed, 105 insertions(+), 6 deletions(-) diff --git a/services/dask-sidecar/tests/unit/conftest.py b/services/dask-sidecar/tests/unit/conftest.py index 043f5d06531..d9cf2bbf9e5 100644 --- a/services/dask-sidecar/tests/unit/conftest.py +++ b/services/dask-sidecar/tests/unit/conftest.py @@ -5,7 +5,7 @@ from pathlib import Path from pprint import pformat -from typing import AsyncIterator, Callable, Iterable, Iterator +from typing import AsyncIterator, Callable, Iterator import dask import distributed @@ -76,7 +76,7 @@ def mock_service_envs( @pytest.fixture -def dask_client(mock_service_envs: None) -> Iterable[distributed.Client]: +def local_cluster(mock_service_envs: None) -> Iterator[distributed.LocalCluster]: print(pformat(dask.config.get("distributed"))) with distributed.LocalCluster( worker_class=distributed.Worker, @@ -85,8 +85,43 @@ def dask_client(mock_service_envs: None) -> Iterable[distributed.Client]: "preload": "simcore_service_dask_sidecar.tasks", }, ) as cluster: - with distributed.Client(cluster) as client: - yield client + assert cluster + assert isinstance(cluster, distributed.LocalCluster) + yield cluster + + +@pytest.fixture +def dask_client( + local_cluster: distributed.LocalCluster, +) -> Iterator[distributed.Client]: + with distributed.Client(local_cluster) as client: + yield client + + +@pytest.fixture +async def async_local_cluster( + mock_service_envs: None, +) -> AsyncIterator[distributed.LocalCluster]: + print(pformat(dask.config.get("distributed"))) + async with distributed.LocalCluster( + worker_class=distributed.Worker, + **{ + "resources": {"CPU": 10, "GPU": 10}, + "preload": "simcore_service_dask_sidecar.tasks", + }, + asynchronous=True, + ) as cluster: + assert cluster + assert isinstance(cluster, distributed.LocalCluster) + yield cluster + + +@pytest.fixture +async def async_dask_client( + async_local_cluster: distributed.LocalCluster, +) -> AsyncIterator[distributed.Client]: + async with distributed.Client(async_local_cluster, asynchronous=True) as client: + yield client @pytest.fixture(scope="module") diff --git a/services/dask-sidecar/tests/unit/test_dask_utils.py b/services/dask-sidecar/tests/unit/test_dask_utils.py index 9123a7cd896..5354fd1925d 100644 --- a/services/dask-sidecar/tests/unit/test_dask_utils.py +++ b/services/dask-sidecar/tests/unit/test_dask_utils.py @@ -8,7 +8,7 @@ import concurrent.futures import logging import time -from typing import Any +from typing import Any, Coroutine import distributed import pytest @@ -22,12 +22,16 @@ monitor_task_abortion, publish_event, ) +from tenacity._asyncio import AsyncRetrying +from tenacity.retry import retry_if_exception_type +from tenacity.stop import stop_after_delay +from tenacity.wait import wait_fixed DASK_TASK_STARTED_EVENT = "task_started" DASK_TESTING_TIMEOUT_S = 25 -async def test_publish_event(dask_client: distributed.Client): +def test_publish_event(dask_client: distributed.Client): dask_pub = distributed.Pub("some_topic", client=dask_client) dask_sub = distributed.Sub("some_topic", client=dask_client) event_to_publish = TaskLogEvent( @@ -44,6 +48,66 @@ async def test_publish_event(dask_client: distributed.Client): assert received_task_log_event == event_to_publish +async def test_publish_event_async(async_dask_client: distributed.Client): + dask_pub = distributed.Pub("some_topic", client=async_dask_client) + dask_sub = distributed.Sub("some_topic", client=async_dask_client) + event_to_publish = TaskLogEvent( + job_id="some_fake_job_id", log="the log", log_level=logging.INFO + ) + publish_event(dask_pub=dask_pub, event=event_to_publish) + + # NOTE: this tests runs a sync dask client, + # and the CI seems to have sometimes difficulties having this run in a reasonable time + # hence the long time out + message = dask_sub.get(timeout=DASK_TESTING_TIMEOUT_S) + assert isinstance(message, Coroutine) + message = await message + assert message is not None + received_task_log_event = TaskLogEvent.parse_raw(message) # type: ignore + assert received_task_log_event == event_to_publish + + +async def test_publish_event_async_using_task(async_dask_client: distributed.Client): + dask_pub = distributed.Pub("some_topic", client=async_dask_client) + dask_sub = distributed.Sub("some_topic", client=async_dask_client) + + received_messages = [] + + async def _dask_sub_consumer_task(sub: distributed.Sub) -> None: + print("--> starting consumer task") + async for dask_event in sub: + print(f"received {dask_event}") + received_messages.append(dask_event) + print("<-- finished consumer task") + + task = asyncio.create_task( + _dask_sub_consumer_task(dask_sub), name="pytest_dask_sub_consumer" + ) + + event_to_publish = TaskLogEvent( + job_id="some_fake_job_id", log="the log", log_level=logging.INFO + ) + publish_event(dask_pub=dask_pub, event=event_to_publish) + async for attempt in AsyncRetrying( + retry=retry_if_exception_type(AssertionError), + stop=stop_after_delay(DASK_TESTING_TIMEOUT_S), + wait=wait_fixed(0.01), + reraise=True, + ): + with attempt: + print(f"checking number of received messages...{received_messages=}") + assert len(received_messages) == 1 + + # NOTE: this tests runs a sync dask client, + # and the CI seems to have sometimes difficulties having this run in a reasonable time + # hence the long time out + + message = received_messages[0] + assert message is not None + received_task_log_event = TaskLogEvent.parse_raw(message) + assert received_task_log_event == event_to_publish + + def _wait_for_task_to_start(): start_event = distributed.Event(DASK_TASK_STARTED_EVENT) start_event.wait(timeout=DASK_TESTING_TIMEOUT_S)