From 48e1c630e71d660890bee00a54c22d36db7c599c Mon Sep 17 00:00:00 2001 From: Yangtao-Hua Date: Fri, 13 Oct 2023 09:38:37 -0700 Subject: [PATCH] Pass StartSession response as env variable Pass the StartSession API response as environment variable to the session-manager-plugin --- .../enhancement-ssmSessionManager-47156.json | 5 + awscli/customizations/sessionmanager.py | 75 ++++- tests/functional/ssm/test_start_session.py | 115 ++++++- .../customizations/test_sessionmanager.py | 292 +++++++++++++++++- 4 files changed, 453 insertions(+), 34 deletions(-) create mode 100644 .changes/next-release/enhancement-ssmSessionManager-47156.json diff --git a/.changes/next-release/enhancement-ssmSessionManager-47156.json b/.changes/next-release/enhancement-ssmSessionManager-47156.json new file mode 100644 index 000000000000..d4870e19b6b3 --- /dev/null +++ b/.changes/next-release/enhancement-ssmSessionManager-47156.json @@ -0,0 +1,5 @@ +{ + "type": "enhancement", + "category": "``ssm`` Session Manager", + "description": "Pass StartSession API response as environment variable to session-manager-plugin" +} diff --git a/awscli/customizations/sessionmanager.py b/awscli/customizations/sessionmanager.py index c33aaca590c7..92a8f8ffbe8a 100644 --- a/awscli/customizations/sessionmanager.py +++ b/awscli/customizations/sessionmanager.py @@ -13,8 +13,10 @@ import logging import json import errno +import os +import re -from subprocess import check_call +from subprocess import check_call, check_output from awscli.compat import ignore_user_entered_signals from awscli.clidriver import ServiceOperation, CLIOperationCaller @@ -44,8 +46,43 @@ def add_custom_start_session(session, command_table, **kwargs): ) -class StartSessionCommand(ServiceOperation): +class VersionRequirement: + WHITESPACE_REGEX = re.compile(r"\s+") + SSM_SESSION_PLUGIN_VERSION_REGEX = re.compile(r"^\d+(\.\d+){0,3}$") + + def __init__(self, min_version): + self.min_version = min_version + + def meets_requirement(self, version): + ssm_plugin_version = self._sanitize_plugin_version(version) + if self._is_valid_version(ssm_plugin_version): + norm_version, norm_min_version = self._normalize( + ssm_plugin_version, self.min_version + ) + return norm_version > norm_min_version + else: + return False + + def _sanitize_plugin_version(self, plugin_version): + return re.sub(self.WHITESPACE_REGEX, "", plugin_version) + + def _is_valid_version(self, plugin_version): + return bool( + self.SSM_SESSION_PLUGIN_VERSION_REGEX.match(plugin_version) + ) + + def _normalize(self, v1, v2): + v1_parts = [int(v) for v in v1.split(".")] + v2_parts = [int(v) for v in v2.split(".")] + while len(v1_parts) != len(v2_parts): + if len(v1_parts) - len(v2_parts) > 0: + v2_parts.append(0) + else: + v1_parts.append(0) + return v1_parts, v2_parts + +class StartSessionCommand(ServiceOperation): def create_help_command(self): help_command = super( StartSessionCommand, self).create_help_command() @@ -55,6 +92,9 @@ def create_help_command(self): class StartSessionCaller(CLIOperationCaller): + LAST_PLUGIN_VERSION_WITHOUT_ENV_VAR = "1.2.497.0" + DEFAULT_SSM_ENV_NAME = "AWS_SSM_START_SESSION_RESPONSE" + def invoke(self, service_name, operation_name, parameters, parsed_globals): client = self._session.create_client( @@ -70,8 +110,34 @@ def invoke(self, service_name, operation_name, parameters, profile_name = self._session.profile \ if self._session.profile is not None else '' endpoint_url = client.meta.endpoint_url + ssm_env_name = self.DEFAULT_SSM_ENV_NAME try: + session_parameters = { + "SessionId": response["SessionId"], + "TokenValue": response["TokenValue"], + "StreamUrl": response["StreamUrl"], + } + start_session_response = json.dumps(session_parameters) + + plugin_version = check_output( + ["session-manager-plugin", "--version"], text=True + ) + env = os.environ.copy() + + # Check if this plugin supports passing the start session response + # as an environment variable name. If it does, it will set the + # value to the response from the start_session operation to the env + # variable defined in DEFAULT_SSM_ENV_NAME. If the session plugin + # version is invalid or older than the version defined in + # LAST_PLUGIN_VERSION_WITHOUT_ENV_VAR, it will fall back to + # passing the start_session response directly. + version_requirement = VersionRequirement( + min_version=self.LAST_PLUGIN_VERSION_WITHOUT_ENV_VAR + ) + if version_requirement.meets_requirement(plugin_version): + env[ssm_env_name] = start_session_response + start_session_response = ssm_env_name # ignore_user_entered_signals ignores these signals # because if signals which kills the process are not # captured would kill the foreground process but not the @@ -81,12 +147,13 @@ def invoke(self, service_name, operation_name, parameters, with ignore_user_entered_signals(): # call executable with necessary input check_call(["session-manager-plugin", - json.dumps(response), + start_session_response, region_name, "StartSession", profile_name, json.dumps(parameters), - endpoint_url]) + endpoint_url], env=env) + return 0 except OSError as ex: if ex.errno == errno.ENOENT: diff --git a/tests/functional/ssm/test_start_session.py b/tests/functional/ssm/test_start_session.py index 8c391024eb24..2ed9c9176abe 100644 --- a/tests/functional/ssm/test_start_session.py +++ b/tests/functional/ssm/test_start_session.py @@ -15,38 +15,119 @@ from awscli.testutils import BaseAWSCommandParamsTest from awscli.testutils import BaseAWSHelpOutputTest -from awscli.testutils import mock +from awscli.testutils import mock -class TestSessionManager(BaseAWSCommandParamsTest): +class TestSessionManager(BaseAWSCommandParamsTest): @mock.patch('awscli.customizations.sessionmanager.check_call') - def test_start_session_success(self, mock_check_call): + @mock.patch("awscli.customizations.sessionmanager.check_output") + def test_start_session_success(self, mock_check_output, mock_check_call): cmdline = 'ssm start-session --target instance-id' mock_check_call.return_value = 0 - self.parsed_responses = [{ + mock_check_output.return_value = "1.2.0.0\n" + expected_response = { "SessionId": "session-id", "TokenValue": "token-value", - "StreamUrl": "stream-url" - }] + "StreamUrl": "stream-url", + } + self.parsed_responses = [expected_response] + start_session_params = {"Target": "instance-id"} + + self.run_cmd(cmdline, expected_rc=0) + + mock_check_call.assert_called_once_with( + [ + "session-manager-plugin", + json.dumps(expected_response), + mock.ANY, + "StartSession", + mock.ANY, + json.dumps(start_session_params), + mock.ANY, + ], + env=self.environ, + ) + + @mock.patch("awscli.customizations.sessionmanager.check_call") + @mock.patch("awscli.customizations.sessionmanager.check_output") + def test_start_session_with_new_version_plugin_success( + self, mock_check_output, mock_check_call + ): + cmdline = "ssm start-session --target instance-id" + mock_check_call.return_value = 0 + mock_check_output.return_value = "1.2.500.0\n" + expected_response = { + "SessionId": "session-id", + "TokenValue": "token-value", + "StreamUrl": "stream-url", + } + self.parsed_responses = [expected_response] + + ssm_env_name = "AWS_SSM_START_SESSION_RESPONSE" + start_session_params = {"Target": "instance-id"} + expected_env = self.environ.copy() + expected_env.update({ssm_env_name: json.dumps(expected_response)}) + self.run_cmd(cmdline, expected_rc=0) self.assertEqual(self.operations_called[0][0].name, 'StartSession') self.assertEqual(self.operations_called[0][1], {'Target': 'instance-id'}) - actual_response = json.loads(mock_check_call.call_args[0][0][1]) - self.assertEqual( - {"SessionId": "session-id", - "TokenValue": "token-value", - "StreamUrl": "stream-url"}, - actual_response) + + mock_check_call.assert_called_once_with( + [ + "session-manager-plugin", + ssm_env_name, + mock.ANY, + "StartSession", + mock.ANY, + json.dumps(start_session_params), + mock.ANY, + ], + env=expected_env, + ) @mock.patch('awscli.customizations.sessionmanager.check_call') - def test_start_session_fails(self, mock_check_call): + @mock.patch("awscli.customizations.sessionmanager.check_output") + def test_start_session_fails(self, mock_check_output, mock_check_call): + cmdline = "ssm start-session --target instance-id" + mock_check_output.return_value = "1.2.500.0\n" + mock_check_call.side_effect = OSError(errno.ENOENT, "some error") + self.parsed_responses = [ + { + "SessionId": "session-id", + "TokenValue": "token-value", + "StreamUrl": "stream-url", + } + ] + self.run_cmd(cmdline, expected_rc=255) + self.assertEqual( + self.operations_called[0][0].name, "StartSession" + ) + self.assertEqual( + self.operations_called[0][1], {"Target": "instance-id"} + ) + self.assertEqual( + self.operations_called[1][0].name, "TerminateSession" + ) + self.assertEqual( + self.operations_called[1][1], {"SessionId": "session-id"} + ) + + @mock.patch("awscli.customizations.sessionmanager.check_call") + @mock.patch("awscli.customizations.sessionmanager.check_output") + def test_start_session_when_get_plugin_version_fails( + self, mock_check_output, mock_check_call + ): cmdline = 'ssm start-session --target instance-id' - mock_check_call.side_effect = OSError(errno.ENOENT, 'some error') - self.parsed_responses = [{ - "SessionId": "session-id" - }] + mock_check_output.side_effect = OSError(errno.ENOENT, 'some error') + self.parsed_responses = [ + { + "SessionId": "session-id", + "TokenValue": "token-value", + "StreamUrl": "stream-url", + } + ] self.run_cmd(cmdline, expected_rc=255) self.assertEqual(self.operations_called[0][0].name, 'StartSession') diff --git a/tests/unit/customizations/test_sessionmanager.py b/tests/unit/customizations/test_sessionmanager.py index 85623dee8417..15cba75e8d9e 100644 --- a/tests/unit/customizations/test_sessionmanager.py +++ b/tests/unit/customizations/test_sessionmanager.py @@ -13,6 +13,9 @@ import errno import json import botocore.session +import subprocess + +import pytest from awscli.customizations import sessionmanager from awscli.testutils import mock, unittest @@ -38,8 +41,12 @@ def test_start_session_when_non_custom_start_session_fails(self): with self.assertRaisesRegex(Exception, 'some exception'): self.caller.invoke('ssm', 'StartSession', params, mock.Mock()) - @mock.patch('awscli.customizations.sessionmanager.check_call') - def test_start_session_success_scenario(self, mock_check_call): + @mock.patch("awscli.customizations.sessionmanager.check_output") + @mock.patch("awscli.customizations.sessionmanager.check_call") + def test_start_session_success_scenario( + self, mock_check_call, mock_check_output + ): + mock_check_output.return_value = "1.2.0.0\n" mock_check_call.return_value = 0 start_session_params = { @@ -58,6 +65,12 @@ def test_start_session_success_scenario(self, mock_check_call): start_session_params, mock.Mock()) self.assertEqual(rc, 0) self.client.start_session.assert_called_with(**start_session_params) + + mock_check_output_list = mock_check_output.call_args[0] + self.assertEqual( + mock_check_output_list[0], ["session-manager-plugin", "--version"] + ) + mock_check_call_list = mock_check_call.call_args[0][0] mock_check_call_list[1] = json.loads(mock_check_call_list[1]) self.assertEqual( @@ -71,8 +84,12 @@ def test_start_session_success_scenario(self, mock_check_call): self.endpoint_url] ) - @mock.patch('awscli.customizations.sessionmanager.check_call') - def test_start_session_when_check_call_fails(self, mock_check_call): + @mock.patch("awscli.customizations.sessionmanager.check_output") + @mock.patch("awscli.customizations.sessionmanager.check_call") + def test_start_session_when_check_call_fails( + self, mock_check_call, mock_check_output + ): + mock_check_output.return_value = "1.2.0.0\n" mock_check_call.side_effect = OSError(errno.ENOENT, 'some error') start_session_params = { @@ -104,17 +121,23 @@ def test_start_session_when_check_call_fails(self, mock_check_call): mock_check_call_list[1] = json.loads(mock_check_call_list[1]) self.assertEqual( mock_check_call_list, - ['session-manager-plugin', - start_session_response, - self.region, - 'StartSession', - self.profile, - json.dumps(start_session_params), - self.endpoint_url] + [ + "session-manager-plugin", + start_session_response, + self.region, + "StartSession", + self.profile, + json.dumps(start_session_params), + self.endpoint_url, + ], ) - @mock.patch('awscli.customizations.sessionmanager.check_call') - def test_start_session_when_no_profile_is_passed(self, mock_check_call): + @mock.patch("awscli.customizations.sessionmanager.check_call") + @mock.patch("awscli.customizations.sessionmanager.check_output") + def test_start_session_when_no_profile_is_passed( + self, mock_check_output, mock_check_call + ): + mock_check_output.return_value = "1.2.500.0\n" self.session.profile = None mock_check_call.return_value = 0 @@ -136,3 +159,246 @@ def test_start_session_when_no_profile_is_passed(self, mock_check_call): self.client.start_session.assert_called_with(**start_session_params) mock_check_call_list = mock_check_call.call_args[0][0] self.assertEqual(mock_check_call_list[4], '') + + @mock.patch("awscli.customizations.sessionmanager.check_call") + @mock.patch("awscli.customizations.sessionmanager.check_output") + def test_start_session_with_env_variable_success_scenario( + self, mock_check_output, mock_check_call + ): + mock_check_output.return_value = "1.2.500.0\n" + mock_check_call.return_value = 0 + + start_session_params = {"Target": "i-123456789"} + start_session_response = { + "SessionId": "session-id", + "TokenValue": "token-value", + "StreamUrl": "stream-url", + } + ssm_env_name = "AWS_SSM_START_SESSION_RESPONSE" + + self.client.start_session.return_value = start_session_response + rc = self.caller.invoke( + "ssm", "StartSession", start_session_params, mock.Mock() + ) + self.assertEqual(rc, 0) + self.client.start_session.assert_called_with(**start_session_params) + + mock_check_output_list = mock_check_output.call_args[0] + self.assertEqual( + mock_check_output_list[0], ["session-manager-plugin", "--version"] + ) + + mock_check_call_list = mock_check_call.call_args[0][0] + self.assertEqual( + mock_check_call_list, + [ + "session-manager-plugin", + ssm_env_name, + self.region, + "StartSession", + self.profile, + json.dumps(start_session_params), + self.endpoint_url, + ], + ) + env_variable = mock_check_call.call_args[1] + self.assertEqual( + env_variable["env"][ssm_env_name], + json.dumps(start_session_response) + ) + + @mock.patch("awscli.customizations.sessionmanager.check_call") + @mock.patch("awscli.customizations.sessionmanager.check_output") + def test_start_session_when_check_output_fails( + self, mock_check_output, mock_check_call + ): + mock_check_output.side_effect = subprocess.CalledProcessError( + returncode=1, cmd="session-manager-plugin", output="some error" + ) + + start_session_params = {"Target": "i-123456789"} + start_session_response = { + "SessionId": "session-id", + "TokenValue": "token-value", + "StreamUrl": "stream-url", + } + + self.client.start_session.return_value = start_session_response + with self.assertRaises(subprocess.CalledProcessError): + self.caller.invoke( + "ssm", "StartSession", start_session_params, mock.Mock() + ) + + self.client.start_session.assert_called_with(**start_session_params) + self.client.terminate_session.assert_not_called() + mock_check_output.assert_called_with( + ["session-manager-plugin", "--version"], text=True + ) + mock_check_call.assert_not_called() + + @mock.patch("awscli.customizations.sessionmanager.check_call") + @mock.patch("awscli.customizations.sessionmanager.check_output") + def test_start_session_when_response_not_json( + self, mock_check_output, mock_check_call + ): + mock_check_output.return_value = "1.2.500.0\n" + start_session_params = {"Target": "i-123456789"} + start_session_response = { + "SessionId": "session-id", + "TokenValue": "token-value", + "StreamUrl": "stream-url", + "para2": {"Not a json format"}, + } + expected_env_value = { + "SessionId": "session-id", + "TokenValue": "token-value", + "StreamUrl": "stream-url", + } + + ssm_env_name = "AWS_SSM_START_SESSION_RESPONSE" + + self.client.start_session.return_value = start_session_response + rc = self.caller.invoke( + "ssm", "StartSession", start_session_params, mock.Mock() + ) + self.assertEqual(rc, 0) + self.client.start_session.assert_called_with(**start_session_params) + + mock_check_output_list = mock_check_output.call_args[0] + self.assertEqual( + mock_check_output_list[0], ["session-manager-plugin", "--version"] + ) + + mock_check_call_list = mock_check_call.call_args[0][0] + self.assertEqual( + mock_check_call_list, + [ + "session-manager-plugin", + ssm_env_name, + self.region, + "StartSession", + self.profile, + json.dumps(start_session_params), + self.endpoint_url, + ], + ) + env_variable = mock_check_call.call_args[1] + self.assertEqual( + env_variable["env"][ssm_env_name], json.dumps(expected_env_value) + ) + + @mock.patch("awscli.customizations.sessionmanager.check_call") + @mock.patch("awscli.customizations.sessionmanager.check_output") + def test_start_session_when_invalid_plugin_version( + self, mock_check_output, mock_check_call + ): + mock_check_output.return_value = "InvalidVersion" + + start_session_params = {"Target": "i-123456789"} + start_session_response = { + "SessionId": "session-id", + "TokenValue": "token-value", + "StreamUrl": "stream-url", + } + + self.client.start_session.return_value = start_session_response + self.caller.invoke( + "ssm", "StartSession", start_session_params, mock.Mock() + ) + self.client.start_session.assert_called_with(**start_session_params) + self.client.terminate_session.assert_not_called() + mock_check_output.assert_called_with( + ["session-manager-plugin", "--version"], text=True + ) + + mock_check_call_list = mock_check_call.call_args[0][0] + self.assertEqual( + mock_check_call_list, + [ + "session-manager-plugin", + json.dumps(start_session_response), + self.region, + "StartSession", + self.profile, + json.dumps(start_session_params), + self.endpoint_url, + ], + ) + + +class TestVersionRequirement: + version_requirement = \ + sessionmanager.VersionRequirement(min_version="1.2.497.0") + + @pytest.mark.parametrize( + "version, expected_result", + [ + ("2.0.0.0", True), + ("2.1", True), + ("2", True), + ("1.3.1.1", True), + ("\r\n1. 3.1.1", True), + ("1.3.1.0", True), + ("1.3", True), + ("1.2.498.1", True), + ("1.2.498", True), + ("1.2.497.1", True), + ("1.2.497.0", False), + ("1.2.497", False), + ("1.2.1.1", False), + ("1.2.1", False), + ("1.2", False), + ("1.1.1.0", False), + ("1.1.1", False), + ("1.0.497.0", False), + ("1.0. 497.0\r\n", False), + ("1", False), + ("0.3.497.0", False), + ], + ) + def test_meets_requirement(self, version, expected_result): + assert expected_result == \ + self.version_requirement.meets_requirement(version) + + @pytest.mark.parametrize( + "version, expected_result", + [ + ("\r\n1.3.1.1", "1.3.1.1"), + ("\r1.3 .1.1", "1.3.1.1"), + ("1 .3.1.1", "1.3.1.1"), + (" 1.3.1.1", "1.3.1.1"), + ("1.3.1.1 ", "1.3.1.1"), + (" 1.3.1.1 ", "1.3.1.1"), + ("\n1.3.1.1 ", "1.3.1.1"), + ("1.3.1.1\n", "1.3.1.1"), + ("1.3\r\n.1.1", "1.3.1.1"), + (" 1\r. 3", "1.3"), + (" 1. 3. ", "1.3."), + ("1.1.1\r\n", "1.1.1"), + ("1\r", "1"), + ], + ) + def test_sanitize_plugin_version(self, version, expected_result): + assert expected_result == \ + self.version_requirement._sanitize_plugin_version(version) + + @pytest.mark.parametrize( + "version, expected_result", + [ + ("999.99999.99.9", True), + ("2", True), + ("1.1.1.1", True), + ("1.1.1", True), + ("1.1", True), + ("1.1.1.1.1", False), + ("1.1.1.1.0", False), + ("1.1.1.a", False), + ("1.a.1.1", False), + ("1-1.1.1", False), + ("1.1.", False), + ("invalid_version", False), + ], + ) + def test_is_valid_version(self, version, expected_result): + assert expected_result == \ + self.version_requirement._is_valid_version(version)