From 378d44be3b51ea84e4370478e3fc8e284889a58d Mon Sep 17 00:00:00 2001 From: hussein-awala Date: Tue, 28 Jan 2025 21:58:11 +0100 Subject: [PATCH] feat(smtp): support html_content and subject templates from SMTP connection --- .../connections/smtp.rst | 2 + .../src/airflow/providers/smtp/hooks/smtp.py | 47 ++++++++++++-- .../providers/smtp/notifications/smtp.py | 61 +++++++++++-------- .../airflow/providers/smtp/operators/smtp.py | 34 ++++++++++- .../tests/smtp/notifications/test_smtp.py | 42 +++++++++++++ providers/tests/smtp/operators/test_smtp.py | 56 +++++++++++------ 6 files changed, 194 insertions(+), 48 deletions(-) diff --git a/docs/apache-airflow-providers-smtp/connections/smtp.rst b/docs/apache-airflow-providers-smtp/connections/smtp.rst index 80016ec6c4d1e..62e8548d0af87 100644 --- a/docs/apache-airflow-providers-smtp/connections/smtp.rst +++ b/docs/apache-airflow-providers-smtp/connections/smtp.rst @@ -62,6 +62,8 @@ Extra (optional) * ``ssl_context``: Can be "default" or "none". Only valid when SSL is used. The "default" context provides a balance between security and compatibility, "none" is not recommended as it disables validation of certificates and allow MITM attacks, and is only needed in case your certificates are wrongly configured in your system. If not specified, defaults are taken from the "smtp_provider", "ssl_context" configuration with the fallback to "email". "ssl_context" configuration. If none of it is specified, "default" is used. + * ``subject_template``: A path to a file containing the email subject template. + * ``html_content_template``: A path to a file containing the email html content template. When specifying the connection in environment variable you should specify it using URI syntax. diff --git a/providers/src/airflow/providers/smtp/hooks/smtp.py b/providers/src/airflow/providers/smtp/hooks/smtp.py index f6374e7fec91e..74707fad5f0a7 100644 --- a/providers/src/airflow/providers/smtp/hooks/smtp.py +++ b/providers/src/airflow/providers/smtp/hooks/smtp.py @@ -33,6 +33,7 @@ from email.mime.multipart import MIMEMultipart from email.mime.text import MIMEText from email.utils import formatdate +from pathlib import Path from typing import TYPE_CHECKING, Any from airflow.exceptions import AirflowException, AirflowNotFoundException @@ -160,6 +161,12 @@ def get_connection_form_widgets(cls) -> dict[str, Any]: ), "disable_tls": BooleanField(lazy_gettext("Disable TLS"), default=False), "disable_ssl": BooleanField(lazy_gettext("Disable SSL"), default=False), + "subject_template": StringField( + lazy_gettext("Path to the subject template"), widget=BS3TextFieldWidget() + ), + "html_content_template": StringField( + lazy_gettext("Path to the html content template"), widget=BS3TextFieldWidget() + ), } def test_connection(self) -> tuple[bool, str]: @@ -178,8 +185,8 @@ def send_email_smtp( self, *, to: str | Iterable[str], - subject: str, - html_content: str, + subject: str | None = None, + html_content: str | None = None, from_email: str | None = None, files: list[str] | None = None, dryrun: bool = False, @@ -194,8 +201,10 @@ def send_email_smtp( Send an email with html content. :param to: Recipient email address or list of addresses. - :param subject: Email subject. - :param html_content: Email body in HTML format. + :param subject: Email subject. If it's None, the hook will check if there is a path to a subject + file provided in the connection, and raises an exception if not. + :param html_content: Email body in HTML format. If it's None, the hook will check if there is a path + to a html content file provided in the connection, and raises an exception if not. :param from_email: Sender email address. If it's None, the hook will check if there is an email provided in the connection, and raises an exception if not. :param files: List of file paths to attach to the email. @@ -216,6 +225,18 @@ def send_email_smtp( from_email = from_email or self.from_email if not from_email: raise AirflowException("You should provide `from_email` or define it in the connection.") + if not subject: + if self.subject_template is None: + raise AirflowException( + "You should provide `subject` or define `subject_template` in the connection." + ) + subject = self._read_template(self.subject_template) + if not html_content: + if self.html_content_template is None: + raise AirflowException( + "You should provide `html_content` or define `html_content_template` in the connection." + ) + html_content = self._read_template(self.html_content_template) mime_msg, recipients = self._build_mime_message( mail_from=from_email, @@ -382,6 +403,24 @@ def timeout(self) -> int: def use_ssl(self) -> bool: return not bool(self.conn.extra_dejson.get("disable_ssl", False)) + @property + def subject_template(self) -> str | None: + return self.conn.extra_dejson.get("subject_template") + + @property + def html_content_template(self) -> str | None: + return self.conn.extra_dejson.get("html_content_template") + + @staticmethod + def _read_template(template_path: str) -> str: + """ + Read the content of a template file. + + :param template_path: The path to the template file. + :return: The content of the template file. + """ + return Path(template_path).read_text() + @classmethod def get_ui_field_behaviour(cls) -> dict[str, Any]: """Return custom field behaviour.""" diff --git a/providers/src/airflow/providers/smtp/notifications/smtp.py b/providers/src/airflow/providers/smtp/notifications/smtp.py index fb043e59c071a..01ca19f80e7a2 100644 --- a/providers/src/airflow/providers/smtp/notifications/smtp.py +++ b/providers/src/airflow/providers/smtp/notifications/smtp.py @@ -46,10 +46,7 @@ class SmtpNotifier(BaseNotifier): ), ) - Default template can be overridden via the following provider configuration data: - - templated_email_subject_path - - templated_html_content_path - + You can define a default template for subject and html_content in the SMTP connection configuration. :param smtp_conn_id: The :ref:`smtp connection id ` that contains the information used to authenticate the client. @@ -96,28 +93,14 @@ def __init__( self.mime_subtype = mime_subtype self.mime_charset = mime_charset self.custom_headers = custom_headers + self.subject = subject + self.html_content = html_content + if self.html_content is None and template is not None: + self.html_content = self._read_template(template) - smtp_default_templated_subject_path = conf.get( - "smtp", - "templated_email_subject_path", - fallback=(Path(__file__).parent / "templates" / "email_subject.jinja2").as_posix(), - ) - self.subject = ( - subject or Path(smtp_default_templated_subject_path).read_text().replace("\n", "").strip() - ) - # If html_content is passed, prioritize it. Otherwise, if template is passed, use - # it to populate html_content. Else, fall back to defaults defined in settings - if html_content is not None: - self.html_content = html_content - elif template is not None: - self.html_content = Path(template).read_text() - else: - smtp_default_templated_html_content_path = conf.get( - "smtp", - "templated_html_content_path", - fallback=(Path(__file__).parent / "templates" / "email.html").as_posix(), - ) - self.html_content = Path(smtp_default_templated_html_content_path).read_text() + @staticmethod + def _read_template(template_path: str) -> str: + return Path(template_path).read_text().replace("\n", "").strip() @cached_property def hook(self) -> SmtpHook: @@ -126,6 +109,34 @@ def hook(self) -> SmtpHook: def notify(self, context): """Send a email via smtp server.""" + fields_to_re_render = [] + if self.subject is None: + smtp_default_templated_subject_path: str + if self.hook.subject_template: + smtp_default_templated_subject_path = self.hook.subject_template + else: + smtp_default_templated_subject_path = conf.get( + "smtp", + "templated_email_subject_path", + fallback=(Path(__file__).parent / "templates" / "email_subject.jinja2").as_posix(), + ) + self.subject = self._read_template(smtp_default_templated_subject_path) + fields_to_re_render.append("subject") + if self.html_content is None: + smtp_default_templated_html_content_path: str + if self.hook.html_content_template: + smtp_default_templated_html_content_path = self.hook.html_content_template + else: + smtp_default_templated_html_content_path = conf.get( + "smtp", + "templated_html_content_path", + fallback=(Path(__file__).parent / "templates" / "email.html").as_posix(), + ) + self.html_content = self._read_template(smtp_default_templated_html_content_path) + fields_to_re_render.append("html_content") + if fields_to_re_render: + jinja_env = self.get_template_env(dag=context["dag"]) + self._do_render_template_fields(self, fields_to_re_render, context, jinja_env, set()) with self.hook as smtp: smtp.send_email_smtp( smtp_conn_id=self.smtp_conn_id, diff --git a/providers/src/airflow/providers/smtp/operators/smtp.py b/providers/src/airflow/providers/smtp/operators/smtp.py index 2c097e8aa8b84..d9880f595fe7b 100644 --- a/providers/src/airflow/providers/smtp/operators/smtp.py +++ b/providers/src/airflow/providers/smtp/operators/smtp.py @@ -18,8 +18,10 @@ from __future__ import annotations from collections.abc import Sequence +from pathlib import Path from typing import TYPE_CHECKING, Any +from airflow.exceptions import AirflowException from airflow.models import BaseOperator from airflow.providers.smtp.hooks.smtp import SmtpHook @@ -58,8 +60,8 @@ def __init__( self, *, to: list[str] | str, - subject: str, - html_content: str, + subject: str | None = None, + html_content: str | None = None, from_email: str | None = None, files: list | None = None, cc: list[str] | str | None = None, @@ -83,8 +85,36 @@ def __init__( self.conn_id = conn_id self.custom_headers = custom_headers + @staticmethod + def _read_template(template_path: str) -> str: + return Path(template_path).read_text().replace("\n", "").strip() + def execute(self, context: Context): with SmtpHook(smtp_conn_id=self.conn_id) as smtp_hook: + fields_to_re_render = [] + if self.from_email is None: + if smtp_hook.from_email is None: + raise AirflowException("You should provide `from_email` or define it in the connection.") + self.from_email = smtp_hook.from_email + fields_to_re_render.append("from_email") + if self.subject is None: + if smtp_hook.subject_template is None: + raise AirflowException( + "You should provide `subject` or define `subject_template` in the connection." + ) + self.subject = self._read_template(smtp_hook.subject_template) + fields_to_re_render.append("subject") + if self.html_content is None: + if smtp_hook.html_content_template is None: + raise AirflowException( + "You should provide `html_content` or define `html_content_template` in the connection." + ) + self.html_content = self._read_template(smtp_hook.html_content_template) + fields_to_re_render.append("html_content") + if fields_to_re_render: + self._do_render_template_fields( + self, fields_to_re_render, context, self.get_template_env(), set() + ) return smtp_hook.send_email_smtp( to=self.to, subject=self.subject, diff --git a/providers/tests/smtp/notifications/test_smtp.py b/providers/tests/smtp/notifications/test_smtp.py index 8ce28637d8fb3..828c671c84252 100644 --- a/providers/tests/smtp/notifications/test_smtp.py +++ b/providers/tests/smtp/notifications/test_smtp.py @@ -128,6 +128,8 @@ def test_notifier_with_defaults(self, mock_smtphook_hook, create_task_instance): from_email=conf.get("smtp", "smtp_mail_from"), to="test_reciver@test.com", ) + mock_smtphook_hook.return_value.subject_template = None + mock_smtphook_hook.return_value.html_content_template = None notifier(context) mock_smtphook_hook.return_value.__enter__().send_email_smtp.assert_called_once_with( from_email=conf.get("smtp", "smtp_mail_from"), @@ -150,6 +152,10 @@ def test_notifier_with_nondefault_conf_vars(self, mock_smtphook_hook, create_tas ti = create_task_instance(dag_id="dag", task_id="op", logical_date=timezone.datetime(2018, 1, 1)) context = {"dag": ti.dag_run.dag, "ti": ti} + mock_smtphook_hook.return_value.from_email = None + mock_smtphook_hook.return_value.subject_template = None + mock_smtphook_hook.return_value.html_content_template = None + with ( tempfile.NamedTemporaryFile(mode="wt", suffix=".txt") as f_subject, tempfile.NamedTemporaryFile(mode="wt", suffix=".txt") as f_content, @@ -184,3 +190,39 @@ def test_notifier_with_nondefault_conf_vars(self, mock_smtphook_hook, create_tas mime_charset="utf-8", custom_headers=None, ) + + @mock.patch("airflow.providers.smtp.notifications.smtp.SmtpHook") + def test_notifier_with_nondefault_connection_extra(self, mock_smtphook_hook, create_task_instance): + ti = create_task_instance(dag_id="dag", task_id="op", logical_date=timezone.datetime(2018, 1, 1)) + context = {"dag": ti.dag_run.dag, "ti": ti} + + with ( + tempfile.NamedTemporaryFile(mode="wt", suffix=".txt") as f_subject, + tempfile.NamedTemporaryFile(mode="wt", suffix=".txt") as f_content, + ): + f_subject.write("Task {{ ti.task_id }} failed") + f_subject.flush() + + f_content.write("Mock content goes here") + f_content.flush() + + mock_smtphook_hook.return_value.subject_template = f_subject.name + mock_smtphook_hook.return_value.html_content_template = f_content.name + notifier = SmtpNotifier( + from_email=conf.get("smtp", "smtp_mail_from"), + to="test_reciver@test.com", + ) + notifier(context) + mock_smtphook_hook.return_value.__enter__().send_email_smtp.assert_called_once_with( + from_email=conf.get("smtp", "smtp_mail_from"), + to="test_reciver@test.com", + subject="Task op failed", + html_content="Mock content goes here", + smtp_conn_id="smtp_default", + files=None, + cc=None, + bcc=None, + mime_subtype="mixed", + mime_charset="utf-8", + custom_headers=None, + ) diff --git a/providers/tests/smtp/operators/test_smtp.py b/providers/tests/smtp/operators/test_smtp.py index 6239c058dadf5..3d662b1d7349c 100644 --- a/providers/tests/smtp/operators/test_smtp.py +++ b/providers/tests/smtp/operators/test_smtp.py @@ -17,8 +17,10 @@ # under the License. from __future__ import annotations +import base64 import json -from unittest.mock import patch +import tempfile +from unittest.mock import MagicMock, patch from airflow.models import Connection from airflow.providers.smtp.operators.smtp import EmailOperator @@ -37,22 +39,42 @@ def test_loading_sender_email_from_connection(self, mock_smtplib, mock_hook_conn custom_retry_limit = 10 custom_timeout = 60 sender_email = "sender_email" - mock_hook_conn.return_value = Connection( - conn_id="mock_conn", - conn_type="smtp", - host="smtp_server_address", - login="smtp_user", - password="smtp_password", - port=465, - extra=json.dumps( - dict(from_email=sender_email, timeout=custom_timeout, retry_limit=custom_retry_limit) - ), - ) - smtp_client_mock = mock_smtplib.SMTP_SSL() - op = EmailOperator(task_id="test_email", **self.default_op_kwargs) - op.execute({}) - call_args = smtp_client_mock.sendmail.call_args.kwargs - assert call_args["from_addr"] == sender_email + with ( + tempfile.NamedTemporaryFile(mode="wt", suffix=".txt") as f_subject, + tempfile.NamedTemporaryFile(mode="wt", suffix=".txt") as f_content, + ): + f_subject.write("Task {{ ti.task_id }} failed") + f_subject.flush() + + f_content.write("Mock content goes here") + f_content.flush() + + mock_hook_conn.return_value = Connection( + conn_id="mock_conn", + conn_type="smtp", + host="smtp_server_address", + login="smtp_user", + password="smtp_password", + port=465, + extra=json.dumps( + dict( + from_email=sender_email, + timeout=custom_timeout, + retry_limit=custom_retry_limit, + subject_template=f_subject.name, + html_content_template=f_content.name, + ) + ), + ) + smtp_client_mock = mock_smtplib.SMTP_SSL() + op = EmailOperator(task_id="test_email", to="to") + op.execute({"ti": MagicMock(task_id="some_id")}) + call_args = smtp_client_mock.sendmail.call_args.kwargs + assert call_args["from_addr"] == sender_email + assert "Subject: Task some_id failed" in call_args["msg"] + assert ( + base64.b64encode("Mock content goes here".encode("ascii")).decode("ascii") in call_args["msg"] + ) def test_assert_templated_fields(self): """Test expected templated fields."""