Skip to content

Commit

Permalink
Pass StartSession response as env variable
Browse files Browse the repository at this point in the history
Pass the StartSession API response as
environment variable to the session-manager-plugin
  • Loading branch information
Yangtao-Hua authored and kdaily committed Nov 17, 2023
1 parent 5be9161 commit 1c0e28f
Show file tree
Hide file tree
Showing 4 changed files with 440 additions and 24 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"type": "enhancement",
"category": "``ssm`` Session Manager",
"description": "Pass StartSession API response as environment variable to session-manager-plugin"
}
75 changes: 71 additions & 4 deletions awscli/customizations/sessionmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
import logging
import json
import errno
import os
import re

from subprocess import check_call
from subprocess import check_call, check_output
from awscli.compat import ignore_user_entered_signals
from awscli.clidriver import ServiceOperation, CLIOperationCaller

Expand Down Expand Up @@ -44,8 +46,43 @@ def add_custom_start_session(session, command_table, **kwargs):
)


class StartSessionCommand(ServiceOperation):
class VersionRequirement:
WHITESPACE_REGEX = re.compile(r"\s+")
SSM_SESSION_PLUGIN_VERSION_REGEX = re.compile(r"^\d+(\.\d+){0,3}$")

def __init__(self, min_version):
self.min_version = min_version

def meets_requirement(self, version):
ssm_plugin_version = self._sanitize_plugin_version(version)
if self._is_valid_version(ssm_plugin_version):
norm_version, norm_min_version = self._normalize(
ssm_plugin_version, self.min_version
)
return norm_version > norm_min_version
else:
return False

def _sanitize_plugin_version(self, plugin_version):
return re.sub(self.WHITESPACE_REGEX, "", plugin_version)

def _is_valid_version(self, plugin_version):
return bool(
self.SSM_SESSION_PLUGIN_VERSION_REGEX.match(plugin_version)
)

def _normalize(self, v1, v2):
v1_parts = [int(v) for v in v1.split(".")]
v2_parts = [int(v) for v in v2.split(".")]
while len(v1_parts) != len(v2_parts):
if len(v1_parts) - len(v2_parts) > 0:
v2_parts.append(0)
else:
v1_parts.append(0)
return v1_parts, v2_parts


class StartSessionCommand(ServiceOperation):
def create_help_command(self):
help_command = super(
StartSessionCommand, self).create_help_command()
Expand All @@ -55,6 +92,9 @@ def create_help_command(self):


class StartSessionCaller(CLIOperationCaller):
LAST_PLUGIN_VERSION_WITHOUT_ENV_VAR = "1.2.497.0"
DEFAULT_SSM_ENV_NAME = "AWS_SSM_START_SESSION_RESPONSE"

def invoke(self, service_name, operation_name, parameters,
parsed_globals):
client = self._session.create_client(
Expand All @@ -70,8 +110,34 @@ def invoke(self, service_name, operation_name, parameters,
profile_name = self._session.profile \
if self._session.profile is not None else ''
endpoint_url = client.meta.endpoint_url
ssm_env_name = self.DEFAULT_SSM_ENV_NAME

try:
session_parameters = {
"SessionId": response["SessionId"],
"TokenValue": response["TokenValue"],
"StreamUrl": response["StreamUrl"],
}
start_session_response = json.dumps(session_parameters)

plugin_version = check_output(
["session-manager-plugin", "--version"], text=True
)
env = os.environ.copy()

# Check if this plugin supports passing the start session response
# as an environment variable name. If it does, it will set the
# value to the response from the start_session operation to the env
# variable defined in DEFAULT_SSM_ENV_NAME. If the session plugin
# version is invalid or older than the version defined in
# LAST_PLUGIN_VERSION_WITHOUT_ENV_VAR, it will fall back to
# passing the start_session response directly.
version_requirement = VersionRequirement(
min_version=self.LAST_PLUGIN_VERSION_WITHOUT_ENV_VAR
)
if version_requirement.meets_requirement(plugin_version):
env[ssm_env_name] = start_session_response
start_session_response = ssm_env_name
# ignore_user_entered_signals ignores these signals
# because if signals which kills the process are not
# captured would kill the foreground process but not the
Expand All @@ -81,12 +147,13 @@ def invoke(self, service_name, operation_name, parameters,
with ignore_user_entered_signals():
# call executable with necessary input
check_call(["session-manager-plugin",
json.dumps(response),
start_session_response,
region_name,
"StartSession",
profile_name,
json.dumps(parameters),
endpoint_url])
endpoint_url], env=env)

return 0
except OSError as ex:
if ex.errno == errno.ENOENT:
Expand Down
114 changes: 97 additions & 17 deletions tests/functional/ssm/test_start_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,44 +10,124 @@
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
import mock
import errno
import json

from awscli.testutils import BaseAWSCommandParamsTest
from awscli.testutils import BaseAWSHelpOutputTest
from awscli.testutils import mock


class TestSessionManager(BaseAWSCommandParamsTest):

@mock.patch('awscli.customizations.sessionmanager.check_call')
def test_start_session_success(self, mock_check_call):
@mock.patch("awscli.customizations.sessionmanager.check_output")
def test_start_session_success(self, mock_check_output, mock_check_call):
cmdline = 'ssm start-session --target instance-id'
mock_check_call.return_value = 0
self.parsed_responses = [{
mock_check_output.return_value = "1.2.0.0\n"
expected_response = {
"SessionId": "session-id",
"TokenValue": "token-value",
"StreamUrl": "stream-url"
}]
"StreamUrl": "stream-url",
}
self.parsed_responses = [expected_response]
start_session_params = {"Target": "instance-id"}

self.run_cmd(cmdline, expected_rc=0)

mock_check_call.assert_called_once_with(
[
"session-manager-plugin",
json.dumps(expected_response),
mock.ANY,
"StartSession",
mock.ANY,
json.dumps(start_session_params),
mock.ANY,
],
env=self.environ,
)

@mock.patch("awscli.customizations.sessionmanager.check_call")
@mock.patch("awscli.customizations.sessionmanager.check_output")
def test_start_session_with_new_version_plugin_success(
self, mock_check_output, mock_check_call
):
cmdline = "ssm start-session --target instance-id"
mock_check_call.return_value = 0
mock_check_output.return_value = "1.2.500.0\n"
expected_response = {
"SessionId": "session-id",
"TokenValue": "token-value",
"StreamUrl": "stream-url",
}
self.parsed_responses = [expected_response]

ssm_env_name = "AWS_SSM_START_SESSION_RESPONSE"
start_session_params = {"Target": "instance-id"}
expected_env = self.environ.copy()
expected_env.update({ssm_env_name: json.dumps(expected_response)})

self.run_cmd(cmdline, expected_rc=0)
self.assertEqual(self.operations_called[0][0].name,
'StartSession')
self.assertEqual(self.operations_called[0][1],
{'Target': 'instance-id'})
actual_response = json.loads(mock_check_call.call_args[0][0][1])
self.assertEqual(
{"SessionId": "session-id",
"TokenValue": "token-value",
"StreamUrl": "stream-url"},
actual_response)

mock_check_call.assert_called_once_with(
[
"session-manager-plugin",
ssm_env_name,
mock.ANY,
"StartSession",
mock.ANY,
json.dumps(start_session_params),
mock.ANY,
],
env=expected_env,
)

@mock.patch('awscli.customizations.sessionmanager.check_call')
def test_start_session_fails(self, mock_check_call):
@mock.patch("awscli.customizations.sessionmanager.check_output")
def test_start_session_fails(self, mock_check_output, mock_check_call):
cmdline = "ssm start-session --target instance-id"
mock_check_output.return_value = "1.2.500.0\n"
mock_check_call.side_effect = OSError(errno.ENOENT, "some error")
self.parsed_responses = [
{
"SessionId": "session-id",
"TokenValue": "token-value",
"StreamUrl": "stream-url",
}
]
self.run_cmd(cmdline, expected_rc=255)
self.assertEqual(
self.operations_called[0][0].name, "StartSession"
)
self.assertEqual(
self.operations_called[0][1], {"Target": "instance-id"}
)
self.assertEqual(
self.operations_called[1][0].name, "TerminateSession"
)
self.assertEqual(
self.operations_called[1][1], {"SessionId": "session-id"}
)

@mock.patch("awscli.customizations.sessionmanager.check_call")
@mock.patch("awscli.customizations.sessionmanager.check_output")
def test_start_session_when_get_plugin_version_fails(
self, mock_check_output, mock_check_call
):
cmdline = 'ssm start-session --target instance-id'
mock_check_call.side_effect = OSError(errno.ENOENT, 'some error')
self.parsed_responses = [{
"SessionId": "session-id"
}]
mock_check_output.side_effect = OSError(errno.ENOENT, 'some error')
self.parsed_responses = [
{
"SessionId": "session-id",
"TokenValue": "token-value",
"StreamUrl": "stream-url",
}
]
self.run_cmd(cmdline, expected_rc=255)
self.assertEqual(self.operations_called[0][0].name,
'StartSession')
Expand Down
Loading

0 comments on commit 1c0e28f

Please sign in to comment.