diff --git a/task_sdk/pyproject.toml b/task_sdk/pyproject.toml index e2eefdd5fcfdc..06b3a9404065f 100644 --- a/task_sdk/pyproject.toml +++ b/task_sdk/pyproject.toml @@ -27,7 +27,7 @@ dependencies = [ "httpx>=0.27.0", "jinja2>=3.1.4", "methodtools>=0.4.7", - "msgspec>=0.18.6", + "msgspec>=0.19.0", "psutil>=6.1.0", "structlog>=24.4.0", "retryhttp>=1.2.0", diff --git a/task_sdk/src/airflow/sdk/log.py b/task_sdk/src/airflow/sdk/log.py index 3290d0ff7f3e7..fa5b113588bf5 100644 --- a/task_sdk/src/airflow/sdk/log.py +++ b/task_sdk/src/airflow/sdk/log.py @@ -196,13 +196,11 @@ def logging_processors( else: exc_group_processor = None - encoder = msgspec.json.Encoder() - def json_dumps(msg, default): - return encoder.encode(msg) + return msgspec.json.encode(msg, enc_hook=default) def json_processor(logger: Any, method_name: Any, event_dict: EventDict) -> str: - return encoder.encode(event_dict).decode("utf-8") + return msgspec.json.encode(event_dict).decode("utf-8") json = structlog.processors.JSONRenderer(serializer=json_dumps) diff --git a/task_sdk/tests/conftest.py b/task_sdk/tests/conftest.py index 3fc7fc18015c7..e24f6e397d3e5 100644 --- a/task_sdk/tests/conftest.py +++ b/task_sdk/tests/conftest.py @@ -62,18 +62,19 @@ def pytest_runtest_setup(item): class LogCapture: # Like structlog.typing.LogCapture, but that doesn't add log_level in to the event dict - entries: list[EventDict] + entries: list[EventDict | bytes] def __init__(self) -> None: self.entries = [] - def __call__(self, _: WrappedLogger, method_name: str, event_dict: EventDict) -> NoReturn: + def __call__(self, _: WrappedLogger, method_name: str, event: EventDict | bytes) -> NoReturn: from structlog.exceptions import DropEvent - if "level" not in event_dict: - event_dict["_log_level"] = method_name + if isinstance(event, dict): + if "level" not in event: + event["_log_level"] = method_name - self.entries.append(event_dict) + self.entries.append(event) raise DropEvent @@ -93,20 +94,29 @@ def captured_logs(request): reset_logging() configure_logging(enable_pretty_log=False) - # Get log level from test parameter, defaulting to INFO if not provided - log_level = getattr(request, "param", logging.INFO) + # Get log level from test parameter, which can either be a single log level or a + # tuple of log level and desired output type, defaulting to INFO if not provided + log_level = logging.INFO + output = "dict" + param = getattr(request, "param", logging.INFO) + if isinstance(param, int): + log_level = param + elif isinstance(param, tuple): + log_level = param[0] + output = param[1] # We want to capture all logs, but we don't want to see them in the test output structlog.configure(wrapper_class=structlog.make_filtering_bound_logger(log_level)) - # But we need to replace remove the last processor (the one that turns JSON into text, as we want the - # event dict for tests) cur_processors = structlog.get_config()["processors"] processors = cur_processors.copy() - proc = processors.pop() - assert isinstance( - proc, (structlog.dev.ConsoleRenderer, structlog.processors.JSONRenderer) - ), "Pre-condition" + if output == "dict": + # We need to replace remove the last processor (the one that turns JSON into text, as we want the + # event dict for tests) + proc = processors.pop() + assert isinstance( + proc, (structlog.dev.ConsoleRenderer, structlog.processors.JSONRenderer) + ), "Pre-condition" try: cap = LogCapture() processors.append(cap) diff --git a/task_sdk/tests/test_log.py b/task_sdk/tests/test_log.py new file mode 100644 index 0000000000000..bf00f33e9a7ec --- /dev/null +++ b/task_sdk/tests/test_log.py @@ -0,0 +1,56 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import json +import logging +import unittest.mock + +import pytest +import structlog +from uuid6 import UUID + +from airflow.sdk.api.datamodels._generated import TaskInstance + + +@pytest.mark.parametrize( + "captured_logs", [(logging.INFO, "json")], indirect=True, ids=["log_level=info,formatter=json"] +) +def test_json_rendering(captured_logs): + """ + Test that the JSON formatter renders correctly. + """ + logger = structlog.get_logger() + logger.info( + "A test message with a Pydantic class", + pydantic_class=TaskInstance( + id=UUID("ffec3c8e-2898-46f8-b7d5-3cc571577368"), + dag_id="test_dag", + task_id="test_task", + run_id="test_run", + try_number=1, + ), + ) + assert captured_logs + assert isinstance(captured_logs[0], bytes) + assert json.loads(captured_logs[0]) == { + "event": "A test message with a Pydantic class", + "pydantic_class": "TaskInstance(id=UUID('ffec3c8e-2898-46f8-b7d5-3cc571577368'), task_id='test_task', dag_id='test_dag', run_id='test_run', try_number=1, map_index=-1, hostname=None)", + "timestamp": unittest.mock.ANY, + "level": "info", + }