diff --git a/moto/core/common_types.py b/moto/core/common_types.py new file mode 100644 index 000000000000..1686d475fdad --- /dev/null +++ b/moto/core/common_types.py @@ -0,0 +1,5 @@ +from typing import Any, Dict, Tuple, TypeVar, Union + + +TYPE_RESPONSE = Tuple[int, Dict[str, Any], Union[str, bytes]] +TYPE_IF_NONE = TypeVar("TYPE_IF_NONE") \ No newline at end of file diff --git a/moto/core/exceptions.py b/moto/core/exceptions.py index 7c4b3db5b4cd..11c030560024 100644 --- a/moto/core/exceptions.py +++ b/moto/core/exceptions.py @@ -73,6 +73,11 @@ def get_headers(self, *args, **kwargs): # pylint: disable=unused-argument def get_body(self, *args, **kwargs): # pylint: disable=unused-argument return self.description + def to_json(self) -> "JsonRESTError": + err = JsonRESTError(error_type=self.error_type, message=self.message) + err.code = self.code + return err + class DryRunClientError(RESTError): code = 412 diff --git a/moto/sqs/models.py b/moto/sqs/models.py index d92a0400fd8d..9486bd746e4b 100644 --- a/moto/sqs/models.py +++ b/moto/sqs/models.py @@ -153,6 +153,10 @@ def utf8(string): def body(self): return escape(self._body).replace('"', """).replace("\r", " ") + @property + def original_body(self) -> str: + return self._body + def mark_sent(self, delay_seconds=None): self.sent_timestamp = int(unix_time_millis()) if delay_seconds: diff --git a/moto/sqs/responses.py b/moto/sqs/responses.py index c6dea26019ba..72c5358d3374 100644 --- a/moto/sqs/responses.py +++ b/moto/sqs/responses.py @@ -1,54 +1,107 @@ +import json import re +from functools import wraps +from typing import Union, Callable, Dict, Any -from moto.core.exceptions import RESTError +from moto.core.common_types import TYPE_RESPONSE +from moto.core.exceptions import JsonRESTError from moto.core.responses import BaseResponse from moto.core.utils import ( amz_crc32, amzn_request_id, underscores_to_camelcase, - camelcase_to_pascal, + camelcase_to_pascal, camelcase_to_underscores, ) +from moto.utilities.constants import JSON_TYPES from urllib.parse import urlparse from .exceptions import ( + RESTError, EmptyBatchRequest, InvalidAttributeName, ReceiptHandleIsInvalid, BatchEntryIdsNotDistinct, ) from .models import sqs_backends -from .utils import parse_message_attributes, extract_input_message_attributes +from .utils import ( + parse_message_attributes, + extract_input_message_attributes, + validate_message_attributes, +) MAXIMUM_VISIBILTY_TIMEOUT = 43200 MAXIMUM_MESSAGE_LENGTH = 262144 # 256 KiB DEFAULT_RECEIVED_MESSAGES = 1 +def jsonify_error( + method: Callable[["SQSResponse"], Union[str, TYPE_RESPONSE]] +) -> Callable[["SQSResponse"], Union[str, TYPE_RESPONSE]]: + """ + The decorator to convert an RESTError to JSON, if necessary + """ + + @wraps(method) + def f(self: "SQSResponse") -> Union[str, TYPE_RESPONSE]: + try: + return method(self) + except RESTError as e: + if self.is_json(): + raise e.to_json() + raise e + + return f + + class SQSResponse(BaseResponse): region_regex = re.compile(r"://(.+?)\.queue\.amazonaws\.com") + def is_json(self) -> bool: + """ + botocore 1.29.127 changed the wire-format to SQS + This means three things: + - The Content-Type is set to JSON + - The input-parameters are in different formats + - The output is in a different format + The change has been reverted for now, but it will be re-introduced later: + https://github.com/boto/botocore/pull/2931 + """ + return self.headers.get("Content-Type") in JSON_TYPES + @property def sqs_backend(self): return sqs_backends[self.region] @property def attribute(self): - if not hasattr(self, "_attribute"): - self._attribute = self._get_map_prefix( - "Attribute", key_end=".Name", value_end=".Value" - ) - return self._attribute + try: + assert self.is_json() + return json.loads(self.body).get("Attributes", {}) + except: # noqa: E722 Do not use bare except + if not hasattr(self, "_attribute"): + self._attribute = self._get_map_prefix( + "Attribute", key_end=".Name", value_end=".Value" + ) + return self._attribute @property def tags(self): if not hasattr(self, "_tags"): - self._tags = self._get_map_prefix("Tag", key_end=".Key", value_end=".Value") + if self.is_json(): + self._tags = self._get_param("tags") + else: + self._tags = self._get_map_prefix( + "Tag", key_end=".Key", value_end=".Value" + ) return self._tags def _get_queue_name(self): try: - queue_url = self.querystring.get("QueueUrl")[0] + if self.is_json(): + queue_url = self._get_param("QueueUrl") + else: + queue_url = self.querystring.get("QueueUrl")[0] # type: ignore if queue_url.startswith("http://") or queue_url.startswith("https://"): return queue_url.split("/")[-1] else: @@ -66,7 +119,10 @@ def _get_validated_visibility_timeout(self, timeout=None): if timeout is not None: visibility_timeout = int(timeout) else: - visibility_timeout = int(self.querystring.get("VisibilityTimeout")[0]) + if self.is_json(): + visibility_timeout = self._get_param("VisibilityTimeout") + else: + visibility_timeout = int(self.querystring.get("VisibilityTimeout")[0]) # type: if visibility_timeout > MAXIMUM_VISIBILTY_TIMEOUT: raise ValueError @@ -85,34 +141,57 @@ def call_action(self): return status_code, headers, body def _error(self, code, message, status=400): + if self.is_json(): + err = JsonRESTError(error_type=code, message=message) + err.code = status + raise err template = self.response_template(ERROR_TEMPLATE) - return template.render(code=code, message=message), dict(status=status) + return status, {"status": status}, template.render(code=code, message=message) + @jsonify_error def create_queue(self): request_url = urlparse(self.uri) queue_name = self._get_param("QueueName") queue = self.sqs_backend.create_queue(queue_name, self.tags, **self.attribute) + if self.is_json(): + return json.dumps({"QueueUrl": queue.url(request_url)}) + template = self.response_template(CREATE_QUEUE_RESPONSE) return template.render(queue_url=queue.url(request_url)) + @jsonify_error def get_queue_url(self): request_url = urlparse(self.uri) queue_name = self._get_param("QueueName") queue = self.sqs_backend.get_queue_url(queue_name) + if self.is_json(): + return json.dumps({"QueueUrl": queue.url(request_url)}) + template = self.response_template(GET_QUEUE_URL_RESPONSE) return template.render(queue_url=queue.url(request_url)) + @jsonify_error def list_queues(self): request_url = urlparse(self.uri) queue_name_prefix = self._get_param("QueueNamePrefix") queues = self.sqs_backend.list_queues(queue_name_prefix) + + if self.is_json(): + if queues: + return json.dumps( + {"QueueUrls": [queue.url(request_url) for queue in queues]} + ) + else: + return "{}" + template = self.response_template(LIST_QUEUES_RESPONSE) return template.render(queues=queues, request_url=request_url) + @jsonify_error def change_message_visibility(self): queue_name = self._get_queue_name() receipt_handle = self._get_param("ReceiptHandle") @@ -120,7 +199,7 @@ def change_message_visibility(self): try: visibility_timeout = self._get_validated_visibility_timeout() except ValueError: - return ERROR_MAX_VISIBILITY_TIMEOUT_RESPONSE, dict(status=400) + return 400, {}, ERROR_MAX_VISIBILITY_TIMEOUT_RESPONSE self.sqs_backend.change_message_visibility( queue_name=queue_name, @@ -128,12 +207,22 @@ def change_message_visibility(self): visibility_timeout=visibility_timeout, ) + if self.is_json(): + return "{}" + template = self.response_template(CHANGE_MESSAGE_VISIBILITY_RESPONSE) return template.render() + @jsonify_error def change_message_visibility_batch(self): queue_name = self._get_queue_name() - entries = self._get_list_prefix("ChangeMessageVisibilityBatchRequestEntry") + if self.is_json(): + entries = [ + {camelcase_to_underscores(key): value for key, value in entr.items()} + for entr in self._get_param("Entries") + ] + else: + entries = self._get_list_prefix("ChangeMessageVisibilityBatchRequestEntry") success = [] error = [] @@ -170,26 +259,50 @@ def change_message_visibility_batch(self): } ) + if self.is_json(): + return json.dumps( + {"Successful": [{"Id": _id} for _id in success], "Failed": error} + ) + template = self.response_template(CHANGE_MESSAGE_VISIBILITY_BATCH_RESPONSE) return template.render(success=success, errors=error) + @jsonify_error def get_queue_attributes(self): queue_name = self._get_queue_name() - if self.querystring.get("AttributeNames"): + if not self.is_json() and self.querystring.get("AttributeNames"): raise InvalidAttributeName("") - attribute_names = self._get_multi_param("AttributeName") - - # if connecting to AWS via boto, then 'AttributeName' is just a normal parameter - if not attribute_names: - attribute_names = self.querystring.get("AttributeName") + if self.is_json(): + attribute_names = self._get_param("AttributeNames") + if attribute_names == [] or (attribute_names and "" in attribute_names): + raise InvalidAttributeName("") + else: + # if connecting to AWS via boto, then 'AttributeName' is just a normal parameter + attribute_names = self._get_multi_param( + "AttributeName" + ) or self.querystring.get("AttributeName") attributes = self.sqs_backend.get_queue_attributes(queue_name, attribute_names) + if self.is_json(): + if len(attributes) == 0: + return "{}" + return json.dumps( + { + "Attributes": { + key: str(value) + for key, value in attributes.items() + if value is not None + } + } + ) + template = self.response_template(GET_QUEUE_ATTRIBUTES_RESPONSE) return template.render(attributes=attributes) + @jsonify_error def set_queue_attributes(self): # TODO validate self.get_param('QueueUrl') attribute = self.attribute @@ -206,6 +319,7 @@ def set_queue_attributes(self): return SET_QUEUE_ATTRIBUTE_RESPONSE + @jsonify_error def delete_queue(self): # TODO validate self.get_param('QueueUrl') queue_name = self._get_queue_name() @@ -215,6 +329,7 @@ def delete_queue(self): template = self.response_template(DELETE_QUEUE_RESPONSE) return template.render() + @jsonify_error def send_message(self): message = self._get_param("MessageBody") delay_seconds = int(self._get_param("DelaySeconds", 0)) @@ -222,12 +337,24 @@ def send_message(self): message_dedupe_id = self._get_param("MessageDeduplicationId") if len(message) > MAXIMUM_MESSAGE_LENGTH: - return ERROR_TOO_LONG_RESPONSE, dict(status=400) + return self._error( + "InvalidParameterValue", + message="One or more parameters are invalid. Reason: Message must be shorter than 262144 bytes.", + ) - message_attributes = parse_message_attributes(self.querystring) - system_message_attributes = parse_message_attributes( - self.querystring, key="MessageSystemAttribute" - ) + if self.is_json(): + message_attributes = self._get_param("MessageAttributes") + self.normalize_json_msg_attributes(message_attributes) + else: + message_attributes = parse_message_attributes(self.querystring) + + if self.is_json(): + system_message_attributes = self._get_param("MessageSystemAttributes") + self.normalize_json_msg_attributes(system_message_attributes) + else: + system_message_attributes = parse_message_attributes( + self.querystring, key="MessageSystemAttribute" + ) queue_name = self._get_queue_name() @@ -244,9 +371,30 @@ def send_message(self): except RESTError as err: return self._error(err.error_type, err.message) + if self.is_json(): + resp = { + "MD5OfMessageBody": message.body_md5, + "MessageId": message.id, + } + if len(message.message_attributes) > 0: + resp["MD5OfMessageAttributes"] = message.attribute_md5 + return json.dumps(resp) + template = self.response_template(SEND_MESSAGE_RESPONSE) return template.render(message=message, message_attributes=message_attributes) + def normalize_json_msg_attributes(self, message_attributes: Dict[str, Any]) -> None: + for key, value in (message_attributes or {}).items(): + if "BinaryValue" in value: + message_attributes[key]["binary_value"] = value.pop("BinaryValue") + if "StringValue" in value: + message_attributes[key]["string_value"] = value.pop("StringValue") + if "DataType" in value: + message_attributes[key]["data_type"] = value.pop("DataType") + + validate_message_attributes(message_attributes) + + @jsonify_error def send_message_batch(self): """ The querystring comes like this @@ -263,57 +411,70 @@ def send_message_batch(self): self.sqs_backend.get_queue(queue_name) - if self.querystring.get("Entries"): + if not self.is_json() and self.querystring.get("Entries"): raise EmptyBatchRequest() - entries = {} - for key, value in self.querystring.items(): - match = re.match(r"^SendMessageBatchRequestEntry\.(\d+)\.Id", key) - if match: - index = match.group(1) - - message_attributes = parse_message_attributes( - self.querystring, - base="SendMessageBatchRequestEntry.{}.".format(index), + if self.is_json(): + entries = { + str(idx): entry for idx, entry in enumerate(self._get_param("Entries")) + } + else: + entries = { + str(idx): entry + for idx, entry in enumerate( + self._get_multi_param("SendMessageBatchRequestEntry") ) + } + for entry in entries.values(): + if "MessageAttribute" in entry: + entry["MessageAttributes"] = { + val["Name"]: val["Value"] + for val in entry.pop("MessageAttribute") + } - entries[index] = { - "Id": value[0], - "MessageBody": self.querystring.get( - "SendMessageBatchRequestEntry.{}.MessageBody".format(index) - )[0], - "DelaySeconds": self.querystring.get( - "SendMessageBatchRequestEntry.{}.DelaySeconds".format(index), - [None], - )[0], - "MessageAttributes": message_attributes, - "MessageGroupId": self.querystring.get( - "SendMessageBatchRequestEntry.{}.MessageGroupId".format(index), - [None], - )[0], - "MessageDeduplicationId": self.querystring.get( - "SendMessageBatchRequestEntry.{}.MessageDeduplicationId".format( - index - ), - [None], - )[0], - } + for entry in entries.values(): + if "MessageAttributes" in entry: + self.normalize_json_msg_attributes(entry["MessageAttributes"]) + else: + entry["MessageAttributes"] = {} + if "DelaySeconds" not in entry: + entry["DelaySeconds"] = None if entries == {}: raise EmptyBatchRequest() messages = self.sqs_backend.send_message_batch(queue_name, entries) + if self.is_json(): + resp: Dict[str, Any] = {"Successful": [], "Failed": []} + for msg in messages: + msg_dict = { + "Id": msg.user_id, # type: ignore + "MessageId": msg.id, + "MD5OfMessageBody": msg.body_md5, + } + if len(msg.message_attributes) > 0: + msg_dict["MD5OfMessageAttributes"] = msg.attribute_md5 + resp["Successful"].append(msg_dict) + return json.dumps(resp) + template = self.response_template(SEND_MESSAGE_BATCH_RESPONSE) return template.render(messages=messages) + @jsonify_error def delete_message(self): queue_name = self._get_queue_name() - receipt_handle = self.querystring.get("ReceiptHandle")[0] + + if self.is_json(): + receipt_handle = self._get_param("ReceiptHandle") + else: + receipt_handle = self.querystring.get("ReceiptHandle")[0] # type: ignore + self.sqs_backend.delete_message(queue_name, receipt_handle) template = self.response_template(DELETE_MESSAGE_RESPONSE) return template.render() + @jsonify_error def delete_message_batch(self): """ The querystring comes like this @@ -326,23 +487,17 @@ def delete_message_batch(self): """ queue_name = self._get_queue_name() - receipts = [] + if self.is_json(): + receipts = self._get_param("Entries") + else: + receipts = self._get_multi_param("DeleteMessageBatchRequestEntry") - for index in range(1, 11): - # Loop through looking for messages - receipt_key = "DeleteMessageBatchRequestEntry.{0}.ReceiptHandle".format( - index - ) - receipt_handle = self.querystring.get(receipt_key) - if not receipt_handle: - # Found all messages - break - - message_user_id_key = "DeleteMessageBatchRequestEntry.{0}.Id".format(index) - message_user_id = self.querystring.get(message_user_id_key)[0] - receipts.append( - {"receipt_handle": receipt_handle[0], "msg_user_id": message_user_id} - ) + for r in receipts: + for key in list(r.keys()): + if key == "Id": + r["msg_user_id"] = r.pop(key) + else: + r[camelcase_to_underscores(key)] = r.pop(key) receipt_seen = set() for receipt_and_id in receipts: @@ -369,27 +524,45 @@ def delete_message_batch(self): } ) + if self.is_json(): + return json.dumps( + {"Successful": [{"Id": _id} for _id in success], "Failed": errors} + ) + template = self.response_template(DELETE_MESSAGE_BATCH_RESPONSE) return template.render(success=success, errors=errors) + @jsonify_error def purge_queue(self): queue_name = self._get_queue_name() self.sqs_backend.purge_queue(queue_name) template = self.response_template(PURGE_QUEUE_RESPONSE) return template.render() + @jsonify_error def receive_message(self): queue_name = self._get_queue_name() - message_attributes = self._get_multi_param("message_attributes") + if self.is_json(): + message_attributes = self._get_param("MessageAttributeNames") + else: + message_attributes = self._get_multi_param("message_attributes") if not message_attributes: message_attributes = extract_input_message_attributes(self.querystring) - attribute_names = self._get_multi_param("AttributeName") + if self.is_json(): + attribute_names = self._get_param("AttributeNames", []) + else: + attribute_names = self._get_multi_param("AttributeName") queue = self.sqs_backend.get_queue(queue_name) try: - message_count = int(self.querystring.get("MaxNumberOfMessages")[0]) + if self.is_json(): + message_count = self._get_param( + "MaxNumberOfMessages", DEFAULT_RECEIVED_MESSAGES + ) + else: + message_count = int(self.querystring.get("MaxNumberOfMessages")[0]) # type: ignore except TypeError: message_count = DEFAULT_RECEIVED_MESSAGES @@ -403,7 +576,11 @@ def receive_message(self): ) try: - wait_time = int(self.querystring.get("WaitTimeSeconds")[0]) + if self.is_json(): + wait_time = int(self._get_param("WaitTimeSeconds")) + else: + wait_time = int(self.querystring.get("WaitTimeSeconds")[0]) # type: ignore + except TypeError: wait_time = int(queue.receive_message_wait_time_seconds) @@ -421,7 +598,7 @@ def receive_message(self): except TypeError: visibility_timeout = queue.visibility_timeout except ValueError: - return ERROR_MAX_VISIBILITY_TIMEOUT_RESPONSE, dict(status=400) + return 400, {}, ERROR_MAX_VISIBILITY_TIMEOUT_RESPONSE messages = self.sqs_backend.receive_message( queue_name, message_count, wait_time, visibility_timeout, message_attributes @@ -441,23 +618,98 @@ def receive_message(self): pascalcase_name = camelcase_to_pascal(underscores_to_camelcase(attribute)) if any(x in ["All", pascalcase_name] for x in attribute_names): attributes[attribute] = True + if self.is_json(): + msgs = [] + for message in messages: + msg: Dict[str, Any] = { + "MessageId": message.id, + "ReceiptHandle": message.receipt_handle, + "MD5OfBody": message.body_md5, + "Body": message.original_body, + "Attributes": {}, + "MessageAttributes": {}, + } + if len(message.message_attributes) > 0: + msg["MD5OfMessageAttributes"] = message.attribute_md5 + if attributes["sender_id"]: + msg["Attributes"]["SenderId"] = message.sender_id + if attributes["sent_timestamp"]: + msg["Attributes"]["SentTimestamp"] = str(message.sent_timestamp) + if attributes["approximate_receive_count"]: + msg["Attributes"]["ApproximateReceiveCount"] = str( + message.approximate_receive_count + ) + if attributes["approximate_first_receive_timestamp"]: + msg["Attributes"]["ApproximateFirstReceiveTimestamp"] = str( + message.approximate_first_receive_timestamp + ) + if attributes["message_deduplication_id"]: + msg["Attributes"][ + "MessageDeduplicationId" + ] = message.deduplication_id + if attributes["message_group_id"] and message.group_id is not None: + msg["Attributes"]["MessageGroupId"] = message.group_id + if message.system_attributes and message.system_attributes.get( + "AWSTraceHeader" + ): + msg["Attributes"]["AWSTraceHeader"] = message.system_attributes[ + "AWSTraceHeader" + ].get("string_value") + if ( + attributes["sequence_number"] + and message.sequence_number is not None + ): + msg["Attributes"]["SequenceNumber"] = message.sequence_number + for name, value in message.message_attributes.items(): + msg["MessageAttributes"][name] = {"DataType": value["data_type"]} + if "Binary" in value["data_type"]: + msg["MessageAttributes"][name]["BinaryValue"] = value[ + "binary_value" + ] + else: + msg["MessageAttributes"][name]["StringValue"] = value[ + "string_value" + ] + + if len(msg["Attributes"]) == 0: + msg.pop("Attributes") + if len(msg["MessageAttributes"]) == 0: + msg.pop("MessageAttributes") + msgs.append(msg) + + return json.dumps({"Messages": msgs} if msgs else {}) template = self.response_template(RECEIVE_MESSAGE_RESPONSE) return template.render(messages=messages, attributes=attributes) + @jsonify_error def list_dead_letter_source_queues(self): request_url = urlparse(self.uri) queue_name = self._get_queue_name() - source_queue_urls = self.sqs_backend.list_dead_letter_source_queues(queue_name) + queues = self.sqs_backend.list_dead_letter_source_queues(queue_name) + + if self.is_json(): + return json.dumps( + {"queueUrls": [queue.url(request_url) for queue in queues]} + ) template = self.response_template(LIST_DEAD_LETTER_SOURCE_QUEUES_RESPONSE) - return template.render(queues=source_queue_urls, request_url=request_url) + return template.render(queues=queues, request_url=request_url) + @jsonify_error def add_permission(self): queue_name = self._get_queue_name() - actions = self._get_multi_param("ActionName") - account_ids = self._get_multi_param("AWSAccountId") + actions = ( + self._get_param("Actions") + if self.is_json() + else self._get_multi_param("ActionName") + ) + account_ids = ( + self._get_param("AWSAccountIds") + if self.is_json() + else self._get_multi_param("AWSAccountId") + ) label = self._get_param("Label") self.sqs_backend.add_permission(queue_name, actions, account_ids, label) @@ -465,6 +717,7 @@ def add_permission(self): template = self.response_template(ADD_PERMISSION_RESPONSE) return template.render() + @jsonify_error def remove_permission(self): queue_name = self._get_queue_name() label = self._get_param("Label") @@ -474,29 +727,48 @@ def remove_permission(self): template = self.response_template(REMOVE_PERMISSION_RESPONSE) return template.render() + @jsonify_error def tag_queue(self): queue_name = self._get_queue_name() - tags = self._get_map_prefix("Tag", key_end=".Key", value_end=".Value") + if self.is_json(): + tags = self._get_param("Tags") + else: + tags = self._get_map_prefix("Tag", key_end=".Key", value_end=".Value") self.sqs_backend.tag_queue(queue_name, tags) + if self.is_json(): + return "{}" + template = self.response_template(TAG_QUEUE_RESPONSE) return template.render() + @jsonify_error def untag_queue(self): queue_name = self._get_queue_name() - tag_keys = self._get_multi_param("TagKey") + tag_keys = ( + self._get_param("TagKeys") + if self.is_json() + else self._get_multi_param("TagKey") + ) self.sqs_backend.untag_queue(queue_name, tag_keys) + if self.is_json(): + return "{}" + template = self.response_template(UNTAG_QUEUE_RESPONSE) return template.render() + @jsonify_error def list_queue_tags(self): queue_name = self._get_queue_name() queue = self.sqs_backend.list_queue_tags(queue_name) + if self.is_json(): + return json.dumps({"Tags": queue.tags}) + template = self.response_template(LIST_QUEUE_TAGS_RESPONSE) return template.render(tags=queue.tags) @@ -782,16 +1054,6 @@ def list_queue_tags(self): """ -ERROR_TOO_LONG_RESPONSE = """ - - Sender - InvalidParameterValue - One or more parameters are invalid. Reason: Message must be shorter than 262144 bytes. - - - 6fde8d1e-52cd-4581-8cd9-c512f4c64223 -""" - ERROR_MAX_VISIBILITY_TIMEOUT_RESPONSE = ( f"Invalid request, maximum visibility timeout is {MAXIMUM_VISIBILTY_TIMEOUT}" ) diff --git a/moto/sqs/utils.py b/moto/sqs/utils.py index 24d3742489e8..0f93dd045b9c 100644 --- a/moto/sqs/utils.py +++ b/moto/sqs/utils.py @@ -1,5 +1,6 @@ import random import string +from typing import Dict, Any from .exceptions import MessageAttributesInvalid @@ -39,46 +40,52 @@ def parse_message_attributes( break data_type_key = base + "{0}.{1}.{2}DataType".format(key, index, value_namespace) - data_type = querystring.get(data_type_key) - if not data_type: - raise MessageAttributesInvalid( - "The message attribute '{0}' must contain non-empty message attribute value.".format( - name[0] - ) - ) + data_type = querystring.get(data_type_key, [None])[0] - data_type_parts = data_type[0].split(".") - if data_type_parts[0] not in [ - "String", - "Binary", - "Number", - ]: - raise MessageAttributesInvalid( - "The message attribute '{0}' has an invalid message attribute type, the set of supported type prefixes is Binary, Number, and String.".format( - name[0] - ) - ) + data_type_parts = (data_type or "").split(".")[0] type_prefix = "String" - if data_type_parts[0] == "Binary": + if data_type_parts == "Binary": type_prefix = "Binary" value_key = base + "{0}.{1}.{2}{3}Value".format( key, index, value_namespace, type_prefix ) - value = querystring.get(value_key) - if not value: - raise MessageAttributesInvalid( - "The message attribute '{0}' must contain non-empty message attribute value for message attribute type '{1}'.".format( - name[0], data_type[0] - ) - ) + value = querystring.get(value_key, [None])[0] message_attributes[name[0]] = { - "data_type": data_type[0], - type_prefix.lower() + "_value": value[0], + "data_type": data_type, + type_prefix.lower() + "_value": value, } index += 1 + validate_message_attributes(message_attributes) + return message_attributes + +def validate_message_attributes(message_attributes: Dict[str, Any]) -> None: + for name, value in (message_attributes or {}).items(): + data_type = value["data_type"] + + if not data_type: + raise MessageAttributesInvalid( + f"The message attribute '{name}' must contain non-empty message attribute value." + ) + + data_type_parts = data_type.split(".")[0] + if data_type_parts not in [ + "String", + "Binary", + "Number", + ]: + raise MessageAttributesInvalid( + f"The message attribute '{name}' has an invalid message attribute type, the set of supported type prefixes is Binary, Number, and String." + ) + + possible_value_fields = ["string_value", "binary_value"] + for field in possible_value_fields: + if field in value and value[field] is None: + raise MessageAttributesInvalid( + f"The message attribute '{name}' must contain non-empty message attribute value for message attribute type '{data_type}'." + ) \ No newline at end of file diff --git a/moto/utilities/constants.py b/moto/utilities/constants.py new file mode 100644 index 000000000000..3cf86f61b8c4 --- /dev/null +++ b/moto/utilities/constants.py @@ -0,0 +1,5 @@ +APPLICATION_AMZ_JSON_1_0 = "application/x-amz-json-1.0" +APPLICATION_AMZ_JSON_1_1 = "application/x-amz-json-1.1" +APPLICATION_JSON = "application/json" + +JSON_TYPES = [APPLICATION_JSON, APPLICATION_AMZ_JSON_1_0, APPLICATION_AMZ_JSON_1_1] \ No newline at end of file