Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(smtp): support html_content and subject templates from SMTP connection #46212

Merged
merged 1 commit into from
Jan 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/apache-airflow-providers-smtp/connections/smtp.rst
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ Extra (optional)
* ``ssl_context``: Can be "default" or "none". Only valid when SSL is used. The "default" context provides a balance between security and compatibility, "none" is not recommended
as it disables validation of certificates and allow MITM attacks, and is only needed in case your certificates are wrongly configured in your system. If not specified, defaults are taken from the
"smtp_provider", "ssl_context" configuration with the fallback to "email". "ssl_context" configuration. If none of it is specified, "default" is used.
* ``subject_template``: A path to a file containing the email subject template.
* ``html_content_template``: A path to a file containing the email html content template.

When specifying the connection in environment variable you should specify
it using URI syntax.
Expand Down
47 changes: 43 additions & 4 deletions providers/src/airflow/providers/smtp/hooks/smtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from email.utils import formatdate
from pathlib import Path
from typing import TYPE_CHECKING, Any

from airflow.exceptions import AirflowException, AirflowNotFoundException
Expand Down Expand Up @@ -160,6 +161,12 @@ def get_connection_form_widgets(cls) -> dict[str, Any]:
),
"disable_tls": BooleanField(lazy_gettext("Disable TLS"), default=False),
"disable_ssl": BooleanField(lazy_gettext("Disable SSL"), default=False),
"subject_template": StringField(
lazy_gettext("Path to the subject template"), widget=BS3TextFieldWidget()
),
"html_content_template": StringField(
lazy_gettext("Path to the html content template"), widget=BS3TextFieldWidget()
),
}

def test_connection(self) -> tuple[bool, str]:
Expand All @@ -178,8 +185,8 @@ def send_email_smtp(
self,
*,
to: str | Iterable[str],
subject: str,
html_content: str,
subject: str | None = None,
html_content: str | None = None,
from_email: str | None = None,
files: list[str] | None = None,
dryrun: bool = False,
Expand All @@ -194,8 +201,10 @@ def send_email_smtp(
Send an email with html content.

:param to: Recipient email address or list of addresses.
:param subject: Email subject.
:param html_content: Email body in HTML format.
:param subject: Email subject. If it's None, the hook will check if there is a path to a subject
file provided in the connection, and raises an exception if not.
:param html_content: Email body in HTML format. If it's None, the hook will check if there is a path
to a html content file provided in the connection, and raises an exception if not.
:param from_email: Sender email address. If it's None, the hook will check if there is an email
provided in the connection, and raises an exception if not.
:param files: List of file paths to attach to the email.
Expand All @@ -216,6 +225,18 @@ def send_email_smtp(
from_email = from_email or self.from_email
if not from_email:
raise AirflowException("You should provide `from_email` or define it in the connection.")
if not subject:
if self.subject_template is None:
raise AirflowException(
"You should provide `subject` or define `subject_template` in the connection."
)
subject = self._read_template(self.subject_template)
if not html_content:
if self.html_content_template is None:
raise AirflowException(
"You should provide `html_content` or define `html_content_template` in the connection."
)
html_content = self._read_template(self.html_content_template)

mime_msg, recipients = self._build_mime_message(
mail_from=from_email,
Expand Down Expand Up @@ -382,6 +403,24 @@ def timeout(self) -> int:
def use_ssl(self) -> bool:
return not bool(self.conn.extra_dejson.get("disable_ssl", False))

@property
def subject_template(self) -> str | None:
return self.conn.extra_dejson.get("subject_template")

@property
def html_content_template(self) -> str | None:
return self.conn.extra_dejson.get("html_content_template")

@staticmethod
def _read_template(template_path: str) -> str:
"""
Read the content of a template file.

:param template_path: The path to the template file.
:return: The content of the template file.
"""
return Path(template_path).read_text()

@classmethod
def get_ui_field_behaviour(cls) -> dict[str, Any]:
"""Return custom field behaviour."""
Expand Down
61 changes: 36 additions & 25 deletions providers/src/airflow/providers/smtp/notifications/smtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,7 @@ class SmtpNotifier(BaseNotifier):
),
)

Default template can be overridden via the following provider configuration data:
- templated_email_subject_path
- templated_html_content_path

You can define a default template for subject and html_content in the SMTP connection configuration.

:param smtp_conn_id: The :ref:`smtp connection id <howto/connection:smtp>`
that contains the information used to authenticate the client.
Expand Down Expand Up @@ -96,28 +93,14 @@ def __init__(
self.mime_subtype = mime_subtype
self.mime_charset = mime_charset
self.custom_headers = custom_headers
self.subject = subject
self.html_content = html_content
if self.html_content is None and template is not None:
self.html_content = self._read_template(template)

smtp_default_templated_subject_path = conf.get(
"smtp",
"templated_email_subject_path",
fallback=(Path(__file__).parent / "templates" / "email_subject.jinja2").as_posix(),
)
self.subject = (
subject or Path(smtp_default_templated_subject_path).read_text().replace("\n", "").strip()
)
# If html_content is passed, prioritize it. Otherwise, if template is passed, use
# it to populate html_content. Else, fall back to defaults defined in settings
if html_content is not None:
self.html_content = html_content
elif template is not None:
self.html_content = Path(template).read_text()
else:
smtp_default_templated_html_content_path = conf.get(
"smtp",
"templated_html_content_path",
fallback=(Path(__file__).parent / "templates" / "email.html").as_posix(),
)
self.html_content = Path(smtp_default_templated_html_content_path).read_text()
@staticmethod
def _read_template(template_path: str) -> str:
return Path(template_path).read_text().replace("\n", "").strip()

@cached_property
def hook(self) -> SmtpHook:
Expand All @@ -126,6 +109,34 @@ def hook(self) -> SmtpHook:

def notify(self, context):
"""Send a email via smtp server."""
fields_to_re_render = []
if self.subject is None:
smtp_default_templated_subject_path: str
if self.hook.subject_template:
smtp_default_templated_subject_path = self.hook.subject_template
else:
smtp_default_templated_subject_path = conf.get(
"smtp",
"templated_email_subject_path",
fallback=(Path(__file__).parent / "templates" / "email_subject.jinja2").as_posix(),
)
self.subject = self._read_template(smtp_default_templated_subject_path)
fields_to_re_render.append("subject")
if self.html_content is None:
smtp_default_templated_html_content_path: str
if self.hook.html_content_template:
smtp_default_templated_html_content_path = self.hook.html_content_template
else:
smtp_default_templated_html_content_path = conf.get(
"smtp",
"templated_html_content_path",
fallback=(Path(__file__).parent / "templates" / "email.html").as_posix(),
)
self.html_content = self._read_template(smtp_default_templated_html_content_path)
fields_to_re_render.append("html_content")
if fields_to_re_render:
jinja_env = self.get_template_env(dag=context["dag"])
self._do_render_template_fields(self, fields_to_re_render, context, jinja_env, set())
with self.hook as smtp:
smtp.send_email_smtp(
smtp_conn_id=self.smtp_conn_id,
Expand Down
34 changes: 32 additions & 2 deletions providers/src/airflow/providers/smtp/operators/smtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
from __future__ import annotations

from collections.abc import Sequence
from pathlib import Path
from typing import TYPE_CHECKING, Any

from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.providers.smtp.hooks.smtp import SmtpHook

Expand Down Expand Up @@ -58,8 +60,8 @@ def __init__(
self,
*,
to: list[str] | str,
subject: str,
html_content: str,
subject: str | None = None,
html_content: str | None = None,
from_email: str | None = None,
files: list | None = None,
cc: list[str] | str | None = None,
Expand All @@ -83,8 +85,36 @@ def __init__(
self.conn_id = conn_id
self.custom_headers = custom_headers

@staticmethod
def _read_template(template_path: str) -> str:
return Path(template_path).read_text().replace("\n", "").strip()

def execute(self, context: Context):
with SmtpHook(smtp_conn_id=self.conn_id) as smtp_hook:
fields_to_re_render = []
if self.from_email is None:
if smtp_hook.from_email is None:
raise AirflowException("You should provide `from_email` or define it in the connection.")
self.from_email = smtp_hook.from_email
fields_to_re_render.append("from_email")
if self.subject is None:
if smtp_hook.subject_template is None:
raise AirflowException(
"You should provide `subject` or define `subject_template` in the connection."
)
self.subject = self._read_template(smtp_hook.subject_template)
fields_to_re_render.append("subject")
if self.html_content is None:
if smtp_hook.html_content_template is None:
raise AirflowException(
"You should provide `html_content` or define `html_content_template` in the connection."
)
self.html_content = self._read_template(smtp_hook.html_content_template)
fields_to_re_render.append("html_content")
if fields_to_re_render:
self._do_render_template_fields(
self, fields_to_re_render, context, self.get_template_env(), set()
)
return smtp_hook.send_email_smtp(
to=self.to,
subject=self.subject,
Expand Down
42 changes: 42 additions & 0 deletions providers/tests/smtp/notifications/test_smtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ def test_notifier_with_defaults(self, mock_smtphook_hook, create_task_instance):
from_email=conf.get("smtp", "smtp_mail_from"),
to="[email protected]",
)
mock_smtphook_hook.return_value.subject_template = None
mock_smtphook_hook.return_value.html_content_template = None
notifier(context)
mock_smtphook_hook.return_value.__enter__().send_email_smtp.assert_called_once_with(
from_email=conf.get("smtp", "smtp_mail_from"),
Expand All @@ -150,6 +152,10 @@ def test_notifier_with_nondefault_conf_vars(self, mock_smtphook_hook, create_tas
ti = create_task_instance(dag_id="dag", task_id="op", logical_date=timezone.datetime(2018, 1, 1))
context = {"dag": ti.dag_run.dag, "ti": ti}

mock_smtphook_hook.return_value.from_email = None
mock_smtphook_hook.return_value.subject_template = None
mock_smtphook_hook.return_value.html_content_template = None

with (
tempfile.NamedTemporaryFile(mode="wt", suffix=".txt") as f_subject,
tempfile.NamedTemporaryFile(mode="wt", suffix=".txt") as f_content,
Expand Down Expand Up @@ -184,3 +190,39 @@ def test_notifier_with_nondefault_conf_vars(self, mock_smtphook_hook, create_tas
mime_charset="utf-8",
custom_headers=None,
)

@mock.patch("airflow.providers.smtp.notifications.smtp.SmtpHook")
def test_notifier_with_nondefault_connection_extra(self, mock_smtphook_hook, create_task_instance):
ti = create_task_instance(dag_id="dag", task_id="op", logical_date=timezone.datetime(2018, 1, 1))
context = {"dag": ti.dag_run.dag, "ti": ti}

with (
tempfile.NamedTemporaryFile(mode="wt", suffix=".txt") as f_subject,
tempfile.NamedTemporaryFile(mode="wt", suffix=".txt") as f_content,
):
f_subject.write("Task {{ ti.task_id }} failed")
f_subject.flush()

f_content.write("Mock content goes here")
f_content.flush()

mock_smtphook_hook.return_value.subject_template = f_subject.name
mock_smtphook_hook.return_value.html_content_template = f_content.name
notifier = SmtpNotifier(
from_email=conf.get("smtp", "smtp_mail_from"),
to="[email protected]",
)
notifier(context)
mock_smtphook_hook.return_value.__enter__().send_email_smtp.assert_called_once_with(
from_email=conf.get("smtp", "smtp_mail_from"),
to="[email protected]",
subject="Task op failed",
html_content="Mock content goes here",
smtp_conn_id="smtp_default",
files=None,
cc=None,
bcc=None,
mime_subtype="mixed",
mime_charset="utf-8",
custom_headers=None,
)
56 changes: 39 additions & 17 deletions providers/tests/smtp/operators/test_smtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
# under the License.
from __future__ import annotations

import base64
import json
from unittest.mock import patch
import tempfile
from unittest.mock import MagicMock, patch

from airflow.models import Connection
from airflow.providers.smtp.operators.smtp import EmailOperator
Expand All @@ -37,22 +39,42 @@ def test_loading_sender_email_from_connection(self, mock_smtplib, mock_hook_conn
custom_retry_limit = 10
custom_timeout = 60
sender_email = "sender_email"
mock_hook_conn.return_value = Connection(
conn_id="mock_conn",
conn_type="smtp",
host="smtp_server_address",
login="smtp_user",
password="smtp_password",
port=465,
extra=json.dumps(
dict(from_email=sender_email, timeout=custom_timeout, retry_limit=custom_retry_limit)
),
)
smtp_client_mock = mock_smtplib.SMTP_SSL()
op = EmailOperator(task_id="test_email", **self.default_op_kwargs)
op.execute({})
call_args = smtp_client_mock.sendmail.call_args.kwargs
assert call_args["from_addr"] == sender_email
with (
tempfile.NamedTemporaryFile(mode="wt", suffix=".txt") as f_subject,
tempfile.NamedTemporaryFile(mode="wt", suffix=".txt") as f_content,
):
f_subject.write("Task {{ ti.task_id }} failed")
f_subject.flush()

f_content.write("Mock content goes here")
f_content.flush()

mock_hook_conn.return_value = Connection(
conn_id="mock_conn",
conn_type="smtp",
host="smtp_server_address",
login="smtp_user",
password="smtp_password",
port=465,
extra=json.dumps(
dict(
from_email=sender_email,
timeout=custom_timeout,
retry_limit=custom_retry_limit,
subject_template=f_subject.name,
html_content_template=f_content.name,
)
),
)
smtp_client_mock = mock_smtplib.SMTP_SSL()
op = EmailOperator(task_id="test_email", to="to")
op.execute({"ti": MagicMock(task_id="some_id")})
call_args = smtp_client_mock.sendmail.call_args.kwargs
assert call_args["from_addr"] == sender_email
assert "Subject: Task some_id failed" in call_args["msg"]
assert (
base64.b64encode("Mock content goes here".encode("ascii")).decode("ascii") in call_args["msg"]
)

def test_assert_templated_fields(self):
"""Test expected templated fields."""
Expand Down
Loading