Skip to content

Commit

Permalink
Pass StartSession response as env variable
Browse files Browse the repository at this point in the history
Pass the StartSession API response as environment variable to the
session-manager-plugin
  • Loading branch information
Yangtao-Hua authored and kdaily committed Nov 17, 2023
1 parent a453709 commit 48e1c63
Show file tree
Hide file tree
Showing 4 changed files with 453 additions and 34 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"type": "enhancement",
"category": "``ssm`` Session Manager",
"description": "Pass StartSession API response as environment variable to session-manager-plugin"
}
75 changes: 71 additions & 4 deletions awscli/customizations/sessionmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
import logging
import json
import errno
import os
import re

from subprocess import check_call
from subprocess import check_call, check_output
from awscli.compat import ignore_user_entered_signals
from awscli.clidriver import ServiceOperation, CLIOperationCaller

Expand Down Expand Up @@ -44,8 +46,43 @@ def add_custom_start_session(session, command_table, **kwargs):
)


class StartSessionCommand(ServiceOperation):
class VersionRequirement:
WHITESPACE_REGEX = re.compile(r"\s+")
SSM_SESSION_PLUGIN_VERSION_REGEX = re.compile(r"^\d+(\.\d+){0,3}$")

def __init__(self, min_version):
self.min_version = min_version

def meets_requirement(self, version):
ssm_plugin_version = self._sanitize_plugin_version(version)
if self._is_valid_version(ssm_plugin_version):
norm_version, norm_min_version = self._normalize(
ssm_plugin_version, self.min_version
)
return norm_version > norm_min_version
else:
return False

def _sanitize_plugin_version(self, plugin_version):
return re.sub(self.WHITESPACE_REGEX, "", plugin_version)

def _is_valid_version(self, plugin_version):
return bool(
self.SSM_SESSION_PLUGIN_VERSION_REGEX.match(plugin_version)
)

def _normalize(self, v1, v2):
v1_parts = [int(v) for v in v1.split(".")]
v2_parts = [int(v) for v in v2.split(".")]
while len(v1_parts) != len(v2_parts):
if len(v1_parts) - len(v2_parts) > 0:
v2_parts.append(0)
else:
v1_parts.append(0)
return v1_parts, v2_parts


class StartSessionCommand(ServiceOperation):
def create_help_command(self):
help_command = super(
StartSessionCommand, self).create_help_command()
Expand All @@ -55,6 +92,9 @@ def create_help_command(self):


class StartSessionCaller(CLIOperationCaller):
LAST_PLUGIN_VERSION_WITHOUT_ENV_VAR = "1.2.497.0"
DEFAULT_SSM_ENV_NAME = "AWS_SSM_START_SESSION_RESPONSE"

def invoke(self, service_name, operation_name, parameters,
parsed_globals):
client = self._session.create_client(
Expand All @@ -70,8 +110,34 @@ def invoke(self, service_name, operation_name, parameters,
profile_name = self._session.profile \
if self._session.profile is not None else ''
endpoint_url = client.meta.endpoint_url
ssm_env_name = self.DEFAULT_SSM_ENV_NAME

try:
session_parameters = {
"SessionId": response["SessionId"],
"TokenValue": response["TokenValue"],
"StreamUrl": response["StreamUrl"],
}
start_session_response = json.dumps(session_parameters)

plugin_version = check_output(
["session-manager-plugin", "--version"], text=True
)
env = os.environ.copy()

# Check if this plugin supports passing the start session response
# as an environment variable name. If it does, it will set the
# value to the response from the start_session operation to the env
# variable defined in DEFAULT_SSM_ENV_NAME. If the session plugin
# version is invalid or older than the version defined in
# LAST_PLUGIN_VERSION_WITHOUT_ENV_VAR, it will fall back to
# passing the start_session response directly.
version_requirement = VersionRequirement(
min_version=self.LAST_PLUGIN_VERSION_WITHOUT_ENV_VAR
)
if version_requirement.meets_requirement(plugin_version):
env[ssm_env_name] = start_session_response
start_session_response = ssm_env_name
# ignore_user_entered_signals ignores these signals
# because if signals which kills the process are not
# captured would kill the foreground process but not the
Expand All @@ -81,12 +147,13 @@ def invoke(self, service_name, operation_name, parameters,
with ignore_user_entered_signals():
# call executable with necessary input
check_call(["session-manager-plugin",
json.dumps(response),
start_session_response,
region_name,
"StartSession",
profile_name,
json.dumps(parameters),
endpoint_url])
endpoint_url], env=env)

return 0
except OSError as ex:
if ex.errno == errno.ENOENT:
Expand Down
115 changes: 98 additions & 17 deletions tests/functional/ssm/test_start_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,38 +15,119 @@

from awscli.testutils import BaseAWSCommandParamsTest
from awscli.testutils import BaseAWSHelpOutputTest
from awscli.testutils import mock
from awscli.testutils import mock

class TestSessionManager(BaseAWSCommandParamsTest):

class TestSessionManager(BaseAWSCommandParamsTest):
@mock.patch('awscli.customizations.sessionmanager.check_call')
def test_start_session_success(self, mock_check_call):
@mock.patch("awscli.customizations.sessionmanager.check_output")
def test_start_session_success(self, mock_check_output, mock_check_call):
cmdline = 'ssm start-session --target instance-id'
mock_check_call.return_value = 0
self.parsed_responses = [{
mock_check_output.return_value = "1.2.0.0\n"
expected_response = {
"SessionId": "session-id",
"TokenValue": "token-value",
"StreamUrl": "stream-url"
}]
"StreamUrl": "stream-url",
}
self.parsed_responses = [expected_response]
start_session_params = {"Target": "instance-id"}

self.run_cmd(cmdline, expected_rc=0)

mock_check_call.assert_called_once_with(
[
"session-manager-plugin",
json.dumps(expected_response),
mock.ANY,
"StartSession",
mock.ANY,
json.dumps(start_session_params),
mock.ANY,
],
env=self.environ,
)

@mock.patch("awscli.customizations.sessionmanager.check_call")
@mock.patch("awscli.customizations.sessionmanager.check_output")
def test_start_session_with_new_version_plugin_success(
self, mock_check_output, mock_check_call
):
cmdline = "ssm start-session --target instance-id"
mock_check_call.return_value = 0
mock_check_output.return_value = "1.2.500.0\n"
expected_response = {
"SessionId": "session-id",
"TokenValue": "token-value",
"StreamUrl": "stream-url",
}
self.parsed_responses = [expected_response]

ssm_env_name = "AWS_SSM_START_SESSION_RESPONSE"
start_session_params = {"Target": "instance-id"}
expected_env = self.environ.copy()
expected_env.update({ssm_env_name: json.dumps(expected_response)})

self.run_cmd(cmdline, expected_rc=0)
self.assertEqual(self.operations_called[0][0].name,
'StartSession')
self.assertEqual(self.operations_called[0][1],
{'Target': 'instance-id'})
actual_response = json.loads(mock_check_call.call_args[0][0][1])
self.assertEqual(
{"SessionId": "session-id",
"TokenValue": "token-value",
"StreamUrl": "stream-url"},
actual_response)

mock_check_call.assert_called_once_with(
[
"session-manager-plugin",
ssm_env_name,
mock.ANY,
"StartSession",
mock.ANY,
json.dumps(start_session_params),
mock.ANY,
],
env=expected_env,
)

@mock.patch('awscli.customizations.sessionmanager.check_call')
def test_start_session_fails(self, mock_check_call):
@mock.patch("awscli.customizations.sessionmanager.check_output")
def test_start_session_fails(self, mock_check_output, mock_check_call):
cmdline = "ssm start-session --target instance-id"
mock_check_output.return_value = "1.2.500.0\n"
mock_check_call.side_effect = OSError(errno.ENOENT, "some error")
self.parsed_responses = [
{
"SessionId": "session-id",
"TokenValue": "token-value",
"StreamUrl": "stream-url",
}
]
self.run_cmd(cmdline, expected_rc=255)
self.assertEqual(
self.operations_called[0][0].name, "StartSession"
)
self.assertEqual(
self.operations_called[0][1], {"Target": "instance-id"}
)
self.assertEqual(
self.operations_called[1][0].name, "TerminateSession"
)
self.assertEqual(
self.operations_called[1][1], {"SessionId": "session-id"}
)

@mock.patch("awscli.customizations.sessionmanager.check_call")
@mock.patch("awscli.customizations.sessionmanager.check_output")
def test_start_session_when_get_plugin_version_fails(
self, mock_check_output, mock_check_call
):
cmdline = 'ssm start-session --target instance-id'
mock_check_call.side_effect = OSError(errno.ENOENT, 'some error')
self.parsed_responses = [{
"SessionId": "session-id"
}]
mock_check_output.side_effect = OSError(errno.ENOENT, 'some error')
self.parsed_responses = [
{
"SessionId": "session-id",
"TokenValue": "token-value",
"StreamUrl": "stream-url",
}
]
self.run_cmd(cmdline, expected_rc=255)
self.assertEqual(self.operations_called[0][0].name,
'StartSession')
Expand Down
Loading

0 comments on commit 48e1c63

Please sign in to comment.