diff --git a/samcli/lib/utils/stream_writer.py b/samcli/lib/utils/stream_writer.py index dc2c2b9b5c..78bae0d8ff 100644 --- a/samcli/lib/utils/stream_writer.py +++ b/samcli/lib/utils/stream_writer.py @@ -1,12 +1,11 @@ """ This class acts like a wrapper around output streams to provide any flexibility with output we need """ -from io import BytesIO, TextIOWrapper -from typing import Optional, TextIO, Union +from typing import TextIO class StreamWriter: - def __init__(self, stream: TextIO, stream_bytes: Optional[Union[TextIO, BytesIO]] = None, auto_flush: bool = False): + def __init__(self, stream: TextIO, auto_flush: bool = False): """ Instatiates new StreamWriter to the specified stream @@ -14,40 +13,16 @@ def __init__(self, stream: TextIO, stream_bytes: Optional[Union[TextIO, BytesIO] ---------- stream io.RawIOBase Stream to wrap - stream_bytes io.TextIO | io.BytesIO - Stream to wrap if bytes are being written auto_flush bool Whether to autoflush the stream upon writing """ self._stream = stream - self._stream_bytes = stream if isinstance(stream, TextIOWrapper) else stream_bytes self._auto_flush = auto_flush @property def stream(self) -> TextIO: return self._stream - def write_bytes(self, output: bytes): - """ - Writes specified text to the underlying stream - Parameters - ---------- - output bytes-like object - Bytes to write into buffer - """ - # all these ifs are to satisfy the linting/type checking - if not self._stream_bytes: - return - if isinstance(self._stream_bytes, TextIOWrapper): - self._stream_bytes.buffer.write(output) - if self._auto_flush: - self._stream_bytes.flush() - - elif isinstance(self._stream_bytes, BytesIO): - self._stream_bytes.write(output) - if self._auto_flush: - self._stream_bytes.flush() - def write_str(self, output: str): """ Writes specified text to the underlying stream @@ -64,5 +39,3 @@ def write_str(self, output: str): def flush(self): self._stream.flush() - if self._stream_bytes: - self._stream_bytes.flush() diff --git a/samcli/local/apigw/authorizers/lambda_authorizer.py b/samcli/local/apigw/authorizers/lambda_authorizer.py index 8b7b92c6ea..ed3483eee5 100644 --- a/samcli/local/apigw/authorizers/lambda_authorizer.py +++ b/samcli/local/apigw/authorizers/lambda_authorizer.py @@ -5,7 +5,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from json import JSONDecodeError, loads -from typing import Any, Dict, List, Optional, Tuple, Type, Union, cast +from typing import Any, Dict, List, Optional, Tuple, Type, cast from urllib.parse import parse_qsl from samcli.commands.local.lib.validators.identity_source_validator import IdentitySourceValidator @@ -321,13 +321,13 @@ def _parse_identity_sources(self, identity_sources: List[str]) -> None: break - def is_valid_response(self, response: Union[str, bytes], method_arn: str) -> bool: + def is_valid_response(self, response: str, method_arn: str) -> bool: """ Validates whether a Lambda authorizer request is authenticated or not. Parameters ---------- - response: Union[str, bytes] + response: str JSON string containing the output from a Lambda authorizer method_arn: str The method ARN of the route that invoked the Lambda authorizer @@ -418,13 +418,13 @@ def _validate_simple_response(self, response: dict) -> bool: return cast(bool, is_authorized) - def get_context(self, response: Union[str, bytes]) -> Dict[str, Any]: + def get_context(self, response: str) -> Dict[str, Any]: """ Returns the context (if set) from the authorizer response and appends the principalId to it. Parameters ---------- - response: Union[str, bytes] + response: str Output from Lambda authorizer Returns diff --git a/samcli/local/apigw/local_apigw_service.py b/samcli/local/apigw/local_apigw_service.py index 1e0f871fcd..df82b68ae4 100644 --- a/samcli/local/apigw/local_apigw_service.py +++ b/samcli/local/apigw/local_apigw_service.py @@ -6,7 +6,7 @@ from datetime import datetime from io import StringIO from time import time -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple from flask import Flask, Request, request from werkzeug.datastructures import Headers @@ -594,7 +594,7 @@ def _valid_identity_sources(self, request: Request, route: Route) -> bool: return True - def _invoke_lambda_function(self, lambda_function_name: str, event: dict) -> Union[str, bytes]: + def _invoke_lambda_function(self, lambda_function_name: str, event: dict) -> str: """ Helper method to invoke a function and setup stdout+stderr @@ -607,8 +607,8 @@ def _invoke_lambda_function(self, lambda_function_name: str, event: dict) -> Uni Returns ------- - Union[str, bytes] - A string or bytes containing the output from the Lambda function + str + A string containing the output from the Lambda function """ with StringIO() as stdout: event_str = json.dumps(event, sort_keys=True) diff --git a/samcli/local/docker/container.py b/samcli/local/docker/container.py index bef505d87c..6e54ed2531 100644 --- a/samcli/local/docker/container.py +++ b/samcli/local/docker/container.py @@ -363,7 +363,7 @@ def start(self, input_data=None): raise ex @retry(exc=requests.exceptions.RequestException, exc_raise=ContainerResponseException) - def wait_for_http_response(self, name, event, stdout) -> Tuple[Union[str, bytes], bool]: + def wait_for_http_response(self, name, event, stdout) -> Union[str, bytes]: # TODO(sriram-mv): `aws-lambda-rie` is in a mode where the function_name is always "function" # NOTE(sriram-mv): There is a connection timeout set on the http call to `aws-lambda-rie`, however there is not # a read time out for the response received from the server. @@ -374,13 +374,10 @@ def wait_for_http_response(self, name, event, stdout) -> Tuple[Union[str, bytes] timeout=(self.RAPID_CONNECTION_TIMEOUT, None), ) try: - # if response is an image then json.loads/dumps will throw a UnicodeDecodeError so return raw content - if "image" in resp.headers["Content-Type"]: - return resp.content, True - return json.dumps(json.loads(resp.content), ensure_ascii=False), False + return json.dumps(json.loads(resp.content), ensure_ascii=False) except json.JSONDecodeError: LOG.debug("Failed to deserialize response from RIE, returning the raw response as is") - return resp.content, False + return resp.content def wait_for_result(self, full_path, event, stdout, stderr, start_timer=None): # NOTE(sriram-mv): Let logging happen in its own thread, so that a http request can be sent. @@ -403,15 +400,13 @@ def wait_for_result(self, full_path, event, stdout, stderr, start_timer=None): # start the timer for function timeout right before executing the function, as waiting for the socket # can take some time timer = start_timer() if start_timer else None - response, is_image = self.wait_for_http_response(full_path, event, stdout) + response = self.wait_for_http_response(full_path, event, stdout) if timer: timer.cancel() self._logs_thread_event.wait(timeout=1) if isinstance(response, str): stdout.write_str(response) - elif isinstance(response, bytes) and is_image: - stdout.write_bytes(response) elif isinstance(response, bytes): stdout.write_str(response.decode("utf-8")) stdout.flush() diff --git a/samcli/local/lambda_service/local_lambda_invoke_service.py b/samcli/local/lambda_service/local_lambda_invoke_service.py index a847802b3c..1e46a10507 100644 --- a/samcli/local/lambda_service/local_lambda_invoke_service.py +++ b/samcli/local/lambda_service/local_lambda_invoke_service.py @@ -165,9 +165,8 @@ def _invoke_request_handler(self, function_name): request_data = request_data.decode("utf-8") - stdout_stream_string = io.StringIO() - stdout_stream_bytes = io.BytesIO() - stdout_stream_writer = StreamWriter(stdout_stream_string, stdout_stream_bytes, auto_flush=True) + stdout_stream = io.StringIO() + stdout_stream_writer = StreamWriter(stdout_stream, auto_flush=True) try: self.lambda_runner.invoke(function_name, request_data, stdout=stdout_stream_writer, stderr=self.stderr) @@ -179,9 +178,7 @@ def _invoke_request_handler(self, function_name): "Inline code is not supported for sam local commands. Please write your code in a separate file." ) - lambda_response, is_lambda_user_error_response = LambdaOutputParser.get_lambda_output( - stdout_stream_string, stdout_stream_bytes - ) + lambda_response, is_lambda_user_error_response = LambdaOutputParser.get_lambda_output(stdout_stream) if is_lambda_user_error_response: return self.service_response( diff --git a/samcli/local/services/base_local_service.py b/samcli/local/services/base_local_service.py index 573c24b445..5de5beb7dd 100644 --- a/samcli/local/services/base_local_service.py +++ b/samcli/local/services/base_local_service.py @@ -2,7 +2,7 @@ import io import json import logging -from typing import Optional, Tuple, Union +from typing import Tuple from flask import Response @@ -85,9 +85,7 @@ def service_response(body, headers, status_code): class LambdaOutputParser: @staticmethod - def get_lambda_output( - stdout_stream_str: io.StringIO, stdout_stream_bytes: Optional[io.BytesIO] = None - ) -> Tuple[Union[str, bytes], bool]: + def get_lambda_output(stdout_stream: io.StringIO) -> Tuple[str, bool]: """ This method will extract read the given stream and return the response from Lambda function separated out from any log statements it might have outputted. Logs end up in the stdout stream if the Lambda function @@ -95,12 +93,9 @@ def get_lambda_output( Parameters ---------- - stdout_stream_str : io.BaseIO + stdout_stream : io.BaseIO Stream to fetch data from - stdout_stream_bytes : Optional[io.BytesIO], optional - Stream to fetch raw bytes data from - Returns ------- str @@ -108,9 +103,7 @@ def get_lambda_output( bool If the response is an error/exception from the container """ - lambda_response: Union[str, bytes] = stdout_stream_str.getvalue() - if stdout_stream_bytes and not lambda_response: - lambda_response = stdout_stream_bytes.getvalue() + lambda_response = stdout_stream.getvalue() # When the Lambda Function returns an Error/Exception, the output is added to the stdout of the container. From # our perspective, the container returned some value, which is not always true. Since the output is the only diff --git a/tests/integration/local/invoke/test_integrations_cli.py b/tests/integration/local/invoke/test_integrations_cli.py index 3c33a25cea..c1a114b5d4 100644 --- a/tests/integration/local/invoke/test_integrations_cli.py +++ b/tests/integration/local/invoke/test_integrations_cli.py @@ -1197,41 +1197,6 @@ def test_invoke_inline_code_function(self): self.assertEqual(process.returncode, 1) -class TestInvokeFunctionWithImageBytesAsReturn(InvokeIntegBase): - template = Path("template-return-image.yaml") - - @pytest.mark.flaky(reruns=3) - def test_invoke_returncode_is_zero(self): - command_list = InvokeIntegBase.get_command_list( - "GetImageFunction", template_path=self.template_path, event_path=self.event_path - ) - - process = Popen(command_list, stdout=PIPE) - try: - process.communicate(timeout=TIMEOUT) - except TimeoutExpired: - process.kill() - raise - - self.assertEqual(process.returncode, 0) - - @pytest.mark.flaky(reruns=3) - def test_invoke_image_is_returned(self): - command_list = InvokeIntegBase.get_command_list( - "GetImageFunction", template_path=self.template_path, event_path=self.event_path - ) - - process = Popen(command_list, stdout=PIPE) - try: - stdout, _ = process.communicate(timeout=TIMEOUT) - except TimeoutExpired: - process.kill() - raise - - # The first byte of a png image file is \x89 so we can check that to verify that it returned an image - self.assertEqual(stdout[0:1], b"\x89") - - class TestInvokeFunctionWithError(InvokeIntegBase): template = Path("template.yml") diff --git a/tests/integration/testdata/invoke/image-for-lambda.png b/tests/integration/testdata/invoke/image-for-lambda.png deleted file mode 100644 index 56a3af614b..0000000000 Binary files a/tests/integration/testdata/invoke/image-for-lambda.png and /dev/null differ diff --git a/tests/integration/testdata/invoke/main.py b/tests/integration/testdata/invoke/main.py index 0f2753a6ff..e33635ccb4 100644 --- a/tests/integration/testdata/invoke/main.py +++ b/tests/integration/testdata/invoke/main.py @@ -60,10 +60,3 @@ def execute_git(event, context): def no_response(event, context): print("lambda called") - - -def image_handler(event, context): - with open("image-for-lambda.png", "rb") as f: - image_bytes = f.read() - - return image_bytes \ No newline at end of file diff --git a/tests/integration/testdata/invoke/template-return-image.yaml b/tests/integration/testdata/invoke/template-return-image.yaml deleted file mode 100644 index 9ec9c783c5..0000000000 --- a/tests/integration/testdata/invoke/template-return-image.yaml +++ /dev/null @@ -1,12 +0,0 @@ -AWSTemplateFormatVersion : '2010-09-09' -Transform: AWS::Serverless-2016-10-31 -Description: A hello world application. - -Resources: - GetImageFunction: - Type: AWS::Serverless::Function - Properties: - Handler: main.image_handler - Runtime: python3.12 - CodeUri: . - Timeout: 30 \ No newline at end of file diff --git a/tests/unit/lib/utils/test_stream_writer.py b/tests/unit/lib/utils/test_stream_writer.py index c586e48a42..0459a44c0e 100644 --- a/tests/unit/lib/utils/test_stream_writer.py +++ b/tests/unit/lib/utils/test_stream_writer.py @@ -1,8 +1,8 @@ """ Tests for StreamWriter """ +import io -from io import BytesIO, TextIOWrapper from unittest import TestCase from samcli.lib.utils.stream_writer import StreamWriter @@ -20,35 +20,6 @@ def test_must_write_to_stream(self): stream_mock.write.assert_called_once_with(buffer.decode("utf-8")) - def test_must_write_to_stream_bytes(self): - img_bytes = b"\xff\xab\x11" - stream_mock = Mock() - byte_stream_mock = Mock(spec=BytesIO) - - writer = StreamWriter(stream_mock, byte_stream_mock) - writer.write_bytes(img_bytes) - - byte_stream_mock.write.assert_called_once_with(img_bytes) - - def test_must_write_to_stream_bytes_for_stdout(self): - img_bytes = b"\xff\xab\x11" - stream_mock = Mock() - byte_stream_mock = Mock(spec=TextIOWrapper) - - writer = StreamWriter(stream_mock, byte_stream_mock) - writer.write_bytes(img_bytes) - - byte_stream_mock.buffer.write.assert_called_once_with(img_bytes) - - def test_must_not_write_to_stream_bytes_if_not_defined(self): - img_bytes = b"\xff\xab\x11" - stream_mock = Mock() - - writer = StreamWriter(stream_mock) - writer.write_bytes(img_bytes) - - stream_mock.write.assert_not_called() - def test_must_flush_underlying_stream(self): stream_mock = Mock() writer = StreamWriter(stream_mock) @@ -73,7 +44,7 @@ def test_when_auto_flush_on_flush_after_each_write(self): lines = ["first", "second", "third"] - writer = StreamWriter(stream_mock, auto_flush=True) + writer = StreamWriter(stream_mock, True) for line in lines: writer.write_str(line) diff --git a/tests/unit/local/docker/test_container.py b/tests/unit/local/docker/test_container.py index ddf02f91e3..cdcf22c5d9 100644 --- a/tests/unit/local/docker/test_container.py +++ b/tests/unit/local/docker/test_container.py @@ -1,7 +1,6 @@ """ Unit test for Container class """ -import base64 import json from unittest import TestCase from unittest.mock import MagicMock, Mock, call, patch, ANY @@ -585,77 +584,22 @@ def setUp(self): self.socket_mock = Mock() self.socket_mock.connect_ex.return_value = 0 - @patch("socket.socket") - @patch("samcli.local.docker.container.requests") - def test_wait_for_result_no_error_image_response(self, mock_requests, patched_socket): - self.container.is_created.return_value = True - - rie_response = b"\xff\xab" - resp_headers = { - "Date": "Tue, 02 Jan 2024 21:23:31 GMT", - "Content-Type": "image/jpeg", - "Transfer-Encoding": "chunked", - } - - real_container_mock = Mock() - self.mock_docker_client.containers.get.return_value = real_container_mock - - output_itr = Mock() - real_container_mock.attach.return_value = output_itr - self.container._write_container_output = Mock() - self.container._create_threading_event = Mock() - self.container._create_threading_event.return_value = Mock() - - stdout_mock = Mock() - stdout_mock.write_bytes = Mock() - stderr_mock = Mock() - response = Mock() - response.content = rie_response - response.headers = resp_headers - mock_requests.post.return_value = response - - patched_socket.return_value = self.socket_mock - - start_timer = Mock() - timer = Mock() - start_timer.return_value = timer - - self.container.wait_for_result( - event=self.event, full_path=self.name, stdout=stdout_mock, stderr=stderr_mock, start_timer=start_timer - ) - - # since we passed in a start_timer function, ensure it's called and - # the timer is cancelled once execution is done - start_timer.assert_called() - timer.cancel.assert_called() - - # make sure we wait for the same host+port that we make the post request to - host = self.container._container_host - port = self.container.rapid_port_host - self.socket_mock.connect_ex.assert_called_with((host, port)) - mock_requests.post.assert_called_with( - self.container.URL.format(host=host, port=port, function_name="function"), - data=b"{}", - timeout=(self.container.RAPID_CONNECTION_TIMEOUT, None), - ) - stdout_mock.write_bytes.assert_called_with(rie_response) - @parameterized.expand( [ - (True, b'{"hello":"world"}', {"Date": "Tue, 02 Jan 2024 21:23:31 GMT", "Content-Type": "text"}), + ( + True, + b'{"hello":"world"}', + ), ( False, b"non-json-deserializable", - {"Date": "Tue, 02 Jan 2024 21:23:31 GMT", "Content-Type": "text/plain"}, ), - (False, b"", {"Date": "Tue, 02 Jan 2024 21:23:31 GMT", "Content-Type": "text/plain"}), + (False, b""), ] ) @patch("socket.socket") @patch("samcli.local.docker.container.requests") - def test_wait_for_result_no_error( - self, response_deserializable, rie_response, resp_headers, mock_requests, patched_socket - ): + def test_wait_for_result_no_error(self, response_deserializable, rie_response, mock_requests, patched_socket): self.container.is_created.return_value = True real_container_mock = Mock() @@ -672,7 +616,6 @@ def test_wait_for_result_no_error( stderr_mock = Mock() response = Mock() response.content = rie_response - response.headers = resp_headers mock_requests.post.return_value = response patched_socket.return_value = self.socket_mock diff --git a/tests/unit/local/lambda_service/test_local_lambda_invoke_service.py b/tests/unit/local/lambda_service/test_local_lambda_invoke_service.py index e338762db9..684e72f00c 100644 --- a/tests/unit/local/lambda_service/test_local_lambda_invoke_service.py +++ b/tests/unit/local/lambda_service/test_local_lambda_invoke_service.py @@ -135,7 +135,7 @@ def test_request_handler_returns_process_stdout_when_making_response( result = service._invoke_request_handler(function_name="HelloWorld") self.assertEqual(result, "request response") - lambda_output_parser_mock.get_lambda_output.assert_called_with(ANY, ANY) + lambda_output_parser_mock.get_lambda_output.assert_called_with(ANY) @patch("samcli.local.lambda_service.local_lambda_invoke_service.LambdaErrorResponses") def test_construct_error_handling(self, lambda_error_response_mock):