Skip to content

Commit

Permalink
Merge pull request #218 from MatterMiners/ssh-maxsessions
Browse files Browse the repository at this point in the history
SSHExecutor respects MaxSessions
  • Loading branch information
giffels authored Nov 19, 2021
2 parents 2afcbd4 + 2d31c69 commit c532a19
Show file tree
Hide file tree
Showing 4 changed files with 169 additions and 47 deletions.
10 changes: 10 additions & 0 deletions docs/source/changes/218.respect_ssh_maxsessions.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
category: changed
summary: "SSHExecutor respects the remote MaxSessions via queueing"
description: |
The SSHExecutor now is aware of sshd MaxSessions, which is a limit on the concurrent
operations per connection. If more operations are to be run at once, operations are
queued until a session becomes available.
pull requests:
- 218
issues:
- 217
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def get_cryptography_version():
"aioprometheus>=21.9.0",
"kubernetes_asyncio",
"pydantic",
"asyncstdlib",
],
extras_require={
"docs": [
Expand Down
110 changes: 78 additions & 32 deletions tardis/utilities/executors/sshexecutor.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,44 @@
from typing import Optional
from ...configuration.utilities import enable_yaml_load
from ...exceptions.executorexceptions import CommandExecutionFailure
from ...interfaces.executor import Executor
from ..attributedict import AttributeDict

import asyncio
import asyncssh
from asyncstdlib import (
ExitStack as AsyncExitStack,
contextmanager as asynccontextmanager,
)


async def probe_max_session(connection: asyncssh.SSHClientConnection):
"""
Probe the sshd `MaxSessions`, i.e. the multiplexing limit per connection
"""
sessions = 0
# It does not actually matter what kind of session we open here, but:
# - it should stay open without a separate task to manage it
# - it should reliably and promptly clean up when done probing
# `create_process` is a bit heavy but does all that.
async with AsyncExitStack() as aes:
try:
while True:
await aes.enter_context(await connection.create_process())
sessions += 1
except asyncssh.ChannelOpenError:
pass
return sessions


@enable_yaml_load("!SSHExecutor")
class SSHExecutor(Executor):
def __init__(self, **parameters):
self._parameters = parameters
self._ssh_connection = None
# the current SSH connection or None if it must be (re-)established
self._ssh_connection: Optional[asyncssh.SSHClientConnection] = None
# the bound on MaxSession running concurrently
self._session_bound: Optional[asyncio.Semaphore] = None
self._lock = None

async def _establish_connection(self):
Expand All @@ -28,48 +55,67 @@ async def _establish_connection(self):
return await asyncssh.connect(**self._parameters)

@property
async def ssh_connection(self):
@asynccontextmanager
async def bounded_connection(self):
"""
Get the current connection with a single reserved session slot
This is a context manager that guards the current
:py:class:`~asyncssh.SSHClientConnection`
so that only `MaxSessions` commands run at once.
"""
if self._ssh_connection is None:
async with self.lock:
# check that connection has not yet been initialize in a different task
# check that connection has not been initialized in a different task
while self._ssh_connection is None:
self._ssh_connection = await self._establish_connection()
return self._ssh_connection
max_session = await probe_max_session(self._ssh_connection)
self._session_bound = asyncio.Semaphore(value=max_session)
assert self._ssh_connection is not None
assert self._session_bound is not None
bound, session = self._session_bound, self._ssh_connection
async with bound:
yield session

@property
def lock(self):
"""Lock protecting the connection"""
# Create lock once tardis event loop is running.
# To avoid got Future <Future pending> attached to a different loop exception
if self._lock is None:
self._lock = asyncio.Lock()
return self._lock

async def run_command(self, command, stdin_input=None):
ssh_connection = await self.ssh_connection
try:
response = await ssh_connection.run(
command, check=True, input=stdin_input and stdin_input.encode()
)
except asyncssh.ProcessError as pe:
raise CommandExecutionFailure(
message=f"Run command {command} via SSHExecutor failed",
exit_code=pe.exit_status,
stdin=stdin_input,
stdout=pe.stdout,
stderr=pe.stderr,
) from pe
except asyncssh.ChannelOpenError as coe:
# Broken connection will be replaced by a new connection during next call
self._ssh_connection = None
raise CommandExecutionFailure(
message=f"Could not run command {command} due to SSH failure: {coe}",
exit_code=255,
stdout="",
stderr="SSH Broken Connection",
) from coe
else:
return AttributeDict(
stdout=response.stdout,
stderr=response.stderr,
exit_code=response.exit_status,
)
async with self.bounded_connection as ssh_connection:
try:
response = await ssh_connection.run(
command, check=True, input=stdin_input and stdin_input.encode()
)
except asyncssh.ProcessError as pe:
raise CommandExecutionFailure(
message=f"Run command {command} via SSHExecutor failed",
exit_code=pe.exit_status,
stdin=stdin_input,
stdout=pe.stdout,
stderr=pe.stderr,
) from pe
except asyncssh.ChannelOpenError as coe:
# clear broken connection to get it replaced
# by a new connection during next command
if ssh_connection is self._ssh_connection:
self._ssh_connection = None
raise CommandExecutionFailure(
message=(
f"Could not run command {command} due to SSH failure: {coe}"
),
exit_code=255,
stdout="",
stderr="SSH Broken Connection",
) from coe
else:
return AttributeDict(
stdout=response.stdout,
stderr=response.stderr,
exit_code=response.exit_status,
)
95 changes: 80 additions & 15 deletions tests/utilities_t/executors_t/test_sshexecutor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from tests.utilities.utilities import async_return, run_async
from tardis.utilities.attributedict import AttributeDict
from tardis.utilities.executors.sshexecutor import SSHExecutor
from tardis.utilities.executors.sshexecutor import SSHExecutor, probe_max_session
from tardis.exceptions.executorexceptions import CommandExecutionFailure

from asyncssh import ChannelOpenError, ConnectionLost, DisconnectError, ProcessError
Expand All @@ -10,18 +10,63 @@

import asyncio
import yaml
import contextlib
from asyncstdlib import contextmanager as asynccontextmanager


DEFAULT_MAX_SESSIONS = 10


class MockConnection(object):
def __init__(self, exception=None, **kwargs):
def __init__(self, exception=None, __max_sessions=DEFAULT_MAX_SESSIONS, **kwargs):
self.exception = exception and exception(**kwargs)
self.max_sessions = __max_sessions
self.current_sessions = 0

@contextlib.contextmanager
def _multiplex_session(self):
if self.current_sessions >= self.max_sessions:
raise ChannelOpenError(code=2, reason="open failed")
self.current_sessions += 1
try:
yield
finally:
self.current_sessions -= 1

async def run(self, command, input=None, **kwargs):
if self.exception:
raise self.exception
return AttributeDict(
stdout=input and input.decode(), stderr="TestError", exit_status=0
)
with self._multiplex_session():
if self.exception:
raise self.exception
if command.startswith("sleep"):
_, duration = command.split()
await asyncio.sleep(float(duration))
elif command != "Test":
raise ValueError(f"Unsupported mock command: {command}")
return AttributeDict(
stdout=input and input.decode(), stderr="TestError", exit_status=0
)

async def create_process(self):
@asynccontextmanager
async def fake_process():
with self._multiplex_session():
yield

return fake_process()


class TestSSHExecutorUtilities(TestCase):
def test_max_sessions(self):
with self.subTest(sessions="default"):
self.assertEqual(
DEFAULT_MAX_SESSIONS, run_async(probe_max_session, MockConnection())
)
for expected in (1, 9, 11, 20, 100):
with self.subTest(sessions=expected):
self.assertEqual(
expected,
run_async(probe_max_session, MockConnection(None, expected)),
)


class TestSSHExecutor(TestCase):
Expand Down Expand Up @@ -80,23 +125,43 @@ def test_establish_connection(self):
self.mock_asyncssh.connect.side_effect = None

def test_connection_property(self):
async def helper_coroutine():
return await self.executor.ssh_connection
async def force_connection():
async with self.executor.bounded_connection as connection:
return connection

self.assertIsNone(self.executor._ssh_connection)
run_async(helper_coroutine)

run_async(force_connection)
self.assertIsInstance(self.executor._ssh_connection, MockConnection)

current_ssh_connection = self.executor._ssh_connection

run_async(helper_coroutine)

run_async(force_connection)
# make sure the connection is not needlessly replaced
self.assertEqual(self.executor._ssh_connection, current_ssh_connection)

def test_lock(self):
self.assertIsInstance(self.executor.lock, asyncio.Lock)

def test_connection_queueing(self):
async def is_queued(n: int):
"""Check whether the n'th command runs is queued or immediately"""
background = [
asyncio.ensure_future(self.executor.run_command("sleep 5"))
for _ in range(n - 1)
]
# probe can only finish in time if it is not queued
probe = asyncio.ensure_future(self.executor.run_command("sleep 0.01"))
await asyncio.sleep(0.05)
queued = not probe.done()
for task in background + [probe]:
task.cancel()
return queued

for sessions in (1, 8, 10, 12, 20):
with self.subTest(sessions=sessions):
self.assertEqual(
sessions > DEFAULT_MAX_SESSIONS,
run_async(is_queued, sessions),
)

def test_run_command(self):
self.assertIsNone(run_async(self.executor.run_command, command="Test").stdout)
self.mock_asyncssh.connect.assert_called_with(
Expand Down

0 comments on commit c532a19

Please sign in to comment.