Skip to content

Commit

Permalink
adding test
Browse files Browse the repository at this point in the history
  • Loading branch information
sanderegg committed Oct 17, 2023
1 parent f0410f7 commit 829cf50
Show file tree
Hide file tree
Showing 5 changed files with 282 additions and 144 deletions.
150 changes: 68 additions & 82 deletions services/autoscaling/src/simcore_service_autoscaling/modules/dask.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
import contextlib
import logging
from typing import Any, AsyncIterator, Coroutine, Final
from collections.abc import AsyncIterator, Coroutine
from typing import Any, Final

import distributed
from pydantic import AnyUrl, ByteSize, parse_obj_as

from ..core.errors import DaskSchedulerNotFoundError, DaskWorkerNotFoundError
from ..core.errors import (
DaskSchedulerNotFoundError,
DaskWorkerNotFoundError,
Ec2InvalidDnsNameError,
)
from ..models import (
AssociatedInstance,
DaskTask,
Expand Down Expand Up @@ -50,42 +55,35 @@ def _list_tasks(
task.key: task.resource_restrictions for task in dask_scheduler.unrunnable
}

with contextlib.suppress(DaskSchedulerNotFoundError):
async with _scheduler_client(url) as client:
list_of_tasks: dict[
DaskTaskId, DaskTaskResources
] = await _wrap_client_async_routine(client.run_on_scheduler(_list_tasks))
_logger.info("found unrunnable tasks: %s", list_of_tasks)
return [
DaskTask(task_id=task_id, required_resources=task_resources)
for task_id, task_resources in list_of_tasks.items()
]
return []
async with _scheduler_client(url) as client:
list_of_tasks: dict[
DaskTaskId, DaskTaskResources
] = await _wrap_client_async_routine(client.run_on_scheduler(_list_tasks))
_logger.info("found unrunnable tasks: %s", list_of_tasks)
return [
DaskTask(task_id=task_id, required_resources=task_resources)
for task_id, task_resources in list_of_tasks.items()
]


async def list_processing_tasks(url: AnyUrl) -> list[DaskTaskId]:
try:
async with _scheduler_client(url) as client:
processing_tasks = set()
if worker_to_processing_tasks := await _wrap_client_async_routine(
client.processing()
):
_logger.info(
"cluster worker processing: %s", worker_to_processing_tasks
)
for tasks in worker_to_processing_tasks.values():
processing_tasks |= set(tasks)

return list(processing_tasks)
except DaskSchedulerNotFoundError:
return []
async with _scheduler_client(url) as client:
processing_tasks = set()
if worker_to_processing_tasks := await _wrap_client_async_routine(
client.processing()
):
_logger.info("cluster worker processing: %s", worker_to_processing_tasks)
for tasks in worker_to_processing_tasks.values():
processing_tasks |= set(tasks)

return list(processing_tasks)


async def get_worker_still_has_results_in_memory(
url: AnyUrl, ec2_instance: EC2InstanceData
) -> int:
try:
async with _scheduler_client(url) as client:
async with _scheduler_client(url) as client:
with contextlib.suppress(Ec2InvalidDnsNameError):
scheduler_info = client.scheduler_info()
if "workers" not in scheduler_info or not scheduler_info["workers"]:
raise DaskWorkerNotFoundError(url=url)
Expand All @@ -103,9 +101,7 @@ async def get_worker_still_has_results_in_memory(
worker_metrics: dict[str, Any] = workers[node_worker_name]["metrics"]
if worker_metrics.get("task_counts", {}) != {}:
return 1
return 0
except DaskSchedulerNotFoundError as exc:
raise DaskWorkerNotFoundError(url=url) from exc
return 0


async def get_worker_used_resources(
Expand All @@ -119,59 +115,49 @@ def _get_worker_used_resources(
used_resources[worker_name] = worker_state.used_resources
return used_resources

try:
async with _scheduler_client(url) as client:
scheduler_info = client.scheduler_info()
if "workers" not in scheduler_info or not scheduler_info["workers"]:
raise DaskWorkerNotFoundError(url=url)
workers: dict[str, Any] = scheduler_info["workers"]
# dict is of type dask_worker_address: worker_details
node_worker_name = None

for worker_name, worker_details in workers.items():
if worker_details["host"] == node_ip_from_ec2_private_dns(ec2_instance):
node_worker_name = worker_name
break
if not node_worker_name:
raise DaskWorkerNotFoundError(url=url)

# now get the used resources
used_resources_per_worker: dict[
str, dict[str, Any]
] = await _wrap_client_async_routine(
client.run_on_scheduler(_get_worker_used_resources)
)
if node_worker_name not in used_resources_per_worker:
raise DaskWorkerNotFoundError(url=url)
worker_used_resources = used_resources_per_worker[node_worker_name]
return Resources(
cpus=worker_used_resources.get("CPU", 0),
ram=parse_obj_as(ByteSize, worker_used_resources.get("RAM", 0)),
)

except DaskSchedulerNotFoundError as exc:
raise DaskWorkerNotFoundError(url=url) from exc
async with _scheduler_client(url) as client:
scheduler_info = client.scheduler_info()
if "workers" not in scheduler_info or not scheduler_info["workers"]:
raise DaskWorkerNotFoundError(url=url)
workers: dict[str, Any] = scheduler_info["workers"]
# dict is of type dask_worker_address: worker_details
node_worker_name = None

for worker_name, worker_details in workers.items():
if worker_details["host"] == node_ip_from_ec2_private_dns(ec2_instance):
node_worker_name = worker_name
break
if not node_worker_name:
raise DaskWorkerNotFoundError(url=url)

# now get the used resources
used_resources_per_worker: dict[
str, dict[str, Any]
] = await _wrap_client_async_routine(
client.run_on_scheduler(_get_worker_used_resources)
)
if node_worker_name not in used_resources_per_worker:
raise DaskWorkerNotFoundError(url=url)
worker_used_resources = used_resources_per_worker[node_worker_name]
return Resources(
cpus=worker_used_resources.get("CPU", 0),
ram=parse_obj_as(ByteSize, worker_used_resources.get("RAM", 0)),
)


async def compute_cluster_total_resources(
url: AnyUrl, instances: list[AssociatedInstance]
) -> Resources:
try:
async with _scheduler_client(url) as client:
instance_hosts = (
node_ip_from_ec2_private_dns(i.ec2_instance) for i in instances
)
scheduler_info = client.scheduler_info()
if "workers" not in scheduler_info or not scheduler_info["workers"]:
raise DaskWorkerNotFoundError(url=url)
workers: dict[str, Any] = scheduler_info["workers"]
workers_to_consider: dict[str, Any] = {}
for worker_details in workers.values():
if worker_details["host"] not in instance_hosts:
continue
worker_metrics = worker_details["metrics"]

return Resources.create_as_empty()
async with _scheduler_client(url) as client:
instance_hosts = (
node_ip_from_ec2_private_dns(i.ec2_instance) for i in instances
)
scheduler_info = client.scheduler_info()
if "workers" not in scheduler_info or not scheduler_info["workers"]:
raise DaskWorkerNotFoundError(url=url)
workers: dict[str, Any] = scheduler_info["workers"]
for worker_details in workers.values():
if worker_details["host"] not in instance_hosts:
continue

except DaskSchedulerNotFoundError:
return Resources.create_as_empty()
30 changes: 29 additions & 1 deletion services/autoscaling/tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from typing import Any, Final, cast

import aiodocker
import distributed
import httpx
import psutil
import pytest
Expand Down Expand Up @@ -41,7 +42,7 @@
from settings_library.rabbit import RabbitSettings
from simcore_service_autoscaling.core.application import create_app
from simcore_service_autoscaling.core.settings import ApplicationSettings, EC2Settings
from simcore_service_autoscaling.models import Cluster
from simcore_service_autoscaling.models import Cluster, DaskTaskResources
from simcore_service_autoscaling.modules.docker import AutoscalingDocker
from simcore_service_autoscaling.modules.ec2 import AutoscalingEC2, EC2InstanceData
from tenacity import retry
Expand Down Expand Up @@ -678,6 +679,15 @@ def _creator(**overrides) -> EC2InstanceData:
return _creator


@pytest.fixture
def fake_localhost_ec2_instance_data(
fake_ec2_instance_data: Callable[..., EC2InstanceData]
) -> EC2InstanceData:
local_ip = get_localhost_ip()
fake_local_ec2_private_dns = f"ip-{local_ip.replace('.', '-')}.ec2.internal"
return fake_ec2_instance_data(aws_private_dns=fake_local_ec2_private_dns)


@pytest.fixture
async def mocked_redis_server(mocker: MockerFixture) -> None:
mock_redis = FakeRedis()
Expand All @@ -700,3 +710,21 @@ def _creator(**cluter_overrides) -> Cluster:
)

return _creator


@pytest.fixture
async def create_dask_task(
dask_spec_cluster_client: distributed.Client,
) -> Callable[[DaskTaskResources], distributed.Future]:
def _remote_pytest_fct(x: int, y: int) -> int:
return x + y

def _creator(required_resources: DaskTaskResources) -> distributed.Future:
# NOTE: pure will ensure dask does not re-use the task results if we run it several times
future = dask_spec_cluster_client.submit(
_remote_pytest_fct, 23, 43, resources=required_resources, pure=False
)
assert future
return future

return _creator
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from models_library.rabbitmq_messages import RabbitAutoscalingStatusMessage
from pydantic import ByteSize, parse_obj_as
from pytest_mock import MockerFixture
from pytest_simcore.helpers.utils_docker import get_localhost_ip
from pytest_simcore.helpers.utils_envs import EnvVarsDict, setenvs_from_dict
from simcore_service_autoscaling.core.settings import ApplicationSettings
from simcore_service_autoscaling.models import (
Expand Down Expand Up @@ -89,24 +88,6 @@ def empty_cluster(cluster: Callable[..., Cluster]) -> Cluster:
return cluster()


@pytest.fixture
async def create_dask_task(
dask_spec_cluster_client: distributed.Client,
) -> Callable[[DaskTaskResources], distributed.Future]:
def _remote_pytest_fct(x: int, y: int) -> int:
return x + y

def _creator(required_resources: DaskTaskResources) -> distributed.Future:
# NOTE: pure will ensure dask does not re-use the task results if we run it several times
future = dask_spec_cluster_client.submit(
_remote_pytest_fct, 23, 43, resources=required_resources, pure=False
)
assert future
return future

return _creator


async def _assert_ec2_instances(
ec2_client: EC2Client,
*,
Expand Down Expand Up @@ -265,14 +246,18 @@ async def test_cluster_scaling_with_no_tasks_does_nothing(
mock_start_aws_instance: mock.Mock,
mock_terminate_instances: mock.Mock,
mock_rabbitmq_post_message: mock.Mock,
dask_spec_local_cluster: distributed.SpecCluster,
):
await auto_scale_cluster(
app=initialized_app, auto_scaling_mode=ComputationalAutoscaling()
)
mock_start_aws_instance.assert_not_called()
mock_terminate_instances.assert_not_called()
_assert_rabbit_autoscaling_message_sent(
mock_rabbitmq_post_message, app_settings, initialized_app
mock_rabbitmq_post_message,
app_settings,
initialized_app,
dask_spec_local_cluster.scheduler_address,
)


Expand All @@ -284,6 +269,7 @@ async def test_cluster_scaling_with_task_with_too_much_resources_starts_nothing(
mock_start_aws_instance: mock.Mock,
mock_terminate_instances: mock.Mock,
mock_rabbitmq_post_message: mock.Mock,
dask_spec_local_cluster: distributed.SpecCluster,
):
# create a task that needs too much power
dask_future = create_dask_task({"RAM": int(parse_obj_as(ByteSize, "12800GiB"))})
Expand All @@ -295,7 +281,10 @@ async def test_cluster_scaling_with_task_with_too_much_resources_starts_nothing(
mock_start_aws_instance.assert_not_called()
mock_terminate_instances.assert_not_called()
_assert_rabbit_autoscaling_message_sent(
mock_rabbitmq_post_message, app_settings, initialized_app
mock_rabbitmq_post_message,
app_settings,
initialized_app,
dask_spec_local_cluster.scheduler_address,
)


Expand All @@ -312,6 +301,7 @@ async def test_cluster_scaling_up(
mock_set_node_availability: mock.Mock,
mock_compute_node_used_resources: mock.Mock,
mocker: MockerFixture,
dask_spec_local_cluster: distributed.SpecCluster,
):
# we have nothing running now
all_instances = await ec2_client.describe_instances()
Expand Down Expand Up @@ -345,6 +335,7 @@ async def test_cluster_scaling_up(
mock_rabbitmq_post_message,
app_settings,
initialized_app,
dask_spec_local_cluster.scheduler_address,
instances_running=0,
instances_pending=1,
)
Expand Down Expand Up @@ -459,6 +450,7 @@ async def test_cluster_scaling_up_starts_multiple_instances(
mock_rabbitmq_post_message: mock.Mock,
mock_find_node_with_name: mock.Mock,
mock_set_node_availability: mock.Mock,
dask_spec_local_cluster: distributed.SpecCluster,
):
# we have nothing running now
all_instances = await ec2_client.describe_instances()
Expand Down Expand Up @@ -500,6 +492,7 @@ async def test_cluster_scaling_up_starts_multiple_instances(
mock_rabbitmq_post_message,
app_settings,
initialized_app,
dask_spec_local_cluster.scheduler_address,
instances_pending=scale_up_params.expected_num_instances,
)
mock_rabbitmq_post_message.reset_mock()
Expand All @@ -508,13 +501,11 @@ async def test_cluster_scaling_up_starts_multiple_instances(
@pytest.fixture
def fake_associated_host_instance(
host_node: DockerNode,
fake_ec2_instance_data: Callable[..., EC2InstanceData],
fake_localhost_ec2_instance_data: EC2InstanceData,
) -> AssociatedInstance:
local_ip = get_localhost_ip()
fake_local_ec2_private_dns = f"ip-{local_ip.replace('.', '-')}.ec2.internal"
return AssociatedInstance(
host_node,
fake_ec2_instance_data(aws_private_dns=fake_local_ec2_private_dns),
fake_localhost_ec2_instance_data,
)


Expand Down
Loading

0 comments on commit 829cf50

Please sign in to comment.