Skip to content

Commit

Permalink
Pass StartSession response as env variable
Browse files Browse the repository at this point in the history
Pass the StartSession API response as environment variable to the
session-manager-plugin
  • Loading branch information
Yangtao-Hua authored and kdaily committed Nov 17, 2023
1 parent a453709 commit da9039a
Show file tree
Hide file tree
Showing 4 changed files with 511 additions and 98 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"type": "enhancement",
"category": "``ssm`` Session Manager",
"description": "Pass StartSession API response as environment variable to session-manager-plugin"
}
100 changes: 85 additions & 15 deletions awscli/customizations/sessionmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,20 @@
import logging
import json
import errno
import os
import re

from subprocess import check_call
from subprocess import check_call, check_output
from awscli.compat import ignore_user_entered_signals
from awscli.clidriver import ServiceOperation, CLIOperationCaller

logger = logging.getLogger(__name__)

ERROR_MESSAGE = (
'SessionManagerPlugin is not found. ',
'Please refer to SessionManager Documentation here: ',
'http://docs.aws.amazon.com/console/systems-manager/',
'session-manager-plugin-not-found'
"SessionManagerPlugin is not found. ",
"Please refer to SessionManager Documentation here: ",
"http://docs.aws.amazon.com/console/systems-manager/",
"session-manager-plugin-not-found",
)


Expand All @@ -44,8 +46,43 @@ def add_custom_start_session(session, command_table, **kwargs):
)


class StartSessionCommand(ServiceOperation):
class VersionRequirement:
WHITESPACE_REGEX = re.compile(r"\s+")
SSM_SESSION_PLUGIN_VERSION_REGEX = re.compile(r"^\d+(\.\d+){0,3}$")

def __init__(self, min_version):
self.min_version = min_version

def meets_requirement(self, version):
ssm_plugin_version = self._sanitize_plugin_version(version)
if self._is_valid_version(ssm_plugin_version):
norm_version, norm_min_version = self._normalize(
ssm_plugin_version, self.min_version
)
return norm_version > norm_min_version
else:
return False

def _sanitize_plugin_version(self, plugin_version):
return re.sub(self.WHITESPACE_REGEX, "", plugin_version)

def _is_valid_version(self, plugin_version):
return bool(
self.SSM_SESSION_PLUGIN_VERSION_REGEX.match(plugin_version)
)

def _normalize(self, v1, v2):
v1_parts = [int(v) for v in v1.split(".")]
v2_parts = [int(v) for v in v2.split(".")]
while len(v1_parts) != len(v2_parts):
if len(v1_parts) - len(v2_parts) > 0:
v2_parts.append(0)
else:
v1_parts.append(0)
return v1_parts, v2_parts


class StartSessionCommand(ServiceOperation):
def create_help_command(self):
help_command = super(
StartSessionCommand, self).create_help_command()
Expand All @@ -55,8 +92,10 @@ def create_help_command(self):


class StartSessionCaller(CLIOperationCaller):
def invoke(self, service_name, operation_name, parameters,
parsed_globals):
LAST_PLUGIN_VERSION_WITHOUT_ENV_VAR = "1.2.497.0"
DEFAULT_SSM_ENV_NAME = "AWS_SSM_START_SESSION_RESPONSE"

def invoke(self, service_name, operation_name, parameters, parsed_globals):
client = self._session.create_client(
service_name, region_name=parsed_globals.region,
endpoint_url=parsed_globals.endpoint_url,
Expand All @@ -70,8 +109,34 @@ def invoke(self, service_name, operation_name, parameters,
profile_name = self._session.profile \
if self._session.profile is not None else ''
endpoint_url = client.meta.endpoint_url
ssm_env_name = self.DEFAULT_SSM_ENV_NAME

try:
session_parameters = {
"SessionId": response["SessionId"],
"TokenValue": response["TokenValue"],
"StreamUrl": response["StreamUrl"],
}
start_session_response = json.dumps(session_parameters)

plugin_version = check_output(
["session-manager-plugin", "--version"], text=True
)
env = os.environ.copy()

# Check if this plugin supports passing the start session response
# as an environment variable name. If it does, it will set the
# value to the response from the start_session operation to the env
# variable defined in DEFAULT_SSM_ENV_NAME. If the session plugin
# version is invalid or older than the version defined in
# LAST_PLUGIN_VERSION_WITHOUT_ENV_VAR, it will fall back to
# passing the start_session response directly.
version_requirement = VersionRequirement(
min_version=self.LAST_PLUGIN_VERSION_WITHOUT_ENV_VAR
)
if version_requirement.meets_requirement(plugin_version):
env[ssm_env_name] = start_session_response
start_session_response = ssm_env_name
# ignore_user_entered_signals ignores these signals
# because if signals which kills the process are not
# captured would kill the foreground process but not the
Expand All @@ -80,13 +145,18 @@ def invoke(self, service_name, operation_name, parameters,
# and handling in there
with ignore_user_entered_signals():
# call executable with necessary input
check_call(["session-manager-plugin",
json.dumps(response),
region_name,
"StartSession",
profile_name,
json.dumps(parameters),
endpoint_url])
check_call(
[
"session-manager-plugin",
start_session_response,
region_name,
"StartSession",
profile_name,
json.dumps(parameters),
endpoint_url,
],
env=env,
)
return 0
except OSError as ex:
if ex.errno == errno.ENOENT:
Expand Down
121 changes: 99 additions & 22 deletions tests/functional/ssm/test_start_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,38 +15,115 @@

from awscli.testutils import BaseAWSCommandParamsTest
from awscli.testutils import BaseAWSHelpOutputTest
from awscli.testutils import mock
from awscli.testutils import mock


class TestSessionManager(BaseAWSCommandParamsTest):
@mock.patch("awscli.customizations.sessionmanager.check_call")
@mock.patch("awscli.customizations.sessionmanager.check_output")
def test_start_session_success(self, mock_check_output, mock_check_call):
cmdline = "ssm start-session --target instance-id"
mock_check_call.return_value = 0
mock_check_output.return_value = "1.2.0.0\n"
expected_response = {
"SessionId": "session-id",
"TokenValue": "token-value",
"StreamUrl": "stream-url",
}
self.parsed_responses = [expected_response]
start_session_params = {"Target": "instance-id"}

@mock.patch('awscli.customizations.sessionmanager.check_call')
def test_start_session_success(self, mock_check_call):
cmdline = 'ssm start-session --target instance-id'
self.run_cmd(cmdline, expected_rc=0)

mock_check_call.assert_called_once_with(
[
"session-manager-plugin",
json.dumps(expected_response),
mock.ANY,
"StartSession",
mock.ANY,
json.dumps(start_session_params),
mock.ANY,
],
env=self.environ,
)

@mock.patch("awscli.customizations.sessionmanager.check_call")
@mock.patch("awscli.customizations.sessionmanager.check_output")
def test_start_session_with_new_version_plugin_success(
self, mock_check_output, mock_check_call
):
cmdline = "ssm start-session --target instance-id"
mock_check_call.return_value = 0
self.parsed_responses = [{
mock_check_output.return_value = "1.2.500.0\n"
expected_response = {
"SessionId": "session-id",
"TokenValue": "token-value",
"StreamUrl": "stream-url"
}]
"StreamUrl": "stream-url",
}
self.parsed_responses = [expected_response]

ssm_env_name = "AWS_SSM_START_SESSION_RESPONSE"
start_session_params = {"Target": "instance-id"}
expected_env = self.environ.copy()
expected_env.update({ssm_env_name: json.dumps(expected_response)})

self.run_cmd(cmdline, expected_rc=0)
self.assertEqual(self.operations_called[0][0].name,
'StartSession')
self.assertEqual(self.operations_called[0][1],
{'Target': 'instance-id'})
actual_response = json.loads(mock_check_call.call_args[0][0][1])

mock_check_call.assert_called_once_with(
[
"session-manager-plugin",
ssm_env_name,
mock.ANY,
"StartSession",
mock.ANY,
json.dumps(start_session_params),
mock.ANY,
],
env=expected_env,
)

@mock.patch("awscli.customizations.sessionmanager.check_call")
@mock.patch("awscli.customizations.sessionmanager.check_output")
def test_start_session_fails(self, mock_check_output, mock_check_call):
cmdline = "ssm start-session --target instance-id"
mock_check_output.return_value = "1.2.500.0\n"
mock_check_call.side_effect = OSError(errno.ENOENT, "some error")
self.parsed_responses = [
{
"SessionId": "session-id",
"TokenValue": "token-value",
"StreamUrl": "stream-url",
}
]
self.run_cmd(cmdline, expected_rc=255)
self.assertEqual(
self.operations_called[0][0].name, "StartSession"
)
self.assertEqual(
self.operations_called[0][1], {"Target": "instance-id"}
)
self.assertEqual(
self.operations_called[1][0].name, "TerminateSession"
)
self.assertEqual(
{"SessionId": "session-id",
"TokenValue": "token-value",
"StreamUrl": "stream-url"},
actual_response)
self.operations_called[1][1], {"SessionId": "session-id"}
)

@mock.patch('awscli.customizations.sessionmanager.check_call')
def test_start_session_fails(self, mock_check_call):
@mock.patch("awscli.customizations.sessionmanager.check_call")
@mock.patch("awscli.customizations.sessionmanager.check_output")
def test_start_session_when_get_plugin_version_fails(
self, mock_check_output, mock_check_call
):
cmdline = 'ssm start-session --target instance-id'
mock_check_call.side_effect = OSError(errno.ENOENT, 'some error')
self.parsed_responses = [{
"SessionId": "session-id"
}]
mock_check_output.side_effect = OSError(errno.ENOENT, 'some error')
self.parsed_responses = [
{
"SessionId": "session-id",
"TokenValue": "token-value",
"StreamUrl": "stream-url",
}
]
self.run_cmd(cmdline, expected_rc=255)
self.assertEqual(self.operations_called[0][0].name,
'StartSession')
Expand Down
Loading

0 comments on commit da9039a

Please sign in to comment.