diff --git a/docs/source/changes/218.respect_ssh_maxsessions.yaml b/docs/source/changes/218.respect_ssh_maxsessions.yaml new file mode 100644 index 00000000..d0bb08db --- /dev/null +++ b/docs/source/changes/218.respect_ssh_maxsessions.yaml @@ -0,0 +1,10 @@ +category: changed +summary: "SSHExecutor respects the remote MaxSessions via queueing" +description: | + The SSHExecutor now is aware of sshd MaxSessions, which is a limit on the concurrent + operations per connection. If more operations are to be run at once, operations are + queued until a session becomes available. +pull requests: +- 218 +issues: +- 217 diff --git a/setup.py b/setup.py index 86f0705d..db2183a9 100644 --- a/setup.py +++ b/setup.py @@ -77,6 +77,7 @@ def get_cryptography_version(): "aioprometheus>=21.9.0", "kubernetes_asyncio", "pydantic", + "asyncstdlib", ], extras_require={ "docs": [ diff --git a/tardis/utilities/executors/sshexecutor.py b/tardis/utilities/executors/sshexecutor.py index 3003b0e7..40c6785c 100644 --- a/tardis/utilities/executors/sshexecutor.py +++ b/tardis/utilities/executors/sshexecutor.py @@ -1,3 +1,4 @@ +from typing import Optional from ...configuration.utilities import enable_yaml_load from ...exceptions.executorexceptions import CommandExecutionFailure from ...interfaces.executor import Executor @@ -5,13 +6,39 @@ import asyncio import asyncssh +from asyncstdlib import ( + ExitStack as AsyncExitStack, + contextmanager as asynccontextmanager, +) + + +async def probe_max_session(connection: asyncssh.SSHClientConnection): + """ + Probe the sshd `MaxSessions`, i.e. the multiplexing limit per connection + """ + sessions = 0 + # It does not actually matter what kind of session we open here, but: + # - it should stay open without a separate task to manage it + # - it should reliably and promptly clean up when done probing + # `create_process` is a bit heavy but does all that. + async with AsyncExitStack() as aes: + try: + while True: + await aes.enter_context(await connection.create_process()) + sessions += 1 + except asyncssh.ChannelOpenError: + pass + return sessions @enable_yaml_load("!SSHExecutor") class SSHExecutor(Executor): def __init__(self, **parameters): self._parameters = parameters - self._ssh_connection = None + # the current SSH connection or None if it must be (re-)established + self._ssh_connection: Optional[asyncssh.SSHClientConnection] = None + # the bound on MaxSession running concurrently + self._session_bound: Optional[asyncio.Semaphore] = None self._lock = None async def _establish_connection(self): @@ -28,16 +55,31 @@ async def _establish_connection(self): return await asyncssh.connect(**self._parameters) @property - async def ssh_connection(self): + @asynccontextmanager + async def bounded_connection(self): + """ + Get the current connection with a single reserved session slot + + This is a context manager that guards the current + :py:class:`~asyncssh.SSHClientConnection` + so that only `MaxSessions` commands run at once. + """ if self._ssh_connection is None: async with self.lock: - # check that connection has not yet been initialize in a different task + # check that connection has not been initialized in a different task while self._ssh_connection is None: self._ssh_connection = await self._establish_connection() - return self._ssh_connection + max_session = await probe_max_session(self._ssh_connection) + self._session_bound = asyncio.Semaphore(value=max_session) + assert self._ssh_connection is not None + assert self._session_bound is not None + bound, session = self._session_bound, self._ssh_connection + async with bound: + yield session @property def lock(self): + """Lock protecting the connection""" # Create lock once tardis event loop is running. # To avoid got Future attached to a different loop exception if self._lock is None: @@ -45,31 +87,35 @@ def lock(self): return self._lock async def run_command(self, command, stdin_input=None): - ssh_connection = await self.ssh_connection - try: - response = await ssh_connection.run( - command, check=True, input=stdin_input and stdin_input.encode() - ) - except asyncssh.ProcessError as pe: - raise CommandExecutionFailure( - message=f"Run command {command} via SSHExecutor failed", - exit_code=pe.exit_status, - stdin=stdin_input, - stdout=pe.stdout, - stderr=pe.stderr, - ) from pe - except asyncssh.ChannelOpenError as coe: - # Broken connection will be replaced by a new connection during next call - self._ssh_connection = None - raise CommandExecutionFailure( - message=f"Could not run command {command} due to SSH failure: {coe}", - exit_code=255, - stdout="", - stderr="SSH Broken Connection", - ) from coe - else: - return AttributeDict( - stdout=response.stdout, - stderr=response.stderr, - exit_code=response.exit_status, - ) + async with self.bounded_connection as ssh_connection: + try: + response = await ssh_connection.run( + command, check=True, input=stdin_input and stdin_input.encode() + ) + except asyncssh.ProcessError as pe: + raise CommandExecutionFailure( + message=f"Run command {command} via SSHExecutor failed", + exit_code=pe.exit_status, + stdin=stdin_input, + stdout=pe.stdout, + stderr=pe.stderr, + ) from pe + except asyncssh.ChannelOpenError as coe: + # clear broken connection to get it replaced + # by a new connection during next command + if ssh_connection is self._ssh_connection: + self._ssh_connection = None + raise CommandExecutionFailure( + message=( + f"Could not run command {command} due to SSH failure: {coe}" + ), + exit_code=255, + stdout="", + stderr="SSH Broken Connection", + ) from coe + else: + return AttributeDict( + stdout=response.stdout, + stderr=response.stderr, + exit_code=response.exit_status, + ) diff --git a/tests/utilities_t/executors_t/test_sshexecutor.py b/tests/utilities_t/executors_t/test_sshexecutor.py index 37a44769..d47b65ce 100644 --- a/tests/utilities_t/executors_t/test_sshexecutor.py +++ b/tests/utilities_t/executors_t/test_sshexecutor.py @@ -1,6 +1,6 @@ from tests.utilities.utilities import async_return, run_async from tardis.utilities.attributedict import AttributeDict -from tardis.utilities.executors.sshexecutor import SSHExecutor +from tardis.utilities.executors.sshexecutor import SSHExecutor, probe_max_session from tardis.exceptions.executorexceptions import CommandExecutionFailure from asyncssh import ChannelOpenError, ConnectionLost, DisconnectError, ProcessError @@ -10,18 +10,63 @@ import asyncio import yaml +import contextlib +from asyncstdlib import contextmanager as asynccontextmanager + + +DEFAULT_MAX_SESSIONS = 10 class MockConnection(object): - def __init__(self, exception=None, **kwargs): + def __init__(self, exception=None, __max_sessions=DEFAULT_MAX_SESSIONS, **kwargs): self.exception = exception and exception(**kwargs) + self.max_sessions = __max_sessions + self.current_sessions = 0 + + @contextlib.contextmanager + def _multiplex_session(self): + if self.current_sessions >= self.max_sessions: + raise ChannelOpenError(code=2, reason="open failed") + self.current_sessions += 1 + try: + yield + finally: + self.current_sessions -= 1 async def run(self, command, input=None, **kwargs): - if self.exception: - raise self.exception - return AttributeDict( - stdout=input and input.decode(), stderr="TestError", exit_status=0 - ) + with self._multiplex_session(): + if self.exception: + raise self.exception + if command.startswith("sleep"): + _, duration = command.split() + await asyncio.sleep(float(duration)) + elif command != "Test": + raise ValueError(f"Unsupported mock command: {command}") + return AttributeDict( + stdout=input and input.decode(), stderr="TestError", exit_status=0 + ) + + async def create_process(self): + @asynccontextmanager + async def fake_process(): + with self._multiplex_session(): + yield + + return fake_process() + + +class TestSSHExecutorUtilities(TestCase): + def test_max_sessions(self): + with self.subTest(sessions="default"): + self.assertEqual( + DEFAULT_MAX_SESSIONS, run_async(probe_max_session, MockConnection()) + ) + for expected in (1, 9, 11, 20, 100): + with self.subTest(sessions=expected): + self.assertEqual( + expected, + run_async(probe_max_session, MockConnection(None, expected)), + ) class TestSSHExecutor(TestCase): @@ -80,23 +125,43 @@ def test_establish_connection(self): self.mock_asyncssh.connect.side_effect = None def test_connection_property(self): - async def helper_coroutine(): - return await self.executor.ssh_connection + async def force_connection(): + async with self.executor.bounded_connection as connection: + return connection self.assertIsNone(self.executor._ssh_connection) - run_async(helper_coroutine) - + run_async(force_connection) self.assertIsInstance(self.executor._ssh_connection, MockConnection) - current_ssh_connection = self.executor._ssh_connection - - run_async(helper_coroutine) - + run_async(force_connection) + # make sure the connection is not needlessly replaced self.assertEqual(self.executor._ssh_connection, current_ssh_connection) def test_lock(self): self.assertIsInstance(self.executor.lock, asyncio.Lock) + def test_connection_queueing(self): + async def is_queued(n: int): + """Check whether the n'th command runs is queued or immediately""" + background = [ + asyncio.ensure_future(self.executor.run_command("sleep 5")) + for _ in range(n - 1) + ] + # probe can only finish in time if it is not queued + probe = asyncio.ensure_future(self.executor.run_command("sleep 0.01")) + await asyncio.sleep(0.05) + queued = not probe.done() + for task in background + [probe]: + task.cancel() + return queued + + for sessions in (1, 8, 10, 12, 20): + with self.subTest(sessions=sessions): + self.assertEqual( + sessions > DEFAULT_MAX_SESSIONS, + run_async(is_queued, sessions), + ) + def test_run_command(self): self.assertIsNone(run_async(self.executor.run_command, command="Test").stdout) self.mock_asyncssh.connect.assert_called_with(