Skip to content

Commit

Permalink
Fix prefect-dask test suite (#16920)
Browse files Browse the repository at this point in the history
  • Loading branch information
desertaxle authored Jan 31, 2025
1 parent da6a517 commit 5b5725a
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 67 deletions.
35 changes: 0 additions & 35 deletions src/integrations/prefect-dask/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
import asyncio
import logging
import sys

import pytest

from prefect.testing.utilities import prefect_test_harness
Expand All @@ -14,34 +10,3 @@ def prefect_db():
"""
with prefect_test_harness():
yield


@pytest.fixture(scope="session")
def event_loop(request):
"""
Redefine the event loop to support session/module-scoped fixtures;
see https://github.com/pytest-dev/pytest-asyncio/issues/68
When running on Windows we need to use a non-default loop for subprocess support.
"""
if sys.platform == "win32":
asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy())

policy = asyncio.get_event_loop_policy()

loop = policy.new_event_loop()

# configure asyncio logging to capture long running tasks
asyncio_logger = logging.getLogger("asyncio")
asyncio_logger.setLevel("WARNING")
asyncio_logger.addHandler(logging.StreamHandler())
loop.set_debug(True)
loop.slow_callback_duration = 0.25

try:
yield loop
finally:
loop.close()

# Workaround for failures in pytest_asyncio 0.17;
# see https://github.com/pytest-dev/pytest-asyncio/issues/257
policy.set_event_loop(loop)
7 changes: 0 additions & 7 deletions src/integrations/prefect-dask/tests/test_client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import pytest
from prefect_dask.client import PrefectDaskClient

from prefect.client.orchestration import get_client
Expand All @@ -13,12 +12,6 @@
from prefect.context import FlowRunContext
from prefect.flows import flow
from prefect.tasks import task
from prefect.testing.fixtures import ( # noqa: F401
hosted_api_server,
use_hosted_api_server,
)

pytestmark = pytest.mark.usefixtures("use_hosted_api_server")


class TestSubmit:
Expand Down
54 changes: 29 additions & 25 deletions src/integrations/prefect-dask/tests/test_task_runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,45 +14,49 @@
from prefect.futures import as_completed
from prefect.server.schemas.states import StateType
from prefect.states import State
from prefect.testing.fixtures import ( # noqa: F401
hosted_api_server,
use_hosted_api_server,
)


@pytest.fixture(scope="module")
def cluster():
with distributed.LocalCluster(dashboard_address=None) as cluster:
yield cluster


@pytest.fixture
def dask_task_runner_with_existing_cluster(use_hosted_api_server): # noqa
def dask_task_runner_with_existing_cluster(cluster: distributed.LocalCluster): # noqa
"""
Generate a dask task runner that's connected to a local cluster
"""
with distributed.LocalCluster(n_workers=2) as cluster:
yield DaskTaskRunner(cluster=cluster)
yield DaskTaskRunner(cluster=cluster)


@pytest.fixture
def dask_task_runner_with_existing_cluster_address(use_hosted_api_server): # noqa
def dask_task_runner_with_existing_cluster_address(cluster: distributed.LocalCluster): # noqa
"""
Generate a dask task runner that's connected to a local cluster
"""
with distributed.LocalCluster(n_workers=2) as cluster:
with distributed.Client(cluster) as client:
address = client.scheduler.address
yield DaskTaskRunner(address=address)
with distributed.Client(cluster) as client:
address = client.scheduler.address
yield DaskTaskRunner(address=address)


@pytest.fixture
def dask_task_runner_with_process_pool(use_hosted_api_server): # noqa
yield DaskTaskRunner(cluster_kwargs={"processes": True})
def dask_task_runner_with_process_pool(): # noqa
yield DaskTaskRunner(cluster_kwargs={"processes": True, "dashboard_address": None})


@pytest.fixture
def dask_task_runner_with_thread_pool(use_hosted_api_server): # noqa
yield DaskTaskRunner(cluster_kwargs={"processes": False})
def dask_task_runner_with_thread_pool(): # noqa
yield DaskTaskRunner(cluster_kwargs={"processes": False, "dashboard_address": None})


@pytest.fixture
def default_dask_task_runner(use_hosted_api_server): # noqa
yield DaskTaskRunner()
def default_dask_task_runner(): # noqa
yield DaskTaskRunner(
cluster_kwargs={
"dashboard_address": None, # Prevent port conflicts
}
)


class TestDaskTaskRunner:
Expand Down Expand Up @@ -192,7 +196,7 @@ async def fake_orchestrate_task_run(example_kwarg):
state = future.state
assert await state.result() == "a"

async def test_async_task_timeout(self, task_runner):
async def test_async_task_timeout(self, task_runner: DaskTaskRunner):
@task(timeout_seconds=0.1)
async def my_timeout_task():
await asyncio.sleep(2)
Expand Down Expand Up @@ -342,9 +346,9 @@ async def adapt(self, **kwargs):
with task_runner:
assert task_runner._cluster._adapt_called

def test_warns_if_future_garbage_collection_before_resolving(
self, caplog, task_runner
):
def test_warns_if_future_garbage_collection_before_resolving(self, caplog):
task_runner = DaskTaskRunner(cluster_kwargs={"dashboard_address": None})

@task
def test_task():
return 42
Expand All @@ -358,9 +362,9 @@ def test_flow():

assert "A future was garbage collected before it resolved" in caplog.text

def test_does_not_warn_if_future_resolved_when_garbage_collected(
self, task_runner, caplog
):
def test_does_not_warn_if_future_resolved_when_garbage_collected(self, caplog):
task_runner = DaskTaskRunner(cluster_kwargs={"dashboard_address": None})

@task
def test_task():
return 42
Expand Down

0 comments on commit 5b5725a

Please sign in to comment.