Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🎨Clusters keeper/use ssm (🚨change in private clusters) #6361

Merged
merged 14 commits into from
Sep 20, 2024
1 change: 1 addition & 0 deletions .env-devel
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ CLUSTERS_KEEPER_COMPUTATIONAL_BACKEND_DOCKER_IMAGE_TAG=master-github-latest
CLUSTERS_KEEPER_DASK_NTHREADS=0
CLUSTERS_KEEPER_DASK_WORKER_SATURATION=inf
CLUSTERS_KEEPER_EC2_ACCESS=null
CLUSTERS_KEEPER_SSM_ACCESS=null
CLUSTERS_KEEPER_EC2_INSTANCES_PREFIX=""
CLUSTERS_KEEPER_LOGLEVEL=WARNING
CLUSTERS_KEEPER_MAX_MISSED_HEARTBEATS_BEFORE_CLUSTER_TERMINATION=5
Expand Down
3 changes: 3 additions & 0 deletions packages/models-library/src/models_library/clusters.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,9 @@ class Config(BaseAuthentication.Config):
class NoAuthentication(BaseAuthentication):
type: Literal["none"] = "none"

class Config(BaseAuthentication.Config):
schema_extra: ClassVar[dict[str, Any]] = {"examples": [{"type": "none"}]}


class TLSAuthentication(BaseAuthentication):
type: Literal["tls"] = "tls"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
@router.get("/", include_in_schema=True, response_class=PlainTextResponse)
async def health_check():
# NOTE: sync url in docker/healthcheck.py with this entrypoint!
return f"{__name__}.health_check@{datetime.datetime.now(datetime.timezone.utc).isoformat()}"
return f"{__name__}.health_check@{datetime.datetime.now(datetime.UTC).isoformat()}"
sanderegg marked this conversation as resolved.
Show resolved Hide resolved


class _ComponentStatus(BaseModel):
Expand All @@ -33,25 +33,34 @@ class _StatusGet(BaseModel):
rabbitmq: _ComponentStatus
ec2: _ComponentStatus
redis_client_sdk: _ComponentStatus
ssm: _ComponentStatus


@router.get("/status", include_in_schema=True, response_model=_StatusGet)
async def get_status(app: Annotated[FastAPI, Depends(get_app)]) -> _StatusGet:
return _StatusGet(
rabbitmq=_ComponentStatus(
is_enabled=is_rabbitmq_enabled(app),
is_responsive=await get_rabbitmq_client(app).ping()
if is_rabbitmq_enabled(app)
else False,
is_responsive=(
await get_rabbitmq_client(app).ping()
if is_rabbitmq_enabled(app)
else False
),
),
ec2=_ComponentStatus(
is_enabled=bool(app.state.ec2_client),
is_responsive=await app.state.ec2_client.ping()
if app.state.ec2_client
else False,
is_responsive=(
await app.state.ec2_client.ping() if app.state.ec2_client else False
),
),
redis_client_sdk=_ComponentStatus(
is_enabled=bool(app.state.redis_client_sdk),
is_responsive=await get_redis_client(app).ping(),
),
ssm=_ComponentStatus(
is_enabled=(app.state.ssm_client is not None),
is_responsive=(
sanderegg marked this conversation as resolved.
Show resolved Hide resolved
await app.state.ssm_client.ping() if app.state.ssm_client else False
),
),
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from typing import Final

from aws_library.ec2._models import AWSTagKey, AWSTagValue
from pydantic import parse_obj_as

DOCKER_STACK_DEPLOY_COMMAND_NAME: Final[str] = "private cluster docker deploy"
DOCKER_STACK_DEPLOY_COMMAND_EC2_TAG_KEY: Final[AWSTagKey] = parse_obj_as(
AWSTagKey, "io.simcore.clusters-keeper.private_cluster_docker_deploy"
)

USER_ID_TAG_KEY: Final[AWSTagKey] = parse_obj_as(AWSTagKey, "user_id")
sanderegg marked this conversation as resolved.
Show resolved Hide resolved
WALLET_ID_TAG_KEY: Final[AWSTagKey] = parse_obj_as(AWSTagKey, "wallet_id")
ROLE_TAG_KEY: Final[AWSTagKey] = parse_obj_as(AWSTagKey, "role")
WORKER_ROLE_TAG_VALUE: Final[AWSTagValue] = parse_obj_as(AWSTagValue, "worker")
MANAGER_ROLE_TAG_VALUE: Final[AWSTagValue] = parse_obj_as(AWSTagValue, "manager")
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from ..modules.ec2 import setup as setup_ec2
from ..modules.rabbitmq import setup as setup_rabbitmq
from ..modules.redis import setup as setup_redis
from ..modules.ssm import setup as setup_ssm
from ..rpc.rpc_routes import setup_rpc_routes
from .settings import ApplicationSettings

Expand Down Expand Up @@ -55,6 +56,7 @@ def create_app(settings: ApplicationSettings) -> FastAPI:
setup_rabbitmq(app)
setup_rpc_routes(app)
setup_ec2(app)
setup_ssm(app)
setup_redis(app)
setup_clusters_management(app)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from settings_library.ec2 import EC2Settings
from settings_library.rabbit import RabbitSettings
from settings_library.redis import RedisSettings
from settings_library.ssm import SSMSettings
from settings_library.tracing import TracingSettings
from settings_library.utils_logging import MixinLoggingSettings
from types_aiobotocore_ec2.literals import InstanceTypeType
Expand All @@ -50,6 +51,21 @@ class Config(EC2Settings.Config):
}


class ClustersKeeperSSMSettings(SSMSettings):
class Config(SSMSettings.Config):
env_prefix = CLUSTERS_KEEPER_ENV_PREFIX
sanderegg marked this conversation as resolved.
Show resolved Hide resolved

schema_extra: ClassVar[dict[str, Any]] = { # type: ignore[misc]
"examples": [
{
f"{CLUSTERS_KEEPER_ENV_PREFIX}{key}": var
for key, var in example.items()
}
for example in SSMSettings.Config.schema_extra["examples"]
],
}


class WorkersEC2InstancesSettings(BaseCustomSettings):
WORKERS_EC2_INSTANCES_ALLOWED_TYPES: dict[str, EC2InstanceBootSpecific] = Field(
...,
Expand Down Expand Up @@ -183,6 +199,12 @@ class PrimaryEC2InstancesSettings(BaseCustomSettings):
"that take longer than this time will be terminated as sometimes it happens that EC2 machine fail on start.",
)

PRIMARY_EC2_INSTANCES_DOCKER_DEFAULT_ADDRESS_POOL: str = Field(
default="172.20.0.0/14",
description="defines the docker swarm default address pool in CIDR format "
"(see https://docs.docker.com/reference/cli/docker/swarm/init/)",
)

@validator("PRIMARY_EC2_INSTANCES_ALLOWED_TYPES")
@classmethod
def check_valid_instance_names(
Expand Down Expand Up @@ -250,6 +272,10 @@ class ApplicationSettings(BaseCustomSettings, MixinLoggingSettings):
auto_default_from_env=True
)

CLUSTERS_KEEPER_SSM_ACCESS: ClustersKeeperSSMSettings | None = Field(
auto_default_from_env=True
)

CLUSTERS_KEEPER_PRIMARY_EC2_INSTANCES: PrimaryEC2InstancesSettings | None = Field(
auto_default_from_env=True
)
Expand Down Expand Up @@ -285,9 +311,11 @@ class ApplicationSettings(BaseCustomSettings, MixinLoggingSettings):
"(default to seconds, or see https://pydantic-docs.helpmanual.io/usage/types/#datetime-types for string formating)",
)

CLUSTERS_KEEPER_MAX_MISSED_HEARTBEATS_BEFORE_CLUSTER_TERMINATION: NonNegativeInt = Field(
default=5,
description="Max number of missed heartbeats before a cluster is terminated",
CLUSTERS_KEEPER_MAX_MISSED_HEARTBEATS_BEFORE_CLUSTER_TERMINATION: NonNegativeInt = (
Field(
default=5,
description="Max number of missed heartbeats before a cluster is terminated",
)
)

CLUSTERS_KEEPER_COMPUTATIONAL_BACKEND_DOCKER_IMAGE_TAG: str = Field(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ async def _get_primary_ec2_params(
ec2_instance_types: list[
EC2InstanceType
] = await ec2_client.get_ec2_instance_capabilities(
instance_type_names=[ec2_type_name]
instance_type_names={ec2_type_name}
)
assert ec2_instance_types # nosec
assert len(ec2_instance_types) == 1 # nosec
Expand All @@ -72,15 +72,7 @@ async def create_cluster(
tags=creation_ec2_tags(app_settings, user_id=user_id, wallet_id=wallet_id),
startup_script=create_startup_script(
app_settings,
cluster_machines_name_prefix=get_cluster_name(
app_settings, user_id=user_id, wallet_id=wallet_id, is_manager=False
),
ec2_boot_specific=ec2_instance_boot_specs,
additional_custom_tags={
AWSTagKey("user_id"): AWSTagValue(f"{user_id}"),
AWSTagKey("wallet_id"): AWSTagValue(f"{wallet_id}"),
AWSTagKey("role"): AWSTagValue("worker"),
},
),
ami_id=ec2_instance_boot_specs.ami_id,
key_name=app_settings.CLUSTERS_KEEPER_PRIMARY_EC2_INSTANCES.PRIMARY_EC2_INSTANCES_KEY_NAME,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,40 @@

import arrow
from aws_library.ec2 import AWSTagKey, EC2InstanceData
from aws_library.ec2._models import AWSTagValue
from fastapi import FastAPI
from models_library.users import UserID
from models_library.wallets import WalletID
from pydantic import parse_obj_as
from servicelib.logging_utils import log_catch

from servicelib.utils import limited_gather

from ..constants import (
DOCKER_STACK_DEPLOY_COMMAND_EC2_TAG_KEY,
DOCKER_STACK_DEPLOY_COMMAND_NAME,
ROLE_TAG_KEY,
USER_ID_TAG_KEY,
WALLET_ID_TAG_KEY,
WORKER_ROLE_TAG_VALUE,
)
from ..core.settings import get_application_settings
from ..modules.clusters import (
delete_clusters,
get_all_clusters,
get_cluster_workers,
set_instance_heartbeat,
)
from ..utils.clusters import create_deploy_cluster_stack_script
from ..utils.dask import get_scheduler_auth, get_scheduler_url
from ..utils.ec2 import HEARTBEAT_TAG_KEY
from ..utils.ec2 import (
HEARTBEAT_TAG_KEY,
get_cluster_name,
user_id_from_instance_tags,
wallet_id_from_instance_tags,
)
from .dask import is_scheduler_busy, ping_scheduler
from .ec2 import get_ec2_client
from .ssm import get_ssm_client

_logger = logging.getLogger(__name__)

Expand All @@ -42,8 +60,8 @@ def _get_instance_last_heartbeat(instance: EC2InstanceData) -> datetime.datetime
async def _get_all_associated_worker_instances(
app: FastAPI,
primary_instances: Iterable[EC2InstanceData],
) -> list[EC2InstanceData]:
worker_instances = []
) -> set[EC2InstanceData]:
worker_instances: set[EC2InstanceData] = set()
for instance in primary_instances:
assert "user_id" in instance.tags # nosec
user_id = UserID(instance.tags[_USER_ID_TAG_KEY])
Expand All @@ -55,20 +73,20 @@ async def _get_all_associated_worker_instances(
else None
)

worker_instances.extend(
worker_instances.update(
await get_cluster_workers(app, user_id=user_id, wallet_id=wallet_id)
)
return worker_instances


async def _find_terminateable_instances(
app: FastAPI, instances: Iterable[EC2InstanceData]
) -> list[EC2InstanceData]:
) -> set[EC2InstanceData]:
app_settings = get_application_settings(app)
assert app_settings.CLUSTERS_KEEPER_PRIMARY_EC2_INSTANCES # nosec

# get the corresponding ec2 instance data
terminateable_instances: list[EC2InstanceData] = []
terminateable_instances: set[EC2InstanceData] = set()

time_to_wait_before_termination = (
app_settings.CLUSTERS_KEEPER_MAX_MISSED_HEARTBEATS_BEFORE_CLUSTER_TERMINATION
Expand All @@ -82,7 +100,7 @@ async def _find_terminateable_instances(
elapsed_time_since_heartbeat = arrow.utcnow().datetime - last_heartbeat
allowed_time_to_wait = time_to_wait_before_termination
if elapsed_time_since_heartbeat >= allowed_time_to_wait:
terminateable_instances.append(instance)
terminateable_instances.add(instance)
else:
_logger.info(
"%s has still %ss before being terminateable",
Expand All @@ -93,14 +111,14 @@ async def _find_terminateable_instances(
elapsed_time_since_startup = arrow.utcnow().datetime - instance.launch_time
allowed_time_to_wait = startup_delay
if elapsed_time_since_startup >= allowed_time_to_wait:
terminateable_instances.append(instance)
terminateable_instances.add(instance)

# get all terminateable instances associated worker instances
worker_instances = await _get_all_associated_worker_instances(
app, terminateable_instances
)

return terminateable_instances + worker_instances
return terminateable_instances.union(worker_instances)


async def check_clusters(app: FastAPI) -> None:
Expand All @@ -112,6 +130,7 @@ async def check_clusters(app: FastAPI) -> None:
if await ping_scheduler(get_scheduler_url(instance), get_scheduler_auth(app))
}

# set intance heartbeat if scheduler is busy
for instance in connected_intances:
with log_catch(_logger, reraise=False):
# NOTE: some connected instance could in theory break between these 2 calls, therefore this is silenced and will
Expand All @@ -124,6 +143,7 @@ async def check_clusters(app: FastAPI) -> None:
f"{instance.id=} for {instance.tags=}",
)
await set_instance_heartbeat(app, instance=instance)
# clean any cluster that is not doing anything
if terminateable_instances := await _find_terminateable_instances(
app, connected_intances
):
Expand All @@ -138,7 +158,7 @@ async def check_clusters(app: FastAPI) -> None:
for instance in disconnected_instances
if _get_instance_last_heartbeat(instance) is None
}

# remove instances that were starting for too long
if terminateable_instances := await _find_terminateable_instances(
app, starting_instances
):
Expand All @@ -149,7 +169,72 @@ async def check_clusters(app: FastAPI) -> None:
)
await delete_clusters(app, instances=terminateable_instances)

# the other instances are broken (they were at some point connected but now not anymore)
# NOTE: transmit command to start docker swarm/stack if needed
# once the instance is connected to the SSM server,
# use ssm client to send the command to these instances,
# we send a command that contain:
# the docker-compose file in binary,
# the call to init the docker swarm and the call to deploy the stack
instances_in_need_of_deployment = {
i
for i in starting_instances - terminateable_instances
if DOCKER_STACK_DEPLOY_COMMAND_EC2_TAG_KEY not in i.tags
}

if instances_in_need_of_deployment:
app_settings = get_application_settings(app)
ssm_client = get_ssm_client(app)
ec2_client = get_ec2_client(app)
instances_in_need_of_deployment_ssm_connection_state = await limited_gather(
*[
ssm_client.is_instance_connected_to_ssm_server(i.id)
for i in instances_in_need_of_deployment
],
reraise=False,
log=_logger,
limit=20,
)
ec2_connected_to_ssm_server = [
i
for i, c in zip(
instances_in_need_of_deployment,
instances_in_need_of_deployment_ssm_connection_state,
strict=True,
)
if c is True
]
started_instances_ready_for_command = ec2_connected_to_ssm_server
if started_instances_ready_for_command:
# we need to send 1 command per machine here, as the user_id/wallet_id changes
for i in started_instances_ready_for_command:
ssm_command = await ssm_client.send_command(
[i.id],
command=create_deploy_cluster_stack_script(
app_settings,
cluster_machines_name_prefix=get_cluster_name(
app_settings,
user_id=user_id_from_instance_tags(i.tags),
wallet_id=wallet_id_from_instance_tags(i.tags),
is_manager=False,
),
additional_custom_tags={
USER_ID_TAG_KEY: i.tags[USER_ID_TAG_KEY],
WALLET_ID_TAG_KEY: i.tags[WALLET_ID_TAG_KEY],
ROLE_TAG_KEY: WORKER_ROLE_TAG_VALUE,
},
),
command_name=DOCKER_STACK_DEPLOY_COMMAND_NAME,
)
await ec2_client.set_instances_tags(
started_instances_ready_for_command,
tags={
DOCKER_STACK_DEPLOY_COMMAND_EC2_TAG_KEY: AWSTagValue(
ssm_command.command_id
),
},
)

# the remaining instances are broken (they were at some point connected but now not anymore)
broken_instances = disconnected_instances - starting_instances
if terminateable_instances := await _find_terminateable_instances(
app, broken_instances
Expand Down
Loading
Loading