Skip to content

Commit

Permalink
checking pubsub
Browse files Browse the repository at this point in the history
  • Loading branch information
sanderegg committed May 24, 2023
1 parent fb5d322 commit 8db93d1
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 6 deletions.
43 changes: 39 additions & 4 deletions services/dask-sidecar/tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from pathlib import Path
from pprint import pformat
from typing import AsyncIterator, Callable, Iterable, Iterator
from typing import AsyncIterator, Callable, Iterator

import dask
import distributed
Expand Down Expand Up @@ -76,7 +76,7 @@ def mock_service_envs(


@pytest.fixture
def dask_client(mock_service_envs: None) -> Iterable[distributed.Client]:
def local_cluster(mock_service_envs: None) -> Iterator[distributed.LocalCluster]:
print(pformat(dask.config.get("distributed")))
with distributed.LocalCluster(
worker_class=distributed.Worker,
Expand All @@ -85,8 +85,43 @@ def dask_client(mock_service_envs: None) -> Iterable[distributed.Client]:
"preload": "simcore_service_dask_sidecar.tasks",
},
) as cluster:
with distributed.Client(cluster) as client:
yield client
assert cluster
assert isinstance(cluster, distributed.LocalCluster)
yield cluster


@pytest.fixture
def dask_client(
local_cluster: distributed.LocalCluster,
) -> Iterator[distributed.Client]:
with distributed.Client(local_cluster) as client:
yield client


@pytest.fixture
async def async_local_cluster(
mock_service_envs: None,
) -> AsyncIterator[distributed.LocalCluster]:
print(pformat(dask.config.get("distributed")))
async with distributed.LocalCluster(
worker_class=distributed.Worker,
**{
"resources": {"CPU": 10, "GPU": 10},
"preload": "simcore_service_dask_sidecar.tasks",
},
asynchronous=True,
) as cluster:
assert cluster
assert isinstance(cluster, distributed.LocalCluster)
yield cluster


@pytest.fixture
async def async_dask_client(
async_local_cluster: distributed.LocalCluster,
) -> AsyncIterator[distributed.Client]:
async with distributed.Client(async_local_cluster, asynchronous=True) as client:
yield client


@pytest.fixture(scope="module")
Expand Down
68 changes: 66 additions & 2 deletions services/dask-sidecar/tests/unit/test_dask_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import concurrent.futures
import logging
import time
from typing import Any
from typing import Any, Coroutine

import distributed
import pytest
Expand All @@ -22,12 +22,16 @@
monitor_task_abortion,
publish_event,
)
from tenacity._asyncio import AsyncRetrying
from tenacity.retry import retry_if_exception_type
from tenacity.stop import stop_after_delay
from tenacity.wait import wait_fixed

DASK_TASK_STARTED_EVENT = "task_started"
DASK_TESTING_TIMEOUT_S = 25


async def test_publish_event(dask_client: distributed.Client):
def test_publish_event(dask_client: distributed.Client):
dask_pub = distributed.Pub("some_topic", client=dask_client)
dask_sub = distributed.Sub("some_topic", client=dask_client)
event_to_publish = TaskLogEvent(
Expand All @@ -44,6 +48,66 @@ async def test_publish_event(dask_client: distributed.Client):
assert received_task_log_event == event_to_publish


async def test_publish_event_async(async_dask_client: distributed.Client):
dask_pub = distributed.Pub("some_topic", client=async_dask_client)
dask_sub = distributed.Sub("some_topic", client=async_dask_client)
event_to_publish = TaskLogEvent(
job_id="some_fake_job_id", log="the log", log_level=logging.INFO
)
publish_event(dask_pub=dask_pub, event=event_to_publish)

# NOTE: this tests runs a sync dask client,
# and the CI seems to have sometimes difficulties having this run in a reasonable time
# hence the long time out
message = dask_sub.get(timeout=DASK_TESTING_TIMEOUT_S)
assert isinstance(message, Coroutine)
message = await message
assert message is not None
received_task_log_event = TaskLogEvent.parse_raw(message) # type: ignore
assert received_task_log_event == event_to_publish


async def test_publish_event_async_using_task(async_dask_client: distributed.Client):
dask_pub = distributed.Pub("some_topic", client=async_dask_client)
dask_sub = distributed.Sub("some_topic", client=async_dask_client)

received_messages = []

async def _dask_sub_consumer_task(sub: distributed.Sub) -> None:
print("--> starting consumer task")
async for dask_event in sub:
print(f"received {dask_event}")
received_messages.append(dask_event)
print("<-- finished consumer task")

task = asyncio.create_task(
_dask_sub_consumer_task(dask_sub), name="pytest_dask_sub_consumer"
)

event_to_publish = TaskLogEvent(
job_id="some_fake_job_id", log="the log", log_level=logging.INFO
)
publish_event(dask_pub=dask_pub, event=event_to_publish)
async for attempt in AsyncRetrying(
retry=retry_if_exception_type(AssertionError),
stop=stop_after_delay(DASK_TESTING_TIMEOUT_S),
wait=wait_fixed(0.01),
reraise=True,
):
with attempt:
print(f"checking number of received messages...{received_messages=}")
assert len(received_messages) == 1

# NOTE: this tests runs a sync dask client,
# and the CI seems to have sometimes difficulties having this run in a reasonable time
# hence the long time out

message = received_messages[0]
assert message is not None
received_task_log_event = TaskLogEvent.parse_raw(message)
assert received_task_log_event == event_to_publish


def _wait_for_task_to_start():
start_event = distributed.Event(DASK_TASK_STARTED_EVENT)
start_event.wait(timeout=DASK_TESTING_TIMEOUT_S)
Expand Down

0 comments on commit 8db93d1

Please sign in to comment.