Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

openlineage: isolate metadata extraction by executing OL methods in separate, forked process #40078

Merged
merged 1 commit into from
Jun 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions airflow/providers/google/cloud/openlineage/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,13 @@ def get_from_nullable_chain(source: Any, chain: list[str]) -> Any | None:
if not result:
return None
"""
# chain.pop modifies passed list, this can be unexpected
chain = chain.copy()
chain.reverse()
try:
while chain:
while isinstance(source, list) and len(source) == 1:
source = source[0]
next_key = chain.pop()
if isinstance(source, dict):
source = source.get(next_key)
Expand Down
17 changes: 16 additions & 1 deletion airflow/providers/openlineage/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,15 @@
import os
from typing import Any

from airflow.compat.functools import cache
# Disable caching if we're inside tests - this makes config easier to mock.
if os.getenv("PYTEST_VERSION"):

def decorator(func):
return func

cache = decorator
else:
from airflow.compat.functools import cache
from airflow.configuration import conf

_CONFIG_SECTION = "openlineage"
Expand Down Expand Up @@ -130,3 +138,10 @@ def dag_state_change_process_pool_size() -> int:
"""[openlineage] dag_state_change_process_pool_size."""
option = conf.get(_CONFIG_SECTION, "dag_state_change_process_pool_size", fallback="")
return _safe_int_convert(str(option).strip(), default=1)


@cache
def execution_timeout() -> int:
"""[openlineage] execution_timeout."""
option = conf.get(_CONFIG_SECTION, "execution_timeout", fallback="")
return _safe_int_convert(str(option).strip(), default=10)
51 changes: 48 additions & 3 deletions airflow/providers/openlineage/plugins/listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@
from __future__ import annotations

import logging
import os
from concurrent.futures import ProcessPoolExecutor
from datetime import datetime
from typing import TYPE_CHECKING

import psutil
from openlineage.client.serde import Serde
from packaging.version import Version
from setproctitle import getproctitle, setproctitle

from airflow import __version__ as AIRFLOW_VERSION, settings
from airflow.listeners import hookimpl
Expand All @@ -38,6 +41,7 @@
is_selective_lineage_enabled,
print_warning,
)
from airflow.settings import configure_orm
from airflow.stats import Stats
from airflow.utils.timeout import timeout

Expand Down Expand Up @@ -156,7 +160,7 @@ def on_running():
len(Serde.to_json(redacted_event).encode("utf-8")),
)

on_running()
self._execute(on_running, "on_running", use_fork=True)

@hookimpl
def on_task_instance_success(
Expand Down Expand Up @@ -223,7 +227,7 @@ def on_success():
len(Serde.to_json(redacted_event).encode("utf-8")),
)

on_success()
self._execute(on_success, "on_success", use_fork=True)

if _IS_AIRFLOW_2_10_OR_HIGHER:

Expand Down Expand Up @@ -318,10 +322,51 @@ def on_failure():
len(Serde.to_json(redacted_event).encode("utf-8")),
)

on_failure()
self._execute(on_failure, "on_failure", use_fork=True)

def _execute(self, callable, callable_name: str, use_fork: bool = False):
if use_fork:
self._fork_execute(callable, callable_name)
else:
callable()

def _terminate_with_wait(self, process: psutil.Process):
process.terminate()
try:
# Waiting for max 3 seconds to make sure process can clean up before being killed.
process.wait(timeout=3)
except psutil.TimeoutExpired:
# If it's not dead by then, then force kill.
process.kill()

def _fork_execute(self, callable, callable_name: str):
self.log.debug("Will fork to execute OpenLineage process.")
pid = os.fork()
if pid:
process = psutil.Process(pid)
try:
self.log.debug("Waiting for process %s", pid)
process.wait(conf.execution_timeout())
except psutil.TimeoutExpired:
self.log.warning(
"OpenLineage process %s expired. This should not affect process execution.", pid
)
self._terminate_with_wait(process)
except BaseException:
# Kill the process directly.
self._terminate_with_wait(process)
self.log.warning("Process with pid %s finished - parent", pid)
else:
setproctitle(getproctitle() + " - OpenLineage - " + callable_name)
configure_orm(disable_connection_pool=True)
self.log.debug("Executing OpenLineage process - %s - pid %s", callable_name, os.getpid())
callable()
self.log.debug("Process with current pid finishes after %s", callable_name)
os._exit(0)

@property
def executor(self) -> ProcessPoolExecutor:
# Executor for dag_run listener
def initializer():
# Re-configure the ORM engine as there are issues with multiple processes
# if process calls Airflow DB.
Expand Down
11 changes: 9 additions & 2 deletions airflow/providers/openlineage/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ dependencies:
- apache-airflow>=2.7.0
- apache-airflow-providers-common-sql>=1.6.0
- attrs>=22.2
- openlineage-integration-common>=1.15.0
- openlineage-python>=1.15.0
- openlineage-integration-common>=1.16.0
- openlineage-python>=1.16.0

integrations:
- integration-name: OpenLineage
Expand Down Expand Up @@ -144,3 +144,10 @@ config:
example: ~
type: integer
version_added: 1.8.0
execution_timeout:
description: |
Maximum amount of time (in seconds) that OpenLineage can spend executing metadata extraction.
default: "10"
example: ~
type: integer
version_added: 1.9.0
4 changes: 2 additions & 2 deletions generated/provider_dependencies.json
Original file line number Diff line number Diff line change
Expand Up @@ -913,8 +913,8 @@
"apache-airflow-providers-common-sql>=1.6.0",
"apache-airflow>=2.7.0",
"attrs>=22.2",
"openlineage-integration-common>=1.15.0",
"openlineage-python>=1.15.0"
"openlineage-integration-common>=1.16.0",
"openlineage-python>=1.16.0"
],
"devel-deps": [],
"plugins": [
Expand Down
60 changes: 60 additions & 0 deletions tests/dags/test_openlineage_execution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import datetime
import time

from openlineage.client.generated.base import Dataset

from airflow.models.dag import DAG
from airflow.models.operator import BaseOperator
from airflow.providers.openlineage.extractors import OperatorLineage


class OpenLineageExecutionOperator(BaseOperator):
def __init__(self, *, stall_amount=0, **kwargs) -> None:
super().__init__(**kwargs)
self.stall_amount = stall_amount

def execute(self, context):
self.log.error("STALL AMOUNT %s", self.stall_amount)
time.sleep(1)

def get_openlineage_facets_on_start(self):
return OperatorLineage(inputs=[Dataset(namespace="test", name="on-start")])

def get_openlineage_facets_on_complete(self, task_instance):
self.log.error("STALL AMOUNT %s", self.stall_amount)
time.sleep(self.stall_amount)
return OperatorLineage(inputs=[Dataset(namespace="test", name="on-complete")])


with DAG(
dag_id="test_openlineage_execution",
default_args={"owner": "airflow", "retries": 3, "start_date": datetime.datetime(2022, 1, 1)},
schedule="0 0 * * *",
dagrun_timeout=datetime.timedelta(minutes=60),
):
no_stall = OpenLineageExecutionOperator(task_id="execute_no_stall")

short_stall = OpenLineageExecutionOperator(task_id="execute_short_stall", stall_amount=5)

mid_stall = OpenLineageExecutionOperator(task_id="execute_mid_stall", stall_amount=15)

long_stall = OpenLineageExecutionOperator(task_id="execute_long_stall", stall_amount=30)
42 changes: 0 additions & 42 deletions tests/providers/openlineage/plugins/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,7 @@
from airflow.operators.bash import BashOperator
from airflow.operators.empty import EmptyOperator
from airflow.providers.openlineage.conf import (
config_path,
custom_extractors,
disabled_operators,
is_disabled,
is_source_enabled,
namespace,
transport,
)
from airflow.providers.openlineage.extractors import OperatorLineage
from airflow.providers.openlineage.plugins.adapter import _PRODUCER, OpenLineageAdapter
Expand All @@ -64,27 +58,6 @@
pytestmark = pytest.mark.db_test


@pytest.fixture(autouse=True)
def clear_cache():
config_path.cache_clear()
is_source_enabled.cache_clear()
disabled_operators.cache_clear()
custom_extractors.cache_clear()
namespace.cache_clear()
transport.cache_clear()
is_disabled.cache_clear()
try:
yield
finally:
config_path.cache_clear()
is_source_enabled.cache_clear()
disabled_operators.cache_clear()
custom_extractors.cache_clear()
namespace.cache_clear()
transport.cache_clear()
is_disabled.cache_clear()


@patch.dict(
os.environ,
{"OPENLINEAGE_URL": "http://ol-api:5000", "OPENLINEAGE_API_KEY": "api-key"},
Expand Down Expand Up @@ -155,9 +128,6 @@ def test_create_client_overrides_env_vars():
assert client.transport.kind == "http"
assert client.transport.url == "http://localhost:5050"

transport.cache_clear()
config_path.cache_clear()

with conf_vars({("openlineage", "transport"): '{"type": "console"}'}):
client = OpenLineageAdapter().get_or_create_openlineage_client()

Expand Down Expand Up @@ -893,9 +863,6 @@ def test_configuration_precedence_when_creating_ol_client():
assert client.transport.config.endpoint == "api/v1/lineage"
assert client.transport.config.auth.api_key == "random_token"

config_path.cache_clear()
transport.cache_clear()

# Second, check transport in Airflow configuration (airflow.cfg or env variable)
with patch.dict(
os.environ,
Expand All @@ -917,9 +884,6 @@ def test_configuration_precedence_when_creating_ol_client():
assert client.transport.kafka_config.topic == "test"
assert client.transport.kafka_config.config == {"acks": "all"}

config_path.cache_clear()
transport.cache_clear()

# Third, check legacy OPENLINEAGE_CONFIG env variable
with patch.dict(
os.environ,
Expand All @@ -942,9 +906,6 @@ def test_configuration_precedence_when_creating_ol_client():
assert client.transport.config.endpoint == "api/v1/lineage"
assert client.transport.config.auth.api_key == "random_token"

config_path.cache_clear()
transport.cache_clear()

# Fourth, check legacy OPENLINEAGE_URL env variable
with patch.dict(
os.environ,
Expand All @@ -967,9 +928,6 @@ def test_configuration_precedence_when_creating_ol_client():
assert client.transport.config.endpoint == "api/v1/lineage"
assert client.transport.config.auth.api_key == "test_api_key"

config_path.cache_clear()
transport.cache_clear()

# If all else fails, use console transport
with patch.dict(os.environ, {}, clear=True):
with conf_vars(
Expand Down
Loading