Skip to content

Commit

Permalink
Add MessageDeduplicationId support to AWS SqsPublishOperator (apache#…
Browse files Browse the repository at this point in the history
…45051)

* add MessageDeduplicationId support to AWS SqsPublishOperator

* modified test

---------

Co-authored-by: pratiksha rajendrabhai badheka <pratiksha@DESKTOP-T5HUA05>
  • Loading branch information
2 people authored and got686-yandex committed Jan 30, 2025
1 parent 05ccd21 commit ce137c1
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 3 deletions.
4 changes: 4 additions & 0 deletions providers/src/airflow/providers/amazon/aws/hooks/sqs.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def send_message(
delay_seconds: int = 0,
message_attributes: dict | None = None,
message_group_id: str | None = None,
message_deduplication_id: str | None = None,
) -> dict:
"""
Send message to the queue.
Expand All @@ -71,6 +72,7 @@ def send_message(
:param delay_seconds: seconds to delay the message
:param message_attributes: additional attributes for the message (default: None)
:param message_group_id: This applies only to FIFO (first-in-first-out) queues. (default: None)
:param message_deduplication_id: This applies only to FIFO (first-in-first-out) queues.
:return: dict with the information about the message sent
"""
params = {
Expand All @@ -81,5 +83,7 @@ def send_message(
}
if message_group_id:
params["MessageGroupId"] = message_group_id
if message_deduplication_id:
params["MessageDeduplicationId"] = message_deduplication_id

return self.get_conn().send_message(**params)
6 changes: 6 additions & 0 deletions providers/src/airflow/providers/amazon/aws/operators/sqs.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ class SqsPublishOperator(AwsBaseOperator[SqsHook]):
:param delay_seconds: message delay (templated) (default: 1 second)
:param message_group_id: This parameter applies only to FIFO (first-in-first-out) queues. (default: None)
For details of the attributes parameter see :py:meth:`botocore.client.SQS.send_message`
:param message_deduplication_id: This applies only to FIFO (first-in-first-out) queues.
For details of the attributes parameter see :py:meth:`botocore.client.SQS.send_message`
:param aws_conn_id: The Airflow connection used for AWS credentials.
If this is ``None`` or empty then the default boto3 behaviour is used. If
running Airflow in a distributed manner and aws_conn_id is None or
Expand All @@ -63,6 +65,7 @@ class SqsPublishOperator(AwsBaseOperator[SqsHook]):
"delay_seconds",
"message_attributes",
"message_group_id",
"message_deduplication_id",
)
template_fields_renderers = {"message_attributes": "json"}
ui_color = "#6ad3fa"
Expand All @@ -75,6 +78,7 @@ def __init__(
message_attributes: dict | None = None,
delay_seconds: int = 0,
message_group_id: str | None = None,
message_deduplication_id: str | None = None,
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -83,6 +87,7 @@ def __init__(
self.delay_seconds = delay_seconds
self.message_attributes = message_attributes or {}
self.message_group_id = message_group_id
self.message_deduplication_id = message_deduplication_id

def execute(self, context: Context) -> dict:
"""
Expand All @@ -98,6 +103,7 @@ def execute(self, context: Context) -> dict:
delay_seconds=self.delay_seconds,
message_attributes=self.message_attributes,
message_group_id=self.message_group_id,
message_deduplication_id=self.message_deduplication_id,
)

self.log.info("send_message result: %s", result)
Expand Down
31 changes: 28 additions & 3 deletions providers/tests/amazon/aws/operators/test_sqs.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,27 +103,52 @@ def test_execute_failure_fifo_queue(self, mocked_context):
with pytest.raises(ClientError, match=error_message):
op.execute(mocked_context)

@mock_aws
def test_deduplication_failure(self, mocked_context):
self.sqs_client.create_queue(
QueueName=FIFO_QUEUE_NAME, Attributes={"FifoQueue": "true", "ContentBasedDeduplication": "false"}
)

op = SqsPublishOperator(**self.default_op_kwargs, sqs_queue=FIFO_QUEUE_NAME, message_group_id="abc")
error_message = (
r"An error occurred \(InvalidParameterValue\) when calling the SendMessage operation: "
r"The queue should either have ContentBasedDeduplication enabled or MessageDeduplicationId provided explicitly"
)
with pytest.raises(ClientError, match=error_message):
op.execute(mocked_context)

@mock_aws
def test_execute_success_fifo_queue(self, mocked_context):
self.sqs_client.create_queue(
QueueName=FIFO_QUEUE_NAME, Attributes={"FifoQueue": "true", "ContentBasedDeduplication": "true"}
)

# Send SQS Message into the FIFO Queue
op = SqsPublishOperator(**self.default_op_kwargs, sqs_queue=FIFO_QUEUE_NAME, message_group_id="abc")
op = SqsPublishOperator(
**self.default_op_kwargs,
sqs_queue=FIFO_QUEUE_NAME,
message_group_id="abc",
message_deduplication_id="abc",
)
result = op.execute(mocked_context)
assert "MD5OfMessageBody" in result
assert "MessageId" in result

# Validate message through moto
message = self.sqs_client.receive_message(QueueUrl=FIFO_QUEUE_URL, AttributeNames=["MessageGroupId"])
message = self.sqs_client.receive_message(
QueueUrl=FIFO_QUEUE_URL, AttributeNames=["MessageGroupId", "MessageDeduplicationId"]
)
assert len(message["Messages"]) == 1
assert message["Messages"][0]["MessageId"] == result["MessageId"]
assert message["Messages"][0]["Body"] == "hello"
assert message["Messages"][0]["Attributes"]["MessageGroupId"] == "abc"
assert message["Messages"][0]["Attributes"]["MessageDeduplicationId"] == "abc"

def test_template_fields(self):
operator = SqsPublishOperator(
**self.default_op_kwargs, sqs_queue=FIFO_QUEUE_NAME, message_group_id="abc"
**self.default_op_kwargs,
sqs_queue=FIFO_QUEUE_NAME,
message_group_id="abc",
message_deduplication_id="abc",
)
validate_template_fields(operator)

0 comments on commit ce137c1

Please sign in to comment.