diff --git a/.github/boring-cyborg.yml b/.github/boring-cyborg.yml index 045a49b4f6e77..cf37c187c6ba9 100644 --- a/.github/boring-cyborg.yml +++ b/.github/boring-cyborg.yml @@ -43,9 +43,7 @@ labelPRBasedOnFilePath: - providers/apache/druid/** provider:apache-flink: - - providers/src/airflow/providers/apache/flink/**/* - - docs/apache-airflow-providers-apache-flink/**/* - - providers/tests/apache/flink/**/* + - providers/apache/flink/** provider:apache-hdfs: - providers/src/airflow/providers/apache/hdfs/**/* @@ -59,9 +57,7 @@ labelPRBasedOnFilePath: - providers/apache/iceberg/** provider:apache-impala: - - providers/src/airflow/providers/apache/impala/**/* - - docs/apache-airflow-providers-apache-impala/**/* - - providers/tests/apache/impala/**/* + - providers/apache/impala/** provider:apache-kafka: - providers/apache/kafka/** @@ -99,19 +95,10 @@ labelPRBasedOnFilePath: - providers/celery/** provider:cloudant: - - providers/src/airflow/providers/cloudant/**/* - - docs/apache-airflow-providers-cloudant/**/* - - providers/tests/cloudant/**/* + - providers/cloudant/** provider:cncf-kubernetes: - - airflow/example_dags/example_kubernetes_executor.py - - airflow/example_dags/example_local_kubernetes_executor.py - - providers/src/airflow/providers/cncf/kubernetes/**/* - - providers/celery/src/airflow/providers/celery/executors/celery_kubernetes_executor.py - - docs/apache-airflow-providers-cncf-kubernetes/**/* - - kubernetes_tests/**/* - - providers/tests/cncf/kubernetes/**/* - - providers/tests/system/cncf/kubernetes/**/* + - providers/cncf/kubernetes/** provider:cohere: - providers/cohere/** @@ -207,10 +194,7 @@ labelPRBasedOnFilePath: - providers/microsoft/psrp/** provider:microsoft-winrm: - - providers/src/airflow/providers/microsoft/winrm/**/* - - docs/apache-airflow-providers-microsoft-winrm/**/* - - providers/tests/microsoft/winrm/**/* - - providers/tests/system/microsoft/winrm/**/* + - providers/microsoft/winrm/** provider:mongo: - providers/mongo/** @@ -320,10 +304,7 @@ labelPRBasedOnFilePath: - providers/weaviate/** provider:yandex: - - providers/src/airflow/providers/yandex/**/* - - docs/apache-airflow-providers-yandex/**/* - - providers/tests/yandex/**/* - - providers/tests/system/yandex/**/* + - providers/yandex/** provider:ydb: - providers/ydb/** diff --git a/.github/workflows/release_dockerhub_image.yml b/.github/workflows/release_dockerhub_image.yml index b8758146cc1b1..7b9adcf60dba5 100644 --- a/.github/workflows/release_dockerhub_image.yml +++ b/.github/workflows/release_dockerhub_image.yml @@ -177,7 +177,7 @@ jobs: ${SKIP_LATEST} ${LIMIT_PLATFORM} --limit-python ${PYTHON_VERSION} - --chicken-egg-providers ${CHICKEN_EGG_PROVIDERS} + --chicken-egg-providers "${CHICKEN_EGG_PROVIDERS}" - name: > Release slim images: ${{ github.event.inputs.airflowVersion }}, ${{ matrix.python-version }} env: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e59a6d1c1f451..e58166f9d383d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -641,11 +641,11 @@ repos: ^providers/google/src/airflow/providers/google/cloud/operators/dataproc.py$| ^providers/google/src/airflow/providers/google/cloud/operators/mlengine.py$| ^providers/src/airflow/providers/microsoft/azure/hooks/cosmos.py$| - ^providers/src/airflow/providers/microsoft/winrm/hooks/winrm.py$| + ^providers/microsoft/winrm/src/airflow/providers/microsoft/winrm/hooks/winrm.py$| ^airflow/www/fab_security/manager.py$| ^docs/.*commits.rst$| ^docs/apache-airflow-providers-apache-cassandra/connections/cassandra.rst$| - ^providers/src/airflow/providers/microsoft/winrm/operators/winrm.py$| + ^providers/microsoft/winrm/src/airflow/providers/microsoft/winrm/operators/winrm.py$| ^providers/opsgenie/src/airflow/providers/opsgenie/hooks/opsgenie.py$| ^providers/redis/src/airflow/providers/redis/provider.yaml$| ^airflow/serialization/serialized_objects.py$| @@ -665,7 +665,7 @@ repos: ^providers/fab/docs/auth-manager/webserver-authentication.rst$| ^providers/google/docs/operators/cloud/kubernetes_engine.rst$| ^docs/apache-airflow-providers-microsoft-azure/connections/azure_cosmos.rst$| - ^docs/apache-airflow-providers-cncf-kubernetes/operators.rst$| + ^providers/cncf/kubernetes/docs/operators.rst$| ^docs/conf.py$| ^docs/exts/removemarktransform.py$| ^newsfragments/41761.significant.rst$| @@ -1223,12 +1223,12 @@ repos: ^airflow/serialization/serde.py$ | ^airflow/utils/file.py$ | ^airflow/utils/helpers.py$ | - ^airflow/utils/log/secrets_masker.py$ | ^providers/ | ^tests/ | ^providers/tests/ | ^providers/.*/tests/ | ^task_sdk/src/airflow/sdk/definitions/dag.py$ | + ^task_sdk/src/airflow/sdk/execution_time/secrets_masker.py$ | ^task_sdk/src/airflow/sdk/definitions/_internal/node.py$ | ^dev/.*\.py$ | ^scripts/.*\.py$ | diff --git a/Dockerfile.ci b/Dockerfile.ci index 27ec994edd5bf..baf31cab0314e 100644 --- a/Dockerfile.ci +++ b/Dockerfile.ci @@ -930,17 +930,12 @@ function check_boto_upgrade() { echo echo "${COLOR_BLUE}Upgrading boto3, botocore to latest version to run Amazon tests with them${COLOR_RESET}" echo - # shellcheck disable=SC2086 - ${PACKAGING_TOOL_CMD} uninstall ${EXTRA_UNINSTALL_FLAGS} aiobotocore s3fs yandexcloud opensearch-py || true - # We need to include few dependencies to pass pip check with other dependencies: - # * oss2 as dependency as otherwise jmespath will be bumped (sync with alibaba provider) - # * cryptography is kept for snowflake-connector-python limitation (sync with snowflake provider) set -x # shellcheck disable=SC2086 - ${PACKAGING_TOOL_CMD} install ${EXTRA_INSTALL_FLAGS} --upgrade boto3 botocore \ - "oss2>=2.14.0" "cryptography<43.0.0" "opensearch-py" + ${PACKAGING_TOOL_CMD} uninstall ${EXTRA_UNINSTALL_FLAGS} aiobotocore s3fs || true + # shellcheck disable=SC2086 + ${PACKAGING_TOOL_CMD} install ${EXTRA_INSTALL_FLAGS} --upgrade boto3 botocore set +x - pip check } function check_downgrade_sqlalchemy() { diff --git a/airflow/api_connexion/schemas/connection_schema.py b/airflow/api_connexion/schemas/connection_schema.py index 4288ce079c554..40085469c2338 100644 --- a/airflow/api_connexion/schemas/connection_schema.py +++ b/airflow/api_connexion/schemas/connection_schema.py @@ -53,7 +53,7 @@ class ConnectionSchema(ConnectionCollectionItemSchema): def serialize_extra(obj: Connection): if obj.extra is None: return - from airflow.utils.log.secrets_masker import redact + from airflow.sdk.execution_time.secrets_masker import redact try: extra = json.loads(obj.extra) diff --git a/airflow/api_fastapi/app.py b/airflow/api_fastapi/app.py index ff74deb2fee4a..4323fcd5017ca 100644 --- a/airflow/api_fastapi/app.py +++ b/airflow/api_fastapi/app.py @@ -79,6 +79,7 @@ def create_app(apps: str = "all") -> FastAPI: if "execution" in apps_list or "all" in apps_list: task_exec_api_app = create_task_execution_api_app(app) + init_error_handlers(task_exec_api_app) app.mount("/execution", task_exec_api_app) init_config(app) diff --git a/airflow/api_fastapi/core_api/datamodels/assets.py b/airflow/api_fastapi/core_api/datamodels/assets.py index fd8f7cef2d415..e25713f02c1c9 100644 --- a/airflow/api_fastapi/core_api/datamodels/assets.py +++ b/airflow/api_fastapi/core_api/datamodels/assets.py @@ -22,7 +22,7 @@ from pydantic import Field, field_validator from airflow.api_fastapi.core_api.base import BaseModel, StrictBaseModel -from airflow.utils.log.secrets_masker import redact +from airflow.sdk.execution_time.secrets_masker import redact class DagScheduleAssetReference(StrictBaseModel): diff --git a/airflow/api_fastapi/core_api/datamodels/connections.py b/airflow/api_fastapi/core_api/datamodels/connections.py index 4650e1354dc1e..19a50eac79bd7 100644 --- a/airflow/api_fastapi/core_api/datamodels/connections.py +++ b/airflow/api_fastapi/core_api/datamodels/connections.py @@ -23,7 +23,7 @@ from pydantic_core.core_schema import ValidationInfo from airflow.api_fastapi.core_api.base import BaseModel, StrictBaseModel -from airflow.utils.log.secrets_masker import redact +from airflow.sdk.execution_time.secrets_masker import redact # Response Models diff --git a/airflow/api_fastapi/core_api/datamodels/dag_tags.py b/airflow/api_fastapi/core_api/datamodels/dag_tags.py index 8d5014fdf3488..5b712c086009a 100644 --- a/airflow/api_fastapi/core_api/datamodels/dag_tags.py +++ b/airflow/api_fastapi/core_api/datamodels/dag_tags.py @@ -17,16 +17,12 @@ from __future__ import annotations -from pydantic import ConfigDict - from airflow.api_fastapi.core_api.base import BaseModel class DagTagResponse(BaseModel): """DAG Tag serializer for responses.""" - model_config = ConfigDict(populate_by_name=True, from_attributes=True) - name: str dag_id: str diff --git a/airflow/api_fastapi/core_api/datamodels/event_logs.py b/airflow/api_fastapi/core_api/datamodels/event_logs.py index 8ea88f363e947..26a1364a2db98 100644 --- a/airflow/api_fastapi/core_api/datamodels/event_logs.py +++ b/airflow/api_fastapi/core_api/datamodels/event_logs.py @@ -19,7 +19,7 @@ from datetime import datetime -from pydantic import ConfigDict, Field +from pydantic import Field from airflow.api_fastapi.core_api.base import BaseModel @@ -27,8 +27,6 @@ class EventLogResponse(BaseModel): """Event Log Response.""" - model_config = ConfigDict(populate_by_name=True, from_attributes=True) - id: int = Field(alias="event_log_id") dttm: datetime = Field(alias="when") dag_id: str | None diff --git a/airflow/api_fastapi/core_api/datamodels/import_error.py b/airflow/api_fastapi/core_api/datamodels/import_error.py index baf1ffa4fb7f1..ccb72b95d6507 100644 --- a/airflow/api_fastapi/core_api/datamodels/import_error.py +++ b/airflow/api_fastapi/core_api/datamodels/import_error.py @@ -18,7 +18,7 @@ from datetime import datetime -from pydantic import ConfigDict, Field +from pydantic import Field from airflow.api_fastapi.core_api.base import BaseModel @@ -26,8 +26,6 @@ class ImportErrorResponse(BaseModel): """Import Error Response.""" - model_config = ConfigDict(populate_by_name=True, from_attributes=True) - id: int = Field(alias="import_error_id") timestamp: datetime filename: str diff --git a/airflow/api_fastapi/core_api/datamodels/pools.py b/airflow/api_fastapi/core_api/datamodels/pools.py index 096e357dfaf1d..2e7ae13cfcdb2 100644 --- a/airflow/api_fastapi/core_api/datamodels/pools.py +++ b/airflow/api_fastapi/core_api/datamodels/pools.py @@ -19,7 +19,7 @@ from typing import Annotated, Callable -from pydantic import BeforeValidator, ConfigDict, Field +from pydantic import BeforeValidator, Field from airflow.api_fastapi.core_api.base import BaseModel, StrictBaseModel @@ -63,8 +63,6 @@ class PoolCollectionResponse(BaseModel): class PoolPatchBody(StrictBaseModel): """Pool serializer for patch bodies.""" - model_config = ConfigDict(populate_by_name=True, from_attributes=True) - name: str | None = Field(default=None, alias="pool") slots: int | None = None description: str | None = None diff --git a/airflow/api_fastapi/core_api/datamodels/task_instances.py b/airflow/api_fastapi/core_api/datamodels/task_instances.py index 505b9da207c39..d209d46bd5edc 100644 --- a/airflow/api_fastapi/core_api/datamodels/task_instances.py +++ b/airflow/api_fastapi/core_api/datamodels/task_instances.py @@ -23,7 +23,6 @@ AliasPath, AwareDatetime, BeforeValidator, - ConfigDict, Field, NonNegativeInt, StringConstraints, @@ -42,8 +41,6 @@ class TaskInstanceResponse(BaseModel): """TaskInstance serializer for responses.""" - model_config = ConfigDict(populate_by_name=True, from_attributes=True) - id: str task_id: str dag_id: str @@ -126,8 +123,6 @@ class TaskInstancesBatchBody(StrictBaseModel): class TaskInstanceHistoryResponse(BaseModel): """TaskInstanceHistory serializer for responses.""" - model_config = ConfigDict(populate_by_name=True, from_attributes=True) - task_id: str dag_id: str @@ -154,6 +149,7 @@ class TaskInstanceHistoryResponse(BaseModel): pid: int | None executor: str | None executor_config: Annotated[str, BeforeValidator(str)] + dag_version: DagVersionResponse | None class TaskInstanceHistoryCollectionResponse(BaseModel): diff --git a/airflow/api_fastapi/core_api/datamodels/variables.py b/airflow/api_fastapi/core_api/datamodels/variables.py index 82cfdbb130523..2317d8a168b82 100644 --- a/airflow/api_fastapi/core_api/datamodels/variables.py +++ b/airflow/api_fastapi/core_api/datamodels/variables.py @@ -19,19 +19,17 @@ import json -from pydantic import ConfigDict, Field, model_validator +from pydantic import Field, model_validator from airflow.api_fastapi.core_api.base import BaseModel, StrictBaseModel from airflow.models.base import ID_LEN +from airflow.sdk.execution_time.secrets_masker import redact from airflow.typing_compat import Self -from airflow.utils.log.secrets_masker import redact class VariableResponse(BaseModel): """Variable serializer for responses.""" - model_config = ConfigDict(populate_by_name=True, from_attributes=True) - key: str val: str = Field(alias="value") description: str | None diff --git a/airflow/api_fastapi/core_api/datamodels/xcom.py b/airflow/api_fastapi/core_api/datamodels/xcom.py index f874f8bdeed5a..1acccb702efaf 100644 --- a/airflow/api_fastapi/core_api/datamodels/xcom.py +++ b/airflow/api_fastapi/core_api/datamodels/xcom.py @@ -65,3 +65,10 @@ class XComCreateBody(StrictBaseModel): key: str value: Any map_index: int = -1 + + +class XComUpdateBody(StrictBaseModel): + """Payload serializer for updating an XCom entry.""" + + value: Any + map_index: int = -1 diff --git a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml index cab7daffbbe3b..486c371682fdb 100644 --- a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml +++ b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml @@ -4380,6 +4380,80 @@ paths: application/json: schema: $ref: '#/components/schemas/HTTPValidationError' + patch: + tags: + - XCom + summary: Update Xcom Entry + description: Update an existing XCom entry. + operationId: update_xcom_entry + parameters: + - name: dag_id + in: path + required: true + schema: + type: string + title: Dag Id + - name: task_id + in: path + required: true + schema: + type: string + title: Task Id + - name: dag_run_id + in: path + required: true + schema: + type: string + title: Dag Run Id + - name: xcom_key + in: path + required: true + schema: + type: string + title: Xcom Key + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/XComUpdateBody' + responses: + '200': + description: Successful Response + content: + application/json: + schema: + $ref: '#/components/schemas/XComResponseNative' + '401': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Unauthorized + '403': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Forbidden + '400': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Bad Request + '404': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Not Found + '422': + description: Validation Error + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPValidationError' /public/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/xcomEntries: get: tags: @@ -9794,6 +9868,10 @@ components: executor_config: type: string title: Executor Config + dag_version: + anyOf: + - $ref: '#/components/schemas/DagVersionResponse' + - type: 'null' type: object required: - task_id @@ -9819,6 +9897,7 @@ components: - pid - executor - executor_config + - dag_version title: TaskInstanceHistoryResponse description: TaskInstanceHistory serializer for responses. TaskInstanceResponse: @@ -10742,3 +10821,17 @@ components: - value title: XComResponseString description: XCom response serializer with string return type. + XComUpdateBody: + properties: + value: + title: Value + map_index: + type: integer + title: Map Index + default: -1 + additionalProperties: false + type: object + required: + - value + title: XComUpdateBody + description: Payload serializer for updating an XCom entry. diff --git a/airflow/api_fastapi/core_api/routes/public/job.py b/airflow/api_fastapi/core_api/routes/public/job.py index b0c99e3ab32b9..71c3857a5e727 100644 --- a/airflow/api_fastapi/core_api/routes/public/job.py +++ b/airflow/api_fastapi/core_api/routes/public/job.py @@ -37,7 +37,6 @@ from airflow.api_fastapi.common.router import AirflowRouter from airflow.api_fastapi.core_api.datamodels.job import ( JobCollectionResponse, - JobResponse, ) from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc from airflow.jobs.job import Job @@ -124,12 +123,6 @@ def get_jobs( jobs = [job for job in jobs if job.is_alive()] return JobCollectionResponse( - jobs=[ - JobResponse.model_validate( - job, - from_attributes=True, - ) - for job in jobs - ], + jobs=jobs, total_entries=total_entries, ) diff --git a/airflow/api_fastapi/core_api/routes/public/task_instances.py b/airflow/api_fastapi/core_api/routes/public/task_instances.py index d3a67b9804911..927f62028a85d 100644 --- a/airflow/api_fastapi/core_api/routes/public/task_instances.py +++ b/airflow/api_fastapi/core_api/routes/public/task_instances.py @@ -272,11 +272,15 @@ def get_task_instance_tries( """Get list of task instances history.""" def _query(orm_object: Base) -> Select: - query = select(orm_object).where( - orm_object.dag_id == dag_id, - orm_object.run_id == dag_run_id, - orm_object.task_id == task_id, - orm_object.map_index == map_index, + query = ( + select(orm_object) + .where( + orm_object.dag_id == dag_id, + orm_object.run_id == dag_run_id, + orm_object.task_id == task_id, + orm_object.map_index == map_index, + ) + .options(joinedload(orm_object.dag_version)) ) return query @@ -291,7 +295,6 @@ def _query(orm_object: Base) -> Select: status.HTTP_404_NOT_FOUND, f"The Task Instance with dag_id: `{dag_id}`, run_id: `{dag_run_id}`, task_id: `{task_id}` and map_index: `{map_index}` was not found", ) - return TaskInstanceHistoryCollectionResponse( task_instances=cast(list[TaskInstanceHistoryResponse], task_instances), total_entries=len(task_instances), @@ -659,13 +662,7 @@ def post_clear_task_instances( ) return TaskInstanceCollectionResponse( - task_instances=[ - TaskInstanceResponse.model_validate( - ti, - from_attributes=True, - ) - for ti in task_instances - ], + task_instances=task_instances, total_entries=len(task_instances), ) @@ -775,7 +772,6 @@ def patch_task_instance_dry_run( task_instances=[ TaskInstanceResponse.model_validate( ti, - from_attributes=True, ) for ti in tis ], @@ -840,4 +836,4 @@ def patch_task_instance( ti.task_instance_note.user_id = None session.commit() - return TaskInstanceResponse.model_validate(ti, from_attributes=True) + return TaskInstanceResponse.model_validate(ti) diff --git a/airflow/api_fastapi/core_api/routes/public/xcom.py b/airflow/api_fastapi/core_api/routes/public/xcom.py index e1ef40685d39f..3da163f3e4033 100644 --- a/airflow/api_fastapi/core_api/routes/public/xcom.py +++ b/airflow/api_fastapi/core_api/routes/public/xcom.py @@ -30,6 +30,7 @@ XComCreateBody, XComResponseNative, XComResponseString, + XComUpdateBody, ) from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc from airflow.exceptions import TaskNotFound @@ -222,3 +223,48 @@ def create_xcom_entry( ) return XComResponseNative.model_validate(xcom) + + +@xcom_router.patch( + "/{xcom_key}", + status_code=status.HTTP_200_OK, + responses=create_openapi_http_exception_doc( + [ + status.HTTP_400_BAD_REQUEST, + status.HTTP_404_NOT_FOUND, + ] + ), +) +def update_xcom_entry( + dag_id: str, + task_id: str, + dag_run_id: str, + xcom_key: str, + patch_body: XComUpdateBody, + session: SessionDep, +) -> XComResponseNative: + """Update an existing XCom entry.""" + # Check if XCom entry exists + xcom_new_value = XCom.serialize_value(patch_body.value) + xcom_entry = session.scalar( + select(XCom) + .where( + XCom.dag_id == dag_id, + XCom.task_id == task_id, + XCom.run_id == dag_run_id, + XCom.key == xcom_key, + XCom.map_index == patch_body.map_index, + ) + .limit(1) + ) + + if not xcom_entry: + raise HTTPException( + status.HTTP_404_NOT_FOUND, + f"The XCom with key: `{xcom_key}` with mentioned task instance doesn't exist.", + ) + + # Update XCom entry + xcom_entry.value = XCom.serialize_value(xcom_new_value) + + return XComResponseNative.model_validate(xcom_entry) diff --git a/airflow/api_fastapi/execution_api/app.py b/airflow/api_fastapi/execution_api/app.py index 61283dc2cf87f..2b85be363f25f 100644 --- a/airflow/api_fastapi/execution_api/app.py +++ b/airflow/api_fastapi/execution_api/app.py @@ -77,8 +77,14 @@ def custom_openapi() -> dict: def get_extra_schemas() -> dict[str, dict]: """Get all the extra schemas that are not part of the main FastAPI app.""" - from airflow.api_fastapi.execution_api.datamodels import taskinstance + from airflow.api_fastapi.execution_api.datamodels.taskinstance import TaskInstance + from airflow.executors.workloads import BundleInfo + from airflow.utils.state import TerminalTIState return { - "TaskInstance": taskinstance.TaskInstance.model_json_schema(), + "TaskInstance": TaskInstance.model_json_schema(), + "BundleInfo": BundleInfo.model_json_schema(), + # Include the combined state enum too. In the datamodels we separate out SUCCESS from the other states + # as that has different payload requirements + "TerminalTIState": {"type": "string", "enum": list(TerminalTIState)}, } diff --git a/airflow/api_fastapi/execution_api/datamodels/asset.py b/airflow/api_fastapi/execution_api/datamodels/asset.py index 29b260c291c2b..28d352aa23101 100644 --- a/airflow/api_fastapi/execution_api/datamodels/asset.py +++ b/airflow/api_fastapi/execution_api/datamodels/asset.py @@ -17,7 +17,7 @@ from __future__ import annotations -from airflow.api_fastapi.core_api.base import BaseModel +from airflow.api_fastapi.core_api.base import BaseModel, StrictBaseModel class AssetResponse(BaseModel): @@ -36,7 +36,7 @@ class AssetAliasResponse(BaseModel): group: str -class AssetProfile(BaseModel): +class AssetProfile(StrictBaseModel): """ Profile of an Asset. diff --git a/airflow/api_fastapi/execution_api/datamodels/taskinstance.py b/airflow/api_fastapi/execution_api/datamodels/taskinstance.py index c3ec79c9ddd11..0c8e5eb1b6a3f 100644 --- a/airflow/api_fastapi/execution_api/datamodels/taskinstance.py +++ b/airflow/api_fastapi/execution_api/datamodels/taskinstance.py @@ -14,11 +14,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - from __future__ import annotations import uuid from datetime import timedelta +from enum import Enum from typing import Annotated, Any, Literal, Union from pydantic import ( @@ -60,15 +60,20 @@ class TIEnterRunningPayload(StrictBaseModel): """When the task started executing""" +# Create an enum to give a nice name in the generated datamodels +class TerminalStateNonSuccess(str, Enum): + """TaskInstance states that can be reported without extra information.""" + + FAILED = TerminalTIState.FAILED + SKIPPED = TerminalTIState.SKIPPED + REMOVED = TerminalTIState.REMOVED + FAIL_WITHOUT_RETRY = TerminalTIState.FAIL_WITHOUT_RETRY + + class TITerminalStatePayload(StrictBaseModel): """Schema for updating TaskInstance to a terminal state except SUCCESS state.""" - state: Literal[ - TerminalTIState.FAILED, - TerminalTIState.SKIPPED, - TerminalTIState.REMOVED, - TerminalTIState.FAIL_WITHOUT_RETRY, - ] + state: TerminalStateNonSuccess end_date: UtcDateTime """When the task completed executing""" @@ -216,7 +221,7 @@ class DagRun(StrictBaseModel): dag_id: str run_id: str - logical_date: UtcDateTime + logical_date: UtcDateTime | None data_interval_start: UtcDateTime | None data_interval_end: UtcDateTime | None run_after: UtcDateTime @@ -242,6 +247,8 @@ class TIRunContext(BaseModel): connections: Annotated[list[ConnectionResponse], Field(default_factory=list)] """Connections that can be accessed by the task instance.""" + upstream_map_indexes: dict[str, int] | None = None + class PrevSuccessfulDagRunResponse(BaseModel): """Schema for response with previous successful DagRun information for Task Template Context.""" @@ -252,7 +259,7 @@ class PrevSuccessfulDagRunResponse(BaseModel): end_date: UtcDateTime | None = None -class TIRuntimeCheckPayload(BaseModel): +class TIRuntimeCheckPayload(StrictBaseModel): """Payload for performing Runtime checks on the TaskInstance model as requested by the SDK.""" inlets: list[AssetProfile] | None = None diff --git a/airflow/api_fastapi/execution_api/routes/task_instances.py b/airflow/api_fastapi/execution_api/routes/task_instances.py index 3cc70dac9c5b5..1c605b5cadcf0 100644 --- a/airflow/api_fastapi/execution_api/routes/task_instances.py +++ b/airflow/api_fastapi/execution_api/routes/task_instances.py @@ -30,7 +30,6 @@ from airflow.api_fastapi.common.db.common import SessionDep from airflow.api_fastapi.common.router import AirflowRouter from airflow.api_fastapi.execution_api.datamodels.taskinstance import ( - DagRun, PrevSuccessfulDagRunResponse, TIDeferredStatePayload, TIEnterRunningPayload, @@ -66,6 +65,7 @@ status.HTTP_409_CONFLICT: {"description": "The TI is already in the requested state"}, status.HTTP_422_UNPROCESSABLE_ENTITY: {"description": "Invalid payload for the state transition"}, }, + response_model_exclude_unset=True, ) def ti_run( task_instance_id: UUID, ti_run_payload: Annotated[TIEnterRunningPayload, Body()], session: SessionDep @@ -150,6 +150,7 @@ def ti_run( DR.run_type, DR.conf, DR.logical_date, + DR.external_trigger, ).filter_by(dag_id=dag_id, run_id=run_id) ).one_or_none() @@ -171,7 +172,7 @@ def ti_run( ) return TIRunContext( - dag_run=DagRun.model_validate(dr, from_attributes=True), + dag_run=dr, max_tries=max_tries, # TODO: Add variables and connections that are needed (and has perms) for the task variables=[], diff --git a/airflow/api_fastapi/execution_api/routes/xcoms.py b/airflow/api_fastapi/execution_api/routes/xcoms.py index f330744536b33..9a8cef62e07a0 100644 --- a/airflow/api_fastapi/execution_api/routes/xcoms.py +++ b/airflow/api_fastapi/execution_api/routes/xcoms.py @@ -21,7 +21,8 @@ import logging from typing import Annotated -from fastapi import Body, HTTPException, Query, status +from fastapi import Body, Depends, HTTPException, Query, Response, status +from sqlalchemy.sql.selectable import Select from airflow.api_fastapi.common.db.common import SessionDep from airflow.api_fastapi.common.router import AirflowRouter @@ -30,6 +31,7 @@ from airflow.api_fastapi.execution_api.datamodels.xcom import XComResponse from airflow.models.taskmap import TaskMap from airflow.models.xcom import BaseXCom +from airflow.utils.db import get_query_count # TODO: Add dependency on JWT token router = AirflowRouter( @@ -42,20 +44,15 @@ log = logging.getLogger(__name__) -@router.get( - "/{dag_id}/{run_id}/{task_id}/{key}", - responses={status.HTTP_404_NOT_FOUND: {"description": "XCom not found"}}, -) -def get_xcom( +async def xcom_query( dag_id: str, run_id: str, task_id: str, key: str, - token: deps.TokenDep, session: SessionDep, - map_index: Annotated[int, Query()] = -1, -) -> XComResponse: - """Get an Airflow XCom from database - not other XCom Backends.""" + token: deps.TokenDep, + map_index: Annotated[int | None, Query()] = None, +) -> Select: if not has_xcom_access(dag_id, run_id, task_id, key, token): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, @@ -65,29 +62,87 @@ def get_xcom( }, ) - # We use `BaseXCom.get_many` to fetch XComs directly from the database, bypassing the XCom Backend. - # This avoids deserialization via the backend (e.g., from a remote storage like S3) and instead - # retrieves the raw serialized value from the database. By not relying on `XCom.get_many` or `XCom.get_one` - # (which automatically deserializes using the backend), we avoid potential - # performance hits from retrieving large data files into the API server. query = BaseXCom.get_many( run_id=run_id, key=key, task_ids=task_id, dag_ids=dag_id, map_indexes=map_index, - limit=1, session=session, ) + return query.with_entities(BaseXCom.value) + + +@router.head( + "/{dag_id}/{run_id}/{task_id}/{key}", + responses={ + status.HTTP_404_NOT_FOUND: {"description": "XCom not found"}, + status.HTTP_200_OK: { + "description": "Metadata about the number of matching XCom values", + "headers": { + "Content-Range": { + "pattern": r"^map_indexes \d+$", + "description": "The number of (mapped) XCom values found for this task.", + }, + }, + }, + }, + description="Return the count of the number of XCom values found via the Content-Range response header", +) +def head_xcom( + response: Response, + token: deps.TokenDep, + session: SessionDep, + xcom_query: Annotated[Select, Depends(xcom_query)], + map_index: Annotated[int | None, Query()] = None, +) -> None: + """Get the count of XComs from database - not other XCom Backends.""" + if map_index is not None: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail={"reason": "invalid_request", "message": "Cannot specify map_index in a HEAD request"}, + ) + + count = get_query_count(xcom_query, session=session) + # Tell the caller how many items in this query. We define a custom range unit (HTTP spec only defines + # "bytes" but we can add our own) + response.headers["Content-Range"] = f"map_indexes {count}" + + +@router.get( + "/{dag_id}/{run_id}/{task_id}/{key}", + responses={status.HTTP_404_NOT_FOUND: {"description": "XCom not found"}}, + description="Get a single XCom Value", +) +def get_xcom( + session: SessionDep, + dag_id: str, + run_id: str, + task_id: str, + key: str, + xcom_query: Annotated[Select, Depends(xcom_query)], + map_index: Annotated[int, Query()] = -1, +) -> XComResponse: + """Get an Airflow XCom from database - not other XCom Backends.""" + # The xcom_query allows no map_index to be passed. This endpoint should always return just a single item, + # so we override that query value + + xcom_query = xcom_query.filter_by(map_index=map_index) + + # We use `BaseXCom.get_many` to fetch XComs directly from the database, bypassing the XCom Backend. + # This avoids deserialization via the backend (e.g., from a remote storage like S3) and instead + # retrieves the raw serialized value from the database. By not relying on `XCom.get_many` or `XCom.get_one` + # (which automatically deserializes using the backend), we avoid potential + # performance hits from retrieving large data files into the API server. - result = query.with_entities(BaseXCom.value).first() + result = xcom_query.limit(1).first() if result is None: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail={ "reason": "not_found", - "message": f"XCom with key '{key}' not found for task '{task_id}' in DAG '{dag_id}'", + "message": f"XCom with {key=} {map_index=} not found for task {task_id!r} in DAG run {run_id!r} of {dag_id!r}", }, ) diff --git a/airflow/api_fastapi/logging/decorators.py b/airflow/api_fastapi/logging/decorators.py index b7e1002c180a1..0c709839f3152 100644 --- a/airflow/api_fastapi/logging/decorators.py +++ b/airflow/api_fastapi/logging/decorators.py @@ -29,7 +29,7 @@ from airflow.api_fastapi.core_api.security import get_user_with_exception_handling from airflow.auth.managers.models.base_user import BaseUser from airflow.models import Log -from airflow.utils.log import secrets_masker +from airflow.sdk.execution_time import secrets_masker logger = logging.getLogger(__name__) diff --git a/airflow/auth/managers/simple/ui/dev/index.html b/airflow/auth/managers/simple/ui/dev/index.html new file mode 100644 index 0000000000000..b27bdcd6795bf --- /dev/null +++ b/airflow/auth/managers/simple/ui/dev/index.html @@ -0,0 +1,23 @@ + + + + + + + + + + Airflow 3.0 + + +
+ + + diff --git a/airflow/auth/managers/simple/ui/package-lock.json b/airflow/auth/managers/simple/ui/package-lock.json index a0087cc944ce1..4898896322b97 100644 --- a/airflow/auth/managers/simple/ui/package-lock.json +++ b/airflow/auth/managers/simple/ui/package-lock.json @@ -25,7 +25,7 @@ "happy-dom": "^15.10.2", "vite": "^5.4.14", "vite-plugin-css-injected-by-js": "^3.5.2", - "vitest": "^2.1.1" + "vitest": "^2.1.9" } }, "node_modules/@7nohe/openapi-react-query-codegen": { @@ -1830,13 +1830,14 @@ } }, "node_modules/@vitest/expect": { - "version": "2.1.8", - "resolved": "https://registry.npmjs.org/@vitest/expect/-/expect-2.1.8.tgz", - "integrity": "sha512-8ytZ/fFHq2g4PJVAtDX57mayemKgDR6X3Oa2Foro+EygiOJHUXhCqBAAKQYYajZpFoIfvBCF1j6R6IYRSIUFuw==", + "version": "2.1.9", + "resolved": "https://registry.npmjs.org/@vitest/expect/-/expect-2.1.9.tgz", + "integrity": "sha512-UJCIkTBenHeKT1TTlKMJWy1laZewsRIzYighyYiJKZreqtdxSos/S1t+ktRMQWu2CKqaarrkeszJx1cgC5tGZw==", "dev": true, + "license": "MIT", "dependencies": { - "@vitest/spy": "2.1.8", - "@vitest/utils": "2.1.8", + "@vitest/spy": "2.1.9", + "@vitest/utils": "2.1.9", "chai": "^5.1.2", "tinyrainbow": "^1.2.0" }, @@ -1845,12 +1846,13 @@ } }, "node_modules/@vitest/mocker": { - "version": "2.1.8", - "resolved": "https://registry.npmjs.org/@vitest/mocker/-/mocker-2.1.8.tgz", - "integrity": "sha512-7guJ/47I6uqfttp33mgo6ga5Gr1VnL58rcqYKyShoRK9ebu8T5Rs6HN3s1NABiBeVTdWNrwUMcHH54uXZBN4zA==", + "version": "2.1.9", + "resolved": "https://registry.npmjs.org/@vitest/mocker/-/mocker-2.1.9.tgz", + "integrity": "sha512-tVL6uJgoUdi6icpxmdrn5YNo3g3Dxv+IHJBr0GXHaEdTcw3F+cPKnsXFhli6nO+f/6SDKPHEK1UN+k+TQv0Ehg==", "dev": true, + "license": "MIT", "dependencies": { - "@vitest/spy": "2.1.8", + "@vitest/spy": "2.1.9", "estree-walker": "^3.0.3", "magic-string": "^0.30.12" }, @@ -1871,10 +1873,11 @@ } }, "node_modules/@vitest/pretty-format": { - "version": "2.1.8", - "resolved": "https://registry.npmjs.org/@vitest/pretty-format/-/pretty-format-2.1.8.tgz", - "integrity": "sha512-9HiSZ9zpqNLKlbIDRWOnAWqgcA7xu+8YxXSekhr0Ykab7PAYFkhkwoqVArPOtJhPmYeE2YHgKZlj3CP36z2AJQ==", + "version": "2.1.9", + "resolved": "https://registry.npmjs.org/@vitest/pretty-format/-/pretty-format-2.1.9.tgz", + "integrity": "sha512-KhRIdGV2U9HOUzxfiHmY8IFHTdqtOhIzCpd8WRdJiE7D/HUcZVD0EgQCVjm+Q9gkUXWgBvMmTtZgIG48wq7sOQ==", "dev": true, + "license": "MIT", "dependencies": { "tinyrainbow": "^1.2.0" }, @@ -1883,12 +1886,13 @@ } }, "node_modules/@vitest/runner": { - "version": "2.1.8", - "resolved": "https://registry.npmjs.org/@vitest/runner/-/runner-2.1.8.tgz", - "integrity": "sha512-17ub8vQstRnRlIU5k50bG+QOMLHRhYPAna5tw8tYbj+jzjcspnwnwtPtiOlkuKC4+ixDPTuLZiqiWWQ2PSXHVg==", + "version": "2.1.9", + "resolved": "https://registry.npmjs.org/@vitest/runner/-/runner-2.1.9.tgz", + "integrity": "sha512-ZXSSqTFIrzduD63btIfEyOmNcBmQvgOVsPNPe0jYtESiXkhd8u2erDLnMxmGrDCwHCCHE7hxwRDCT3pt0esT4g==", "dev": true, + "license": "MIT", "dependencies": { - "@vitest/utils": "2.1.8", + "@vitest/utils": "2.1.9", "pathe": "^1.1.2" }, "funding": { @@ -1896,12 +1900,13 @@ } }, "node_modules/@vitest/snapshot": { - "version": "2.1.8", - "resolved": "https://registry.npmjs.org/@vitest/snapshot/-/snapshot-2.1.8.tgz", - "integrity": "sha512-20T7xRFbmnkfcmgVEz+z3AU/3b0cEzZOt/zmnvZEctg64/QZbSDJEVm9fLnnlSi74KibmRsO9/Qabi+t0vCRPg==", + "version": "2.1.9", + "resolved": "https://registry.npmjs.org/@vitest/snapshot/-/snapshot-2.1.9.tgz", + "integrity": "sha512-oBO82rEjsxLNJincVhLhaxxZdEtV0EFHMK5Kmx5sJ6H9L183dHECjiefOAdnqpIgT5eZwT04PoggUnW88vOBNQ==", "dev": true, + "license": "MIT", "dependencies": { - "@vitest/pretty-format": "2.1.8", + "@vitest/pretty-format": "2.1.9", "magic-string": "^0.30.12", "pathe": "^1.1.2" }, @@ -1910,10 +1915,11 @@ } }, "node_modules/@vitest/spy": { - "version": "2.1.8", - "resolved": "https://registry.npmjs.org/@vitest/spy/-/spy-2.1.8.tgz", - "integrity": "sha512-5swjf2q95gXeYPevtW0BLk6H8+bPlMb4Vw/9Em4hFxDcaOxS+e0LOX4yqNxoHzMR2akEB2xfpnWUzkZokmgWDg==", + "version": "2.1.9", + "resolved": "https://registry.npmjs.org/@vitest/spy/-/spy-2.1.9.tgz", + "integrity": "sha512-E1B35FwzXXTs9FHNK6bDszs7mtydNi5MIfUWpceJ8Xbfb1gBMscAnwLbEu+B44ed6W3XjL9/ehLPHR1fkf1KLQ==", "dev": true, + "license": "MIT", "dependencies": { "tinyspy": "^3.0.2" }, @@ -1922,12 +1928,13 @@ } }, "node_modules/@vitest/utils": { - "version": "2.1.8", - "resolved": "https://registry.npmjs.org/@vitest/utils/-/utils-2.1.8.tgz", - "integrity": "sha512-dwSoui6djdwbfFmIgbIjX2ZhIoG7Ex/+xpxyiEgIGzjliY8xGkcpITKTlp6B4MgtGkF2ilvm97cPM96XZaAgcA==", + "version": "2.1.9", + "resolved": "https://registry.npmjs.org/@vitest/utils/-/utils-2.1.9.tgz", + "integrity": "sha512-v0psaMSkNJ3A2NMrUEHFRzJtDPFn+/VWZ5WxImB21T9fjucJRmS7xCS3ppEnARb9y11OAzaD+P2Ps+b+BGX5iQ==", "dev": true, + "license": "MIT", "dependencies": { - "@vitest/pretty-format": "2.1.8", + "@vitest/pretty-format": "2.1.9", "loupe": "^3.1.2", "tinyrainbow": "^1.2.0" }, @@ -2758,6 +2765,7 @@ "resolved": "https://registry.npmjs.org/assertion-error/-/assertion-error-2.0.1.tgz", "integrity": "sha512-Izi8RQcffqCeNVgFigKli1ssklIbpHnCYc6AknXGYoB6grJqyeby7jv12JUQgmTAnIDnbck1uxksT4dzN3PWBA==", "dev": true, + "license": "MIT", "engines": { "node": ">=12" } @@ -2865,6 +2873,7 @@ "resolved": "https://registry.npmjs.org/cac/-/cac-6.7.14.tgz", "integrity": "sha512-b6Ilus+c3RrdDk+JhLKUAQfzzgLEPy6wcXqS7f/xe1EETvsDP6GORG7SFuOs6cID5YkqchW/LXZbX5bc8j7ZcQ==", "dev": true, + "license": "MIT", "engines": { "node": ">=8" } @@ -2894,6 +2903,7 @@ "resolved": "https://registry.npmjs.org/chai/-/chai-5.1.2.tgz", "integrity": "sha512-aGtmf24DW6MLHHG5gCx4zaI3uBq3KRtxeVs0DjFH6Z0rDNbsvTxFASFvdj79pxjxZ8/5u3PIiN3IwEIQkiiuPw==", "dev": true, + "license": "MIT", "dependencies": { "assertion-error": "^2.0.1", "check-error": "^2.1.1", @@ -2926,6 +2936,7 @@ "resolved": "https://registry.npmjs.org/check-error/-/check-error-2.1.1.tgz", "integrity": "sha512-OAlb+T7V4Op9OwdkjmguYRqncdlx5JiofwOAUkmTF+jNdHwzTaTs4sRAGpzLF3oOz5xAyDGrPgeIDFQmDOTiJw==", "dev": true, + "license": "MIT", "engines": { "node": ">= 16" } @@ -3118,6 +3129,7 @@ "resolved": "https://registry.npmjs.org/deep-eql/-/deep-eql-5.0.2.tgz", "integrity": "sha512-h5k/5U50IJJFpzfL6nO9jaaumfjO/f2NjK/oYB2Djzm4p9L+3T9qWpZqZ2hAbLPuuYq9wrU08WQyBTL5GbPk5Q==", "dev": true, + "license": "MIT", "engines": { "node": ">=6" } @@ -3215,7 +3227,8 @@ "version": "1.6.0", "resolved": "https://registry.npmjs.org/es-module-lexer/-/es-module-lexer-1.6.0.tgz", "integrity": "sha512-qqnD1yMU6tk/jnaMosogGySTZP8YtUgAffA9nMN+E/rjxcfRQ6IEk7IiozUjgxKoFHBGjTLnrHB/YC45r/59EQ==", - "dev": true + "dev": true, + "license": "MIT" }, "node_modules/esbuild": { "version": "0.21.5", @@ -3408,6 +3421,7 @@ "resolved": "https://registry.npmjs.org/estree-walker/-/estree-walker-3.0.3.tgz", "integrity": "sha512-7RUKfXgSMMkzt6ZuXmqapOurLGPPfgj6l9uRZ7lRGolvk0y2yocc35LdcxKC5PQZdn2DMqioAQ2NoWcrTKmm6g==", "dev": true, + "license": "MIT", "dependencies": { "@types/estree": "^1.0.0" } @@ -4132,10 +4146,11 @@ } }, "node_modules/loupe": { - "version": "3.1.2", - "resolved": "https://registry.npmjs.org/loupe/-/loupe-3.1.2.tgz", - "integrity": "sha512-23I4pFZHmAemUnz8WZXbYRSKYj801VDaNv9ETuMh7IrMc7VuVVSo+Z9iLE3ni30+U48iDWfi30d3twAXBYmnCg==", - "dev": true + "version": "3.1.3", + "resolved": "https://registry.npmjs.org/loupe/-/loupe-3.1.3.tgz", + "integrity": "sha512-kkIp7XSkP78ZxJEsSxW3712C6teJVoeHHwgo9zJ380de7IYyJ2ISlxojcH2pC5OFLewESmnRi/+XCDIEEVyoug==", + "dev": true, + "license": "MIT" }, "node_modules/lru-cache": { "version": "10.4.3", @@ -4159,6 +4174,7 @@ "resolved": "https://registry.npmjs.org/magic-string/-/magic-string-0.30.17.tgz", "integrity": "sha512-sNPKHvyjVf7gyjwS4xGTaW/mCnF8wnjtifKBEhxfZ7E/S8tQ0rssrwGNn6q8JH/ohItJfSQp9mBtQYuTlH5QnA==", "dev": true, + "license": "MIT", "dependencies": { "@jridgewell/sourcemap-codec": "^1.5.0" } @@ -4588,6 +4604,7 @@ "resolved": "https://registry.npmjs.org/pathval/-/pathval-2.0.0.tgz", "integrity": "sha512-vE7JKRyES09KiunauX7nd2Q9/L7lhok4smP9RZTDeD4MVs72Dp2qNFVz39Nz5a0FVEW0BJR6C0DYrq6unoziZA==", "dev": true, + "license": "MIT", "engines": { "node": ">= 14.16" } @@ -5264,6 +5281,7 @@ "resolved": "https://registry.npmjs.org/tinyrainbow/-/tinyrainbow-1.2.0.tgz", "integrity": "sha512-weEDEq7Z5eTHPDh4xjX789+fHfF+P8boiFB+0vbWzpbnbsEr/GRaohi/uMKxg8RZMXnl1ItAi/IUHWMsjDV7kQ==", "dev": true, + "license": "MIT", "engines": { "node": ">=14.0.0" } @@ -5273,6 +5291,7 @@ "resolved": "https://registry.npmjs.org/tinyspy/-/tinyspy-3.0.2.tgz", "integrity": "sha512-n1cw8k1k0x4pgA2+9XrOkFydTerNcJ1zWCO5Nn9scWHTD+5tp8dghT2x1uduQePZTZgd3Tupf+x9BxJjeJi77Q==", "dev": true, + "license": "MIT", "engines": { "node": ">=14.0.0" } @@ -5425,10 +5444,11 @@ } }, "node_modules/vite-node": { - "version": "2.1.8", - "resolved": "https://registry.npmjs.org/vite-node/-/vite-node-2.1.8.tgz", - "integrity": "sha512-uPAwSr57kYjAUux+8E2j0q0Fxpn8M9VoyfGiRI8Kfktz9NcYMCenwY5RnZxnF1WTu3TGiYipirIzacLL3VVGFg==", + "version": "2.1.9", + "resolved": "https://registry.npmjs.org/vite-node/-/vite-node-2.1.9.tgz", + "integrity": "sha512-AM9aQ/IPrW/6ENLQg3AGY4K1N2TGZdR5e4gu/MmmR2xR3Ll1+dib+nook92g4TV3PXVyeyxdWwtaCAiUL0hMxA==", "dev": true, + "license": "MIT", "dependencies": { "cac": "^6.7.14", "debug": "^4.3.7", @@ -5456,18 +5476,19 @@ } }, "node_modules/vitest": { - "version": "2.1.8", - "resolved": "https://registry.npmjs.org/vitest/-/vitest-2.1.8.tgz", - "integrity": "sha512-1vBKTZskHw/aosXqQUlVWWlGUxSJR8YtiyZDJAFeW2kPAeX6S3Sool0mjspO+kXLuxVWlEDDowBAeqeAQefqLQ==", - "dev": true, - "dependencies": { - "@vitest/expect": "2.1.8", - "@vitest/mocker": "2.1.8", - "@vitest/pretty-format": "^2.1.8", - "@vitest/runner": "2.1.8", - "@vitest/snapshot": "2.1.8", - "@vitest/spy": "2.1.8", - "@vitest/utils": "2.1.8", + "version": "2.1.9", + "resolved": "https://registry.npmjs.org/vitest/-/vitest-2.1.9.tgz", + "integrity": "sha512-MSmPM9REYqDGBI8439mA4mWhV5sKmDlBKWIYbA3lRb2PTHACE0mgKwA8yQ2xq9vxDTuk4iPrECBAEW2aoFXY0Q==", + "dev": true, + "license": "MIT", + "dependencies": { + "@vitest/expect": "2.1.9", + "@vitest/mocker": "2.1.9", + "@vitest/pretty-format": "^2.1.9", + "@vitest/runner": "2.1.9", + "@vitest/snapshot": "2.1.9", + "@vitest/spy": "2.1.9", + "@vitest/utils": "2.1.9", "chai": "^5.1.2", "debug": "^4.3.7", "expect-type": "^1.1.0", @@ -5479,7 +5500,7 @@ "tinypool": "^1.0.1", "tinyrainbow": "^1.2.0", "vite": "^5.0.0", - "vite-node": "2.1.8", + "vite-node": "2.1.9", "why-is-node-running": "^2.3.0" }, "bin": { @@ -5494,8 +5515,8 @@ "peerDependencies": { "@edge-runtime/vm": "*", "@types/node": "^18.0.0 || >=20.0.0", - "@vitest/browser": "2.1.8", - "@vitest/ui": "2.1.8", + "@vitest/browser": "2.1.9", + "@vitest/ui": "2.1.9", "happy-dom": "*", "jsdom": "*" }, diff --git a/airflow/auth/managers/simple/ui/package.json b/airflow/auth/managers/simple/ui/package.json index 2001ef7162626..552d25f84a7b1 100644 --- a/airflow/auth/managers/simple/ui/package.json +++ b/airflow/auth/managers/simple/ui/package.json @@ -4,7 +4,7 @@ "version": "0.0.0", "type": "module", "scripts": { - "dev": "vite", + "dev": "vite --port 5174", "build": "vite build", "preview": "vite preview", "codegen": "openapi-rq -i \"../openapi/v1-generated.yaml\" -c axios --format prettier -o openapi-gen --operationId", @@ -29,6 +29,6 @@ "happy-dom": "^15.10.2", "vite": "^5.4.14", "vite-plugin-css-injected-by-js": "^3.5.2", - "vitest": "^2.1.1" + "vitest": "^2.1.9" } } diff --git a/airflow/auth/managers/simple/ui/pnpm-lock.yaml b/airflow/auth/managers/simple/ui/pnpm-lock.yaml index 1676d87998832..dc6fe7218cc16 100644 --- a/airflow/auth/managers/simple/ui/pnpm-lock.yaml +++ b/airflow/auth/managers/simple/ui/pnpm-lock.yaml @@ -55,8 +55,8 @@ importers: specifier: ^3.5.2 version: 3.5.2(vite@5.4.14) vitest: - specifier: ^2.1.1 - version: 2.1.8(happy-dom@15.11.7) + specifier: ^2.1.9 + version: 2.1.9(happy-dom@15.11.7) packages: @@ -103,8 +103,8 @@ packages: resolution: {integrity: sha512-Ed61U6XJc3CVRfkERJWDz4dJwKe7iLmmJsbOGu9wSloNSFttHV0I8g6UAgb7qnK5ly5bGLPd4oXZlxCdANBOWQ==} engines: {node: '>=6.9.0'} - '@babel/parser@7.26.5': - resolution: {integrity: sha512-SRJ4jYmXRqV1/Xc+TIVG84WjHBXKlxO9sHQnA2Pf12QQEAp1LOh6kDzNHXcUnbH1QI0FDoPPVOt+vyUDucxpaw==} + '@babel/parser@7.26.7': + resolution: {integrity: sha512-kEvgGGgEjRUutvdVvZhbn/BxVt+5VSpwXz1j3WYXQbXDo8KzFOPNG2GQbdAiNq8g6wn1yKk7C/qrke03a84V+w==} engines: {node: '>=6.0.0'} hasBin: true @@ -112,16 +112,20 @@ packages: resolution: {integrity: sha512-FDSOghenHTiToteC/QRlv2q3DhPZ/oOXTBoirfWNx1Cx3TMVcGWQtMMmQcSvb/JjpNeGzx8Pq/b4fKEJuWm1sw==} engines: {node: '>=6.9.0'} + '@babel/runtime@7.26.7': + resolution: {integrity: sha512-AOPI3D+a8dXnja+iwsUqGRjr1BbZIe771sXdapOtYI531gSqpi92vXivKcq2asu/DFpdl1ceFAKZyRzK2PCVcQ==} + engines: {node: '>=6.9.0'} + '@babel/template@7.25.9': resolution: {integrity: sha512-9DGttpmPvIxBb/2uwpVo3dqJ+O6RooAFOS+lB+xDqoE2PVCE8nfoHMdZLpfCQRLwvohzXISPZcgxt80xLfsuwg==} engines: {node: '>=6.9.0'} - '@babel/traverse@7.26.5': - resolution: {integrity: sha512-rkOSPOw+AXbgtwUga3U4u8RpoK9FEFWBNAlTpcnkLFjL5CT+oyHNuUUC/xx6XefEJ16r38r8Bc/lfp6rYuHeJQ==} + '@babel/traverse@7.26.7': + resolution: {integrity: sha512-1x1sgeyRLC3r5fQOM0/xtQKsYjyxmFjaOrLJNtZ81inNjyJHGIolTULPiSc/2qe1/qfpFLisLQYFnnZl7QoedA==} engines: {node: '>=6.9.0'} - '@babel/types@7.26.5': - resolution: {integrity: sha512-L6mZmwFDK6Cjh1nRCLXpa6no13ZIioJDz7mdkzHv399pThrTa/k0nUlNaenOeh2kWu/iaOQYElEpKPUswUa9Vg==} + '@babel/types@7.26.7': + resolution: {integrity: sha512-t8kDRGrKXyp6+tjUh7hw2RLyclsW4TRoRvRHtSyAX9Bb5ldlFh+90YAYY6awRXrlB4G5G2izNeGySpATlFzmOg==} engines: {node: '>=6.9.0'} '@chakra-ui/react@3.3.3': @@ -661,11 +665,11 @@ packages: peerDependencies: vite: ^4 || ^5 || ^6 - '@vitest/expect@2.1.8': - resolution: {integrity: sha512-8ytZ/fFHq2g4PJVAtDX57mayemKgDR6X3Oa2Foro+EygiOJHUXhCqBAAKQYYajZpFoIfvBCF1j6R6IYRSIUFuw==} + '@vitest/expect@2.1.9': + resolution: {integrity: sha512-UJCIkTBenHeKT1TTlKMJWy1laZewsRIzYighyYiJKZreqtdxSos/S1t+ktRMQWu2CKqaarrkeszJx1cgC5tGZw==} - '@vitest/mocker@2.1.8': - resolution: {integrity: sha512-7guJ/47I6uqfttp33mgo6ga5Gr1VnL58rcqYKyShoRK9ebu8T5Rs6HN3s1NABiBeVTdWNrwUMcHH54uXZBN4zA==} + '@vitest/mocker@2.1.9': + resolution: {integrity: sha512-tVL6uJgoUdi6icpxmdrn5YNo3g3Dxv+IHJBr0GXHaEdTcw3F+cPKnsXFhli6nO+f/6SDKPHEK1UN+k+TQv0Ehg==} peerDependencies: msw: ^2.4.9 vite: ^5.0.0 @@ -675,20 +679,20 @@ packages: vite: optional: true - '@vitest/pretty-format@2.1.8': - resolution: {integrity: sha512-9HiSZ9zpqNLKlbIDRWOnAWqgcA7xu+8YxXSekhr0Ykab7PAYFkhkwoqVArPOtJhPmYeE2YHgKZlj3CP36z2AJQ==} + '@vitest/pretty-format@2.1.9': + resolution: {integrity: sha512-KhRIdGV2U9HOUzxfiHmY8IFHTdqtOhIzCpd8WRdJiE7D/HUcZVD0EgQCVjm+Q9gkUXWgBvMmTtZgIG48wq7sOQ==} - '@vitest/runner@2.1.8': - resolution: {integrity: sha512-17ub8vQstRnRlIU5k50bG+QOMLHRhYPAna5tw8tYbj+jzjcspnwnwtPtiOlkuKC4+ixDPTuLZiqiWWQ2PSXHVg==} + '@vitest/runner@2.1.9': + resolution: {integrity: sha512-ZXSSqTFIrzduD63btIfEyOmNcBmQvgOVsPNPe0jYtESiXkhd8u2erDLnMxmGrDCwHCCHE7hxwRDCT3pt0esT4g==} - '@vitest/snapshot@2.1.8': - resolution: {integrity: sha512-20T7xRFbmnkfcmgVEz+z3AU/3b0cEzZOt/zmnvZEctg64/QZbSDJEVm9fLnnlSi74KibmRsO9/Qabi+t0vCRPg==} + '@vitest/snapshot@2.1.9': + resolution: {integrity: sha512-oBO82rEjsxLNJincVhLhaxxZdEtV0EFHMK5Kmx5sJ6H9L183dHECjiefOAdnqpIgT5eZwT04PoggUnW88vOBNQ==} - '@vitest/spy@2.1.8': - resolution: {integrity: sha512-5swjf2q95gXeYPevtW0BLk6H8+bPlMb4Vw/9Em4hFxDcaOxS+e0LOX4yqNxoHzMR2akEB2xfpnWUzkZokmgWDg==} + '@vitest/spy@2.1.9': + resolution: {integrity: sha512-E1B35FwzXXTs9FHNK6bDszs7mtydNi5MIfUWpceJ8Xbfb1gBMscAnwLbEu+B44ed6W3XjL9/ehLPHR1fkf1KLQ==} - '@vitest/utils@2.1.8': - resolution: {integrity: sha512-dwSoui6djdwbfFmIgbIjX2ZhIoG7Ex/+xpxyiEgIGzjliY8xGkcpITKTlp6B4MgtGkF2ilvm97cPM96XZaAgcA==} + '@vitest/utils@2.1.9': + resolution: {integrity: sha512-v0psaMSkNJ3A2NMrUEHFRzJtDPFn+/VWZ5WxImB21T9fjucJRmS7xCS3ppEnARb9y11OAzaD+P2Ps+b+BGX5iQ==} '@zag-js/accordion@0.81.1': resolution: {integrity: sha512-NMSx9DNz+FigY9E+FtT/3GCjpP4H0VTbBTmqUDxw3FYKgP3txPoIQGrV4Dig4hCtCiPdmlwSZatA29HrTi8+zw==} @@ -1198,8 +1202,8 @@ packages: fast-levenshtein@2.0.6: resolution: {integrity: sha512-DCXu6Ifhqcks7TZKY3Hxp3y6qphY5SJZmrWMDrKcERSOXWQdMhU9Ig/PYrzyw/ul9jOIyh0N4M0tbC5hodg8dw==} - fastq@1.18.0: - resolution: {integrity: sha512-QKHXPW0hD8g4UET03SdOdunzSouc9N4AuHdsX8XNcTsuz+yYFILVNIX4l9yHABMhiEI9Db0JTTIpu0wB+Y1QQw==} + fastq@1.19.0: + resolution: {integrity: sha512-7SFSRCNjBQIZH/xZR3iy5iQYR8aGBE0h3VG6/cwlbrpdciNYBMotQav8c1XI3HjHH+NikUpP53nPdlZSdWmFzA==} file-entry-cache@8.0.0: resolution: {integrity: sha512-XXTUwCvisa5oacNGRP9SfNtYBNAMi+RPwBFmblZEF7N7swHYQS6/Zfk7SRwx4D5j3CH211YNRco1DEMNVfZCnQ==} @@ -1312,6 +1316,10 @@ packages: resolution: {integrity: sha512-veYYhQa+D1QBKznvhUHxb8faxlrwUnxseDAbAp457E0wLNio2bOSKnjYDhMj+YiAq61xrMGhQk9iXVk5FzgQMw==} engines: {node: '>=6'} + import-fresh@3.3.1: + resolution: {integrity: sha512-TR3KfrTZTYLPB6jUjfx6MF9WcWrHL9su5TObK4ZkYgBdWKPOFoSoQIdEuTuR82pmtxH2spWG9h6etwfr1pLBqQ==} + engines: {node: '>=6'} + imurmurhash@0.1.4: resolution: {integrity: sha512-JmXMZ6wuvDmLiHEml9ykzqO6lwFbof0GG4IkcGaENdCRDDmMVnny7s5HsIgHCbaq0w2MyPhDqkhTUgS2LU2PHA==} engines: {node: '>=0.8.19'} @@ -1409,8 +1417,8 @@ packages: resolution: {integrity: sha512-lyuxPGr/Wfhrlem2CL/UcnUc1zcqKAImBDzukY7Y5F/yQiNdko6+fRLevlw1HgMySw7f611UIY408EtxRSoK3Q==} hasBin: true - loupe@3.1.2: - resolution: {integrity: sha512-23I4pFZHmAemUnz8WZXbYRSKYj801VDaNv9ETuMh7IrMc7VuVVSo+Z9iLE3ni30+U48iDWfi30d3twAXBYmnCg==} + loupe@3.1.3: + resolution: {integrity: sha512-kkIp7XSkP78ZxJEsSxW3712C6teJVoeHHwgo9zJ380de7IYyJ2ISlxojcH2pC5OFLewESmnRi/+XCDIEEVyoug==} lru-cache@10.4.3: resolution: {integrity: sha512-JNAzZcXrCt42VGLuYz0zfAzDfAvJWW6AfYlDBQyDV5DClI2m5sAmK+OIO7s59XfsRsWHp02jAJrRadPRGTt6SQ==} @@ -1580,8 +1588,8 @@ packages: pathe@1.1.2: resolution: {integrity: sha512-whLdWMYL2TwI08hn8/ZqAbrVemu0LNaNNJZX73O6qaIdCTfXutsLhMkjdENX0qhsQ9uIimo4/aQOmXkoon2nDQ==} - pathe@2.0.1: - resolution: {integrity: sha512-6jpjMpOth5S9ITVu5clZ7NOgHNsv5vRQdheL9ztp2vZmM6fRbLvyua1tiBIL4lk8SAe3ARzeXEly6siXCjDHDw==} + pathe@2.0.2: + resolution: {integrity: sha512-15Ztpk+nov8DR524R4BF7uEuzESgzUEAV4Ah7CUMNGXdE5ELuvxElxGXndBl32vMSsWa1jpNf22Z+Er3sKwq+w==} pathval@2.0.0: resolution: {integrity: sha512-vE7JKRyES09KiunauX7nd2Q9/L7lhok4smP9RZTDeD4MVs72Dp2qNFVz39Nz5a0FVEW0BJR6C0DYrq6unoziZA==} @@ -1830,8 +1838,8 @@ packages: uri-js@4.4.1: resolution: {integrity: sha512-7rKUyy33Q1yc98pQ1DAmLtwX109F7TIfWlW1Ydo8Wl1ii1SeHieeh0HHfPeL2fMXK6z0s8ecKs9frCuLJvndBg==} - vite-node@2.1.8: - resolution: {integrity: sha512-uPAwSr57kYjAUux+8E2j0q0Fxpn8M9VoyfGiRI8Kfktz9NcYMCenwY5RnZxnF1WTu3TGiYipirIzacLL3VVGFg==} + vite-node@2.1.9: + resolution: {integrity: sha512-AM9aQ/IPrW/6ENLQg3AGY4K1N2TGZdR5e4gu/MmmR2xR3Ll1+dib+nook92g4TV3PXVyeyxdWwtaCAiUL0hMxA==} engines: {node: ^18.0.0 || >=20.0.0} hasBin: true @@ -1871,15 +1879,15 @@ packages: terser: optional: true - vitest@2.1.8: - resolution: {integrity: sha512-1vBKTZskHw/aosXqQUlVWWlGUxSJR8YtiyZDJAFeW2kPAeX6S3Sool0mjspO+kXLuxVWlEDDowBAeqeAQefqLQ==} + vitest@2.1.9: + resolution: {integrity: sha512-MSmPM9REYqDGBI8439mA4mWhV5sKmDlBKWIYbA3lRb2PTHACE0mgKwA8yQ2xq9vxDTuk4iPrECBAEW2aoFXY0Q==} engines: {node: ^18.0.0 || >=20.0.0} hasBin: true peerDependencies: '@edge-runtime/vm': '*' '@types/node': ^18.0.0 || >=20.0.0 - '@vitest/browser': 2.1.8 - '@vitest/ui': 2.1.8 + '@vitest/browser': 2.1.9 + '@vitest/ui': 2.1.9 happy-dom: '*' jsdom: '*' peerDependenciesMeta: @@ -2025,16 +2033,16 @@ snapshots: '@babel/generator@7.26.5': dependencies: - '@babel/parser': 7.26.5 - '@babel/types': 7.26.5 + '@babel/parser': 7.26.7 + '@babel/types': 7.26.7 '@jridgewell/gen-mapping': 0.3.8 '@jridgewell/trace-mapping': 0.3.25 jsesc: 3.1.0 '@babel/helper-module-imports@7.25.9': dependencies: - '@babel/traverse': 7.26.5 - '@babel/types': 7.26.5 + '@babel/traverse': 7.26.7 + '@babel/types': 7.26.7 transitivePeerDependencies: - supports-color @@ -2042,33 +2050,37 @@ snapshots: '@babel/helper-validator-identifier@7.25.9': {} - '@babel/parser@7.26.5': + '@babel/parser@7.26.7': dependencies: - '@babel/types': 7.26.5 + '@babel/types': 7.26.7 '@babel/runtime@7.26.0': dependencies: regenerator-runtime: 0.14.1 + '@babel/runtime@7.26.7': + dependencies: + regenerator-runtime: 0.14.1 + '@babel/template@7.25.9': dependencies: '@babel/code-frame': 7.26.2 - '@babel/parser': 7.26.5 - '@babel/types': 7.26.5 + '@babel/parser': 7.26.7 + '@babel/types': 7.26.7 - '@babel/traverse@7.26.5': + '@babel/traverse@7.26.7': dependencies: '@babel/code-frame': 7.26.2 '@babel/generator': 7.26.5 - '@babel/parser': 7.26.5 + '@babel/parser': 7.26.7 '@babel/template': 7.25.9 - '@babel/types': 7.26.5 + '@babel/types': 7.26.7 debug: 4.4.0 globals: 11.12.0 transitivePeerDependencies: - supports-color - '@babel/types@7.26.5': + '@babel/types@7.26.7': dependencies: '@babel/helper-string-parser': 7.25.9 '@babel/helper-validator-identifier': 7.25.9 @@ -2089,7 +2101,7 @@ snapshots: '@emotion/babel-plugin@11.13.5': dependencies: '@babel/helper-module-imports': 7.25.9 - '@babel/runtime': 7.26.0 + '@babel/runtime': 7.26.7 '@emotion/hash': 0.9.2 '@emotion/memoize': 0.9.0 '@emotion/serialize': 1.3.3 @@ -2120,7 +2132,7 @@ snapshots: '@emotion/react@11.14.0(react@18.3.1)': dependencies: - '@babel/runtime': 7.26.0 + '@babel/runtime': 7.26.7 '@emotion/babel-plugin': 11.13.5 '@emotion/cache': 11.14.0 '@emotion/serialize': 1.3.3 @@ -2344,7 +2356,7 @@ snapshots: '@nodelib/fs.walk@1.2.8': dependencies: '@nodelib/fs.scandir': 2.1.5 - fastq: 1.18.0 + fastq: 1.19.0 '@pandacss/is-valid-prop@0.41.0': {} @@ -2477,7 +2489,7 @@ snapshots: '@testing-library/dom@10.4.0': dependencies: '@babel/code-frame': 7.26.2 - '@babel/runtime': 7.26.0 + '@babel/runtime': 7.26.7 '@types/aria-query': 5.0.4 aria-query: 5.3.0 chalk: 4.1.2 @@ -2524,44 +2536,44 @@ snapshots: transitivePeerDependencies: - '@swc/helpers' - '@vitest/expect@2.1.8': + '@vitest/expect@2.1.9': dependencies: - '@vitest/spy': 2.1.8 - '@vitest/utils': 2.1.8 + '@vitest/spy': 2.1.9 + '@vitest/utils': 2.1.9 chai: 5.1.2 tinyrainbow: 1.2.0 - '@vitest/mocker@2.1.8(vite@5.4.14)': + '@vitest/mocker@2.1.9(vite@5.4.14)': dependencies: - '@vitest/spy': 2.1.8 + '@vitest/spy': 2.1.9 estree-walker: 3.0.3 magic-string: 0.30.17 optionalDependencies: vite: 5.4.14 - '@vitest/pretty-format@2.1.8': + '@vitest/pretty-format@2.1.9': dependencies: tinyrainbow: 1.2.0 - '@vitest/runner@2.1.8': + '@vitest/runner@2.1.9': dependencies: - '@vitest/utils': 2.1.8 + '@vitest/utils': 2.1.9 pathe: 1.1.2 - '@vitest/snapshot@2.1.8': + '@vitest/snapshot@2.1.9': dependencies: - '@vitest/pretty-format': 2.1.8 + '@vitest/pretty-format': 2.1.9 magic-string: 0.30.17 pathe: 1.1.2 - '@vitest/spy@2.1.8': + '@vitest/spy@2.1.9': dependencies: tinyspy: 3.0.2 - '@vitest/utils@2.1.8': + '@vitest/utils@2.1.9': dependencies: - '@vitest/pretty-format': 2.1.8 - loupe: 3.1.2 + '@vitest/pretty-format': 2.1.9 + loupe: 3.1.3 tinyrainbow: 1.2.0 '@zag-js/accordion@0.81.1': @@ -3087,7 +3099,7 @@ snapshots: babel-plugin-macros@3.1.0: dependencies: - '@babel/runtime': 7.26.0 + '@babel/runtime': 7.26.7 cosmiconfig: 7.1.0 resolve: 1.22.10 @@ -3136,7 +3148,7 @@ snapshots: assertion-error: 2.0.1 check-error: 2.1.1 deep-eql: 5.0.2 - loupe: 3.1.2 + loupe: 3.1.3 pathval: 2.0.0 chalk@3.0.0: @@ -3194,7 +3206,7 @@ snapshots: cosmiconfig@7.1.0: dependencies: '@types/parse-json': 4.0.2 - import-fresh: 3.3.0 + import-fresh: 3.3.1 parse-json: 5.2.0 path-type: 4.0.0 yaml: 1.10.2 @@ -3373,7 +3385,7 @@ snapshots: fast-levenshtein@2.0.6: {} - fastq@1.18.0: + fastq@1.19.0: dependencies: reusify: 1.0.4 @@ -3489,6 +3501,11 @@ snapshots: parent-module: 1.0.1 resolve-from: 4.0.0 + import-fresh@3.3.1: + dependencies: + parent-module: 1.0.1 + resolve-from: 4.0.0 + imurmurhash@0.1.4: {} indent-string@4.0.0: {} @@ -3564,7 +3581,7 @@ snapshots: dependencies: js-tokens: 4.0.0 - loupe@3.1.2: {} + loupe@3.1.3: {} lru-cache@10.4.3: {} @@ -3576,8 +3593,8 @@ snapshots: magicast@0.3.5: dependencies: - '@babel/parser': 7.26.5 - '@babel/types': 7.26.5 + '@babel/parser': 7.26.7 + '@babel/types': 7.26.7 source-map-js: 1.2.1 optional: true @@ -3630,7 +3647,7 @@ snapshots: mlly@1.7.4: dependencies: acorn: 8.14.0 - pathe: 2.0.1 + pathe: 2.0.2 pkg-types: 1.3.1 ufo: 1.5.4 @@ -3714,7 +3731,7 @@ snapshots: pathe@1.1.2: {} - pathe@2.0.1: {} + pathe@2.0.2: {} pathval@2.0.0: {} @@ -3730,7 +3747,7 @@ snapshots: dependencies: confbox: 0.1.8 mlly: 1.7.4 - pathe: 2.0.1 + pathe: 2.0.2 postcss@8.5.1: dependencies: @@ -3950,7 +3967,7 @@ snapshots: dependencies: punycode: 2.3.1 - vite-node@2.1.8: + vite-node@2.1.9: dependencies: cac: 6.7.14 debug: 4.4.0 @@ -3980,15 +3997,15 @@ snapshots: optionalDependencies: fsevents: 2.3.3 - vitest@2.1.8(happy-dom@15.11.7): + vitest@2.1.9(happy-dom@15.11.7): dependencies: - '@vitest/expect': 2.1.8 - '@vitest/mocker': 2.1.8(vite@5.4.14) - '@vitest/pretty-format': 2.1.8 - '@vitest/runner': 2.1.8 - '@vitest/snapshot': 2.1.8 - '@vitest/spy': 2.1.8 - '@vitest/utils': 2.1.8 + '@vitest/expect': 2.1.9 + '@vitest/mocker': 2.1.9(vite@5.4.14) + '@vitest/pretty-format': 2.1.9 + '@vitest/runner': 2.1.9 + '@vitest/snapshot': 2.1.9 + '@vitest/spy': 2.1.9 + '@vitest/utils': 2.1.9 chai: 5.1.2 debug: 4.4.0 expect-type: 1.1.0 @@ -4000,7 +4017,7 @@ snapshots: tinypool: 1.0.2 tinyrainbow: 1.2.0 vite: 5.4.14 - vite-node: 2.1.8 + vite-node: 2.1.9 why-is-node-running: 2.3.0 optionalDependencies: happy-dom: 15.11.7 diff --git a/airflow/cli/cli_config.py b/airflow/cli/cli_config.py index 0186add2bb3bb..78cfbf394b839 100644 --- a/airflow/cli/cli_config.py +++ b/airflow/cli/cli_config.py @@ -1073,7 +1073,7 @@ class GroupCommand(NamedTuple): name="state", help="Get the status of a dag run", func=lazy_load_command("airflow.cli.commands.remote_commands.dag_command.dag_state"), - args=(ARG_DAG_ID, ARG_LOGICAL_DATE, ARG_SUBDIR, ARG_VERBOSE), + args=(ARG_DAG_ID, ARG_LOGICAL_DATE_OR_RUN_ID, ARG_SUBDIR, ARG_VERBOSE), ), ActionCommand( name="next-execution", diff --git a/airflow/cli/commands/remote_commands/config_command.py b/airflow/cli/commands/remote_commands/config_command.py index dee1e89df5682..5a52edbaaa3cc 100644 --- a/airflow/cli/commands/remote_commands/config_command.py +++ b/airflow/cli/commands/remote_commands/config_command.py @@ -84,11 +84,13 @@ class ConfigChange: :param config: The configuration parameter being changed. :param suggestion: A suggestion for replacing or handling the removed configuration. :param renamed_to: The new section and option if the configuration is renamed. + :param was_deprecated: If the config is removed, whether the old config was deprecated. """ config: ConfigParameter suggestion: str = "" renamed_to: ConfigParameter | None = None + was_deprecated: bool = True @property def message(self) -> str: @@ -96,15 +98,16 @@ def message(self) -> str: if self.renamed_to: if self.config.section != self.renamed_to.section: return ( - f"`{self.config.option}` configuration parameter moved from `{self.config.section}` section to `" - f"{self.renamed_to.section}` section as `{self.renamed_to.option}`." + f"`{self.config.option}` configuration parameter moved from `{self.config.section}` section to " + f"`{self.renamed_to.section}` section as `{self.renamed_to.option}`." ) return ( f"`{self.config.option}` configuration parameter renamed to `{self.renamed_to.option}` " f"in the `{self.config.section}` section." ) return ( - f"Removed deprecated `{self.config.option}` configuration parameter from `{self.config.section}` section. " + f"Removed{' deprecated' if self.was_deprecated else ''} `{self.config.option}` configuration parameter " + f"from `{self.config.section}` section. " f"{self.suggestion}" ) @@ -203,6 +206,12 @@ def message(self) -> str: config=ConfigParameter("core", "dag_file_processor_timeout"), renamed_to=ConfigParameter("dag_processor", "dag_file_processor_timeout"), ), + ConfigChange( + config=ConfigParameter("core", "dag_processor_manager_log_location"), + ), + ConfigChange( + config=ConfigParameter("core", "log_processor_filename_template"), + ), # api ConfigChange( config=ConfigParameter("api", "access_control_allow_origin"), @@ -218,6 +227,18 @@ def message(self) -> str: suggestion="Remove TaskContextLogger: Replaced by the Log table for better handling of task log " "messages outside the execution context.", ), + ConfigChange( + config=ConfigParameter("logging", "dag_processor_manager_log_location"), + was_deprecated=False, + ), + ConfigChange( + config=ConfigParameter("logging", "dag_processor_manager_log_stdout"), + was_deprecated=False, + ), + ConfigChange( + config=ConfigParameter("logging", "log_processor_filename_template"), + was_deprecated=False, + ), # metrics ConfigChange( config=ConfigParameter("metrics", "metrics_use_pattern_match"), diff --git a/airflow/cli/commands/remote_commands/dag_command.py b/airflow/cli/commands/remote_commands/dag_command.py index 6c95451271fbc..bf0841f1250be 100644 --- a/airflow/cli/commands/remote_commands/dag_command.py +++ b/airflow/cli/commands/remote_commands/dag_command.py @@ -33,6 +33,7 @@ from airflow.api.client import get_current_api_client from airflow.api_connexion.schemas.dag_schema import dag_schema from airflow.cli.simple_table import AirflowConsole +from airflow.cli.utils import fetch_dag_run_from_run_id_or_logical_date_string from airflow.exceptions import AirflowException from airflow.jobs.job import Job from airflow.models import DagBag, DagModel, DagRun, TaskInstance @@ -264,12 +265,17 @@ def dag_state(args, session: Session = NEW_SESSION) -> None: if not dag: raise SystemExit(f"DAG: {args.dag_id} does not exist in 'dag' table") - dr = session.scalar(select(DagRun).filter_by(dag_id=args.dag_id, logical_date=args.logical_date)) - out = dr.state if dr else None - conf_out = "" - if out and dr.conf: - conf_out = ", " + json.dumps(dr.conf) - print(str(out) + conf_out) + dr, _ = fetch_dag_run_from_run_id_or_logical_date_string( + dag_id=dag.dag_id, + value=args.logical_date_or_run_id, + session=session, + ) + if not dr: + print(None) + elif dr.conf: + print(f"{dr.state}, {json.dumps(dr.conf)}") + else: + print(dr.state) @cli_utils.action_cli @@ -465,20 +471,20 @@ def dag_list_dag_runs(args, dag: DAG | None = None, session: Session = NEW_SESSI logical_end_date=args.end_date, session=session, ) + dag_runs.sort(key=operator.attrgetter("run_after"), reverse=True) - dag_runs.sort(key=lambda x: x.logical_date, reverse=True) - AirflowConsole().print_as( - data=dag_runs, - output=args.output, - mapper=lambda dr: { + def _render_dagrun(dr: DagRun) -> dict[str, str]: + return { "dag_id": dr.dag_id, "run_id": dr.run_id, "state": dr.state, - "logical_date": dr.logical_date.isoformat(), + "run_after": dr.run_after.isoformat(), + "logical_date": dr.logical_date.isoformat() if dr.logical_date else "", "start_date": dr.start_date.isoformat() if dr.start_date else "", "end_date": dr.end_date.isoformat() if dr.end_date else "", - }, - ) + } + + AirflowConsole().print_as(data=dag_runs, output=args.output, mapper=_render_dagrun) @cli_utils.action_cli @@ -515,7 +521,7 @@ def dag_test(args, dag: DAG | None = None, session: Session = NEW_SESSION) -> No tis = session.scalars( select(TaskInstance).where( TaskInstance.dag_id == args.dag_id, - TaskInstance.logical_date == logical_date, + TaskInstance.run_id == dr.run_id, ) ).all() diff --git a/airflow/cli/commands/remote_commands/task_command.py b/airflow/cli/commands/remote_commands/task_command.py index ccae29ededd09..debe1b814397b 100644 --- a/airflow/cli/commands/remote_commands/task_command.py +++ b/airflow/cli/commands/remote_commands/task_command.py @@ -32,11 +32,10 @@ from typing import TYPE_CHECKING, Protocol, cast import pendulum -from pendulum.parsing.exceptions import ParserError -from sqlalchemy import select from airflow import settings from airflow.cli.simple_table import AirflowConsole +from airflow.cli.utils import fetch_dag_run_from_run_id_or_logical_date_string from airflow.configuration import conf from airflow.exceptions import AirflowException, DagRunNotFound, TaskDeferred, TaskInstanceNotFound from airflow.executors.executor_loader import ExecutorLoader @@ -48,6 +47,7 @@ from airflow.models.dagrun import DagRun from airflow.models.taskinstance import TaskReturnCode from airflow.sdk.definitions.param import ParamsDict +from airflow.sdk.execution_time.secrets_masker import RedactedIO from airflow.settings import IS_EXECUTOR_CONTAINER, IS_K8S_EXECUTOR_POD from airflow.ti_deps.dep_context import DepContext from airflow.ti_deps.dependencies_deps import SCHEDULER_QUEUED_DEPS @@ -61,7 +61,6 @@ ) from airflow.utils.log.file_task_handler import _set_task_deferred_context_var from airflow.utils.log.logging_mixin import StreamLogWriter -from airflow.utils.log.secrets_masker import RedactedIO from airflow.utils.net import get_hostname from airflow.utils.providers_configuration_loader import providers_configuration_loaded from airflow.utils.session import NEW_SESSION, create_session, provide_session @@ -91,41 +90,6 @@ def _generate_temporary_run_id() -> str: return f"__airflow_temporary_run_{timezone.utcnow().isoformat()}__" -def _fetch_dag_run_from_run_id_or_logical_date_string( - *, - dag_id: str, - value: str, - session: Session, -) -> tuple[DagRun, pendulum.DateTime | None]: - """ - Try to find a DAG run with a given string value. - - The string value may be a run ID, or a logical date in string form. We first - try to use it as a run_id; if a run is found, it is returned as-is. - - Otherwise, the string value is parsed into a datetime. If that works, it is - used to find a DAG run. - - The return value is a two-tuple. The first item is the found DAG run (or - *None* if one cannot be found). The second is the parsed logical date. This - second value can be used to create a new run by the calling function when - one cannot be found here. - """ - if dag_run := DAG.fetch_dagrun(dag_id=dag_id, run_id=value, session=session): - return dag_run, dag_run.logical_date # type: ignore[return-value] - try: - logical_date = timezone.parse(value) - except (ParserError, TypeError): - return dag_run, None - dag_run = session.scalar( - select(DagRun) - .where(DagRun.dag_id == dag_id, DagRun.logical_date == logical_date) - .order_by(DagRun.id.desc()) - .limit(1) - ) - return dag_run, logical_date - - def _get_dag_run( *, dag: DAG, @@ -152,7 +116,7 @@ def _get_dag_run( logical_date = None if logical_date_or_run_id: - dag_run, logical_date = _fetch_dag_run_from_run_id_or_logical_date_string( + dag_run, logical_date = fetch_dag_run_from_run_id_or_logical_date_string( dag_id=dag.dag_id, value=logical_date_or_run_id, session=session, @@ -165,11 +129,7 @@ def _get_dag_run( f"of {logical_date_or_run_id!r} not found" ) - if logical_date is not None: - dag_run_logical_date = logical_date - else: - dag_run_logical_date = pendulum.instance(timezone.utcnow()) - + dag_run_logical_date = pendulum.instance(logical_date or timezone.utcnow()) if create_if_necessary == "memory": data_interval = dag.timetable.infer_manual_data_interval(run_after=dag_run_logical_date) dag_run = DagRun( @@ -592,17 +552,11 @@ def _guess_debugger() -> _SupportedDebugger: @provide_session def task_states_for_dag_run(args, session: Session = NEW_SESSION) -> None: """Get the status of all task instances in a DagRun.""" - dag_run = session.scalar( - select(DagRun).where(DagRun.run_id == args.logical_date_or_run_id, DagRun.dag_id == args.dag_id) + dag_run, _ = fetch_dag_run_from_run_id_or_logical_date_string( + dag_id=args.dag_id, + value=args.logical_date_or_run_id, + session=session, ) - if not dag_run: - try: - logical_date = timezone.parse(args.logical_date_or_run_id) - dag_run = session.scalar( - select(DagRun).where(DagRun.logical_date == logical_date, DagRun.dag_id == args.dag_id) - ) - except (ParserError, TypeError) as err: - raise AirflowException(f"Error parsing the supplied logical_date. Error: {err}") if dag_run is None: raise DagRunNotFound( @@ -615,7 +569,7 @@ def task_states_for_dag_run(args, session: Session = NEW_SESSION) -> None: def format_task_instance(ti: TaskInstance) -> dict[str, str]: data = { "dag_id": ti.dag_id, - "logical_date": dag_run.logical_date.isoformat(), + "logical_date": dag_run.logical_date.isoformat() if dag_run.logical_date else "", "task_id": ti.task_id, "state": ti.state, "start_date": ti.start_date.isoformat() if ti.start_date else "", diff --git a/airflow/cli/utils.py b/airflow/cli/utils.py index d132deeb373a9..605244ee71e56 100644 --- a/airflow/cli/utils.py +++ b/airflow/cli/utils.py @@ -17,13 +17,17 @@ from __future__ import annotations -import io import sys -from collections.abc import Collection from typing import TYPE_CHECKING if TYPE_CHECKING: - from io import IOBase + import datetime + from collections.abc import Collection + from io import IOBase, TextIOWrapper + + from sqlalchemy.orm import Session + + from airflow.models.dagrun import DagRun class CliConflictError(Exception): @@ -45,8 +49,50 @@ def is_stdout(fileio: IOBase) -> bool: return fileio.fileno() == sys.stdout.fileno() -def print_export_output(command_type: str, exported_items: Collection, file: io.TextIOWrapper): +def print_export_output(command_type: str, exported_items: Collection, file: TextIOWrapper): if not file.closed and is_stdout(file): print(f"\n{len(exported_items)} {command_type} successfully exported.", file=sys.stderr) else: print(f"{len(exported_items)} {command_type} successfully exported to {file.name}.") + + +def fetch_dag_run_from_run_id_or_logical_date_string( + *, + dag_id: str, + value: str, + session: Session, +) -> tuple[DagRun | None, datetime.datetime | None]: + """ + Try to find a DAG run with a given string value. + + The string value may be a run ID, or a logical date in string form. We first + try to use it as a run_id; if a run is found, it is returned as-is. + + Otherwise, the string value is parsed into a datetime. If that works, it is + used to find a DAG run. + + The return value is a two-tuple. The first item is the found DAG run (or + *None* if one cannot be found). The second is the parsed logical date. This + second value can be used to create a new run by the calling function when + one cannot be found here. + """ + from pendulum.parsing.exceptions import ParserError + from sqlalchemy import select + + from airflow.models.dag import DAG + from airflow.models.dagrun import DagRun + from airflow.utils import timezone + + if dag_run := DAG.fetch_dagrun(dag_id=dag_id, run_id=value, session=session): + return dag_run, dag_run.logical_date + try: + logical_date = timezone.parse(value) + except (ParserError, TypeError): + return None, None + dag_run = session.scalar( + select(DagRun) + .where(DagRun.dag_id == dag_id, DagRun.logical_date == logical_date) + .order_by(DagRun.id.desc()) + .limit(1) + ) + return dag_run, logical_date diff --git a/airflow/config_templates/airflow_local_settings.py b/airflow/config_templates/airflow_local_settings.py index f440261dafc86..5ac7f513d1f87 100644 --- a/airflow/config_templates/airflow_local_settings.py +++ b/airflow/config_templates/airflow_local_settings.py @@ -20,7 +20,6 @@ from __future__ import annotations import os -from pathlib import Path from typing import Any from urllib.parse import urlsplit @@ -53,17 +52,6 @@ PROCESSOR_LOG_FOLDER: str = conf.get_mandatory_value("scheduler", "CHILD_PROCESS_LOG_DIRECTORY") -DAG_PROCESSOR_MANAGER_LOG_LOCATION: str = conf.get_mandatory_value( - "logging", "DAG_PROCESSOR_MANAGER_LOG_LOCATION" -) - -DAG_PROCESSOR_MANAGER_LOG_STDOUT: str = conf.get_mandatory_value( - "logging", "DAG_PROCESSOR_MANAGER_LOG_STDOUT" -) - - -PROCESSOR_FILENAME_TEMPLATE: str = conf.get_mandatory_value("logging", "LOG_PROCESSOR_FILENAME_TEMPLATE") - DEFAULT_LOGGING_CONFIG: dict[str, Any] = { "version": 1, "disable_existing_loggers": False, @@ -83,7 +71,7 @@ }, "filters": { "mask_secrets": { - "()": "airflow.utils.log.secrets_masker.SecretsMasker", + "()": "airflow.sdk.execution_time.secrets_masker.SecretsMasker", }, }, "handlers": { @@ -99,27 +87,8 @@ "base_log_folder": os.path.expanduser(BASE_LOG_FOLDER), "filters": ["mask_secrets"], }, - "processor": { - "class": "airflow.utils.log.file_processor_handler.FileProcessorHandler", - "formatter": "airflow", - "base_log_folder": os.path.expanduser(PROCESSOR_LOG_FOLDER), - "filename_template": PROCESSOR_FILENAME_TEMPLATE, - "filters": ["mask_secrets"], - }, - "processor_to_stdout": { - "class": "airflow.utils.log.logging_mixin.RedirectStdHandler", - "formatter": "source_processor", - "stream": "sys.stdout", - "filters": ["mask_secrets"], - }, }, "loggers": { - "airflow.processor": { - "handlers": ["processor_to_stdout" if DAG_PROCESSOR_LOG_TARGET == "stdout" else "processor"], - "level": LOG_LEVEL, - # Set to true here (and reset via set_context) so that if no file is configured we still get logs! - "propagate": True, - }, "airflow.task": { "handlers": ["task"], "level": LOG_LEVEL, @@ -152,54 +121,6 @@ } DEFAULT_LOGGING_CONFIG["loggers"].update(new_loggers) -DEFAULT_DAG_PARSING_LOGGING_CONFIG: dict[str, dict[str, dict[str, Any]]] = { - "handlers": { - "processor_manager": { - "class": "airflow.utils.log.non_caching_file_handler.NonCachingRotatingFileHandler", - "formatter": "airflow", - "filename": DAG_PROCESSOR_MANAGER_LOG_LOCATION, - "mode": "a", - "maxBytes": 104857600, # 100MB - "backupCount": 5, - } - }, - "loggers": { - "airflow.processor_manager": { - "handlers": ["processor_manager"], - "level": LOG_LEVEL, - "propagate": False, - } - }, -} - -if DAG_PROCESSOR_MANAGER_LOG_STDOUT == "True": - DEFAULT_DAG_PARSING_LOGGING_CONFIG["handlers"].update( - { - "console": { - "class": "airflow.utils.log.logging_mixin.RedirectStdHandler", - "formatter": "airflow", - "stream": "sys.stdout", - "filters": ["mask_secrets"], - } - } - ) - DEFAULT_DAG_PARSING_LOGGING_CONFIG["loggers"]["airflow.processor_manager"]["handlers"].append("console") - -# Only update the handlers and loggers when CONFIG_PROCESSOR_MANAGER_LOGGER is set. -# This is to avoid exceptions when initializing RotatingFileHandler multiple times -# in multiple processes. -if os.environ.get("CONFIG_PROCESSOR_MANAGER_LOGGER") == "True": - DEFAULT_LOGGING_CONFIG["handlers"].update(DEFAULT_DAG_PARSING_LOGGING_CONFIG["handlers"]) - DEFAULT_LOGGING_CONFIG["loggers"].update(DEFAULT_DAG_PARSING_LOGGING_CONFIG["loggers"]) - - # Manually create log directory for processor_manager handler as RotatingFileHandler - # will only create file but not the directory. - processor_manager_handler_config: dict[str, Any] = DEFAULT_DAG_PARSING_LOGGING_CONFIG["handlers"][ - "processor_manager" - ] - directory: str = os.path.dirname(processor_manager_handler_config["filename"]) - Path(directory).mkdir(parents=True, exist_ok=True, mode=0o755) - ################## # Remote logging # ################## diff --git a/airflow/config_templates/config.yml b/airflow/config_templates/config.yml index 8e8726bd07e02..201d4e3e6bc83 100644 --- a/airflow/config_templates/config.yml +++ b/airflow/config_templates/config.yml @@ -886,10 +886,10 @@ logging: secret_mask_adapter: description: | An import path to a function to add adaptations of each secret added with - ``airflow.utils.log.secrets_masker.mask_secret`` to be masked in log messages. The given function - is expected to require a single parameter: the secret to be adapted. It may return a - single adaptation of the secret or an iterable of adaptations to each be masked as secrets. - The original secret will be masked as well as any adaptations returned. + ``airflow.sdk.execution_time.secrets_masker.mask_secret`` to be masked in log messages. + The given function is expected to require a single parameter: the secret to be adapted. + It may return a single adaptation of the secret or an iterable of adaptations to each be + masked as secrets. The original secret will be masked as well as any adaptations returned. version_added: 2.6.0 type: string default: "" @@ -912,28 +912,6 @@ logging: default: "dag_id={{ ti.dag_id }}/run_id={{ ti.run_id }}/task_id={{ ti.task_id }}/\ {%% if ti.map_index >= 0 %%}map_index={{ ti.map_index }}/{%% endif %%}\ attempt={{ try_number|default(ti.try_number) }}.log" - log_processor_filename_template: - description: | - Formatting for how airflow generates file names for log - version_added: 2.0.0 - type: string - example: ~ - is_template: true - default: "{{ filename }}.log" - dag_processor_manager_log_location: - description: | - Full path of dag_processor_manager logfile. - version_added: 2.0.0 - type: string - example: ~ - default: "{AIRFLOW_HOME}/logs/dag_processor_manager/dag_processor_manager.log" - dag_processor_manager_log_stdout: - description: | - Whether DAG processor manager will write logs to stdout - version_added: 2.9.0 - type: boolean - example: ~ - default: "False" task_log_reader: description: | Name of handler to read task instance logs. @@ -1595,9 +1573,9 @@ webserver: default: "120" worker_refresh_batch_size: description: | - Number of workers to refresh at a time. When set to 0, worker refresh is - disabled. When nonzero, airflow periodically refreshes webserver workers by - bringing up new ones and killing old ones. + Number of workers to refresh at a time through Gunicorn's built-in worker management. + When set to 0, worker refresh is disabled. When nonzero, airflow periodically refreshes + webserver workers by bringing up new ones and killing old ones. version_added: ~ type: string example: ~ @@ -2698,3 +2676,15 @@ dag_processor: type: integer example: ~ default: "30" +fastapi: + description: Configuration for the Fastapi webserver. + options: + base_url: + description: | + The base url of the Fastapi endpoint. Airflow cannot guess what domain or CNAME you are using. + If the Airflow console (the front-end) and the Fastapi apis are on a different domain, this config + should contain the Fastapi apis endpoint. + version_added: ~ + type: string + example: ~ + default: "http://localhost:29091" diff --git a/airflow/configuration.py b/airflow/configuration.py index 5257da834035c..6ba3fe6006c66 100644 --- a/airflow/configuration.py +++ b/airflow/configuration.py @@ -782,7 +782,7 @@ def _create_future_warning(name: str, section: str, current_value: Any, new_valu ) def mask_secrets(self): - from airflow.utils.log.secrets_masker import mask_secret + from airflow.sdk.execution_time.secrets_masker import mask_secret for section, key in self.sensitive_config_values: try: diff --git a/airflow/dag_processing/bundles/base.py b/airflow/dag_processing/bundles/base.py index 81bf38d373677..3d86f398c73cb 100644 --- a/airflow/dag_processing/bundles/base.py +++ b/airflow/dag_processing/bundles/base.py @@ -17,8 +17,10 @@ from __future__ import annotations +import fcntl import tempfile from abc import ABC, abstractmethod +from contextlib import contextmanager from pathlib import Path from airflow.configuration import conf @@ -46,6 +48,7 @@ class BaseDagBundle(ABC): """ supports_versioning: bool = False + _locked: bool = False def __init__( self, @@ -67,6 +70,10 @@ def initialize(self) -> None: and allows for deferring expensive operations until that point in time. This will only be called when Airflow needs the bundle files on disk - some uses only need to call the `view_url` method, which can run without initializing the bundle. + + This method must ultimately be safe to call concurrently from different threads or processes. + If it isn't naturally safe, you'll need to make it so with some form of locking. + There is a `lock` context manager on this class available for this purpose. """ self.is_initialized = True @@ -101,7 +108,13 @@ def get_current_version(self) -> str | None: @abstractmethod def refresh(self) -> None: - """Retrieve the latest version of the files in the bundle.""" + """ + Retrieve the latest version of the files in the bundle. + + This method must ultimately be safe to call concurrently from different threads or processes. + If it isn't naturally safe, you'll need to make it so with some form of locking. + There is a `lock` context manager on this class available for this purpose. + """ def view_url(self, version: str | None = None) -> str | None: """ @@ -112,3 +125,27 @@ def view_url(self, version: str | None = None) -> str | None: :param version: Version to view :return: URL to view the bundle """ + + @contextmanager + def lock(self): + """ + Ensure only a single bundle can enter this context at a time, by taking an exclusive lock on a lockfile. + + This is useful when a bundle needs to perform operations that are not safe to run concurrently. + """ + if self._locked: + yield + return + + lock_dir_path = self._dag_bundle_root_storage_path / "_locks" + lock_dir_path.mkdir(parents=True, exist_ok=True) + lock_file_path = lock_dir_path / f"{self.name}.lock" + with open(lock_file_path, "w") as lock_file: + # Exclusive lock - blocks until it is available + fcntl.flock(lock_file, fcntl.LOCK_EX) + try: + self._locked = True + yield + finally: + fcntl.flock(lock_file, fcntl.LOCK_UN) + self._locked = False diff --git a/airflow/dag_processing/bundles/git.py b/airflow/dag_processing/bundles/git.py index abebbb4a33820..ea6833e66bae1 100644 --- a/airflow/dag_processing/bundles/git.py +++ b/airflow/dag_processing/bundles/git.py @@ -17,8 +17,10 @@ from __future__ import annotations +import contextlib import json import os +import tempfile from typing import TYPE_CHECKING, Any from urllib.parse import urlparse @@ -60,6 +62,7 @@ def get_ui_field_behaviour(cls) -> dict[str, Any]: "extra": json.dumps( { "key_file": "optional/path/to/keyfile", + "private_key": "optional inline private key", } ) }, @@ -70,15 +73,22 @@ def __init__(self, git_conn_id="git_default", *args, **kwargs): connection = self.get_connection(git_conn_id) self.repo_url = connection.host self.auth_token = connection.password + self.private_key = connection.extra_dejson.get("private_key") self.key_file = connection.extra_dejson.get("key_file") - strict_host_key_checking = connection.extra_dejson.get("strict_host_key_checking", "no") + self.strict_host_key_checking = connection.extra_dejson.get("strict_host_key_checking", "no") self.env: dict[str, str] = {} - if self.key_file: - self.env["GIT_SSH_COMMAND"] = ( - f"ssh -i {self.key_file} -o IdentitiesOnly=yes -o StrictHostKeyChecking={strict_host_key_checking}" - ) + + if self.key_file and self.private_key: + raise AirflowException("Both 'key_file' and 'private_key' cannot be provided at the same time") self._process_git_auth_url() + def _build_ssh_command(self, key_path: str) -> str: + return ( + f"ssh -i {key_path} " + f"-o IdentitiesOnly=yes " + f"-o StrictHostKeyChecking={self.strict_host_key_checking}" + ) + def _process_git_auth_url(self): if not isinstance(self.repo_url, str): return @@ -87,6 +97,22 @@ def _process_git_auth_url(self): elif not self.repo_url.startswith("git@") or not self.repo_url.startswith("https://"): self.repo_url = os.path.expanduser(self.repo_url) + def set_git_env(self, key: str) -> None: + self.env["GIT_SSH_COMMAND"] = self._build_ssh_command(key) + + @contextlib.contextmanager + def configure_hook_env(self): + if self.private_key: + with tempfile.NamedTemporaryFile(mode="w", delete=True) as tmp_keyfile: + tmp_keyfile.write(self.private_key) + tmp_keyfile.flush() + os.chmod(tmp_keyfile.name, 0o600) + self.set_git_env(tmp_keyfile.name) + yield + else: + self.set_git_env(self.key_file) + yield + class GitDagBundle(BaseDagBundle, LoggingMixin): """ @@ -128,17 +154,20 @@ def __init__( self.log.warning("Could not create GitHook for connection %s : %s", self.git_conn_id, e) def _initialize(self): - self._clone_bare_repo_if_required() - self._ensure_version_in_bare_repo() - self._clone_repo_if_required() - self.repo.git.checkout(self.tracking_ref) - if self.version: - if not self._has_version(self.repo, self.version): - self.repo.remotes.origin.fetch() - self.repo.head.set_reference(self.repo.commit(self.version)) - self.repo.head.reset(index=True, working_tree=True) - else: - self.refresh() + with self.lock(): + with self.hook.configure_hook_env(): + self._clone_bare_repo_if_required() + self._ensure_version_in_bare_repo() + + self._clone_repo_if_required() + self.repo.git.checkout(self.tracking_ref) + if self.version: + if not self._has_version(self.repo, self.version): + self.repo.remotes.origin.fetch() + self.repo.head.set_reference(self.repo.commit(self.version)) + self.repo.head.reset(index=True, working_tree=True) + else: + self.refresh() def initialize(self) -> None: if not self.repo_url: @@ -230,8 +259,11 @@ def _fetch_bare_repo(self): def refresh(self) -> None: if self.version: raise AirflowException("Refreshing a specific version is not supported") - self._fetch_bare_repo() - self.repo.remotes.origin.pull() + + with self.lock(): + with self.hook.configure_hook_env(): + self._fetch_bare_repo() + self.repo.remotes.origin.pull() @staticmethod def _convert_git_ssh_url_to_https(url: str) -> str: diff --git a/airflow/dag_processing/manager.py b/airflow/dag_processing/manager.py index 815a45b0fd99e..3fb5ac3106d5d 100644 --- a/airflow/dag_processing/manager.py +++ b/airflow/dag_processing/manager.py @@ -63,6 +63,7 @@ from airflow.traces.tracer import Trace from airflow.utils import timezone from airflow.utils.file import list_py_file_paths, might_contain_dag +from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.net import get_hostname from airflow.utils.process_utils import ( kill_child_processes_by_pids, @@ -97,9 +98,6 @@ class DagFileStat: last_num_of_db_queries: int = 0 -log = logging.getLogger("airflow.processor_manager") - - @dataclass(frozen=True) class DagFileInfo: """Information about a DAG file.""" @@ -135,7 +133,7 @@ def _resolve_path(instance: Any, attribute: attrs.Attribute, val: str | os.PathL @attrs.define -class DagFileProcessorManager: +class DagFileProcessorManager(LoggingMixin): """ Manage processes responsible for parsing DAGs. @@ -167,8 +165,6 @@ class DagFileProcessorManager: factory=_config_int_factory("dag_processor", "stale_dag_threshold") ) - log: logging.Logger = attrs.field(default=log, init=False) - _last_deactivate_stale_dags_time: float = attrs.field(default=0, init=False) print_stats_interval: float = attrs.field( factory=_config_int_factory("dag_processor", "print_stats_interval") diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py index 9db6d058adeb5..b79819bedcc5e 100644 --- a/airflow/decorators/base.py +++ b/airflow/decorators/base.py @@ -55,8 +55,6 @@ from airflow.utils.types import NOTSET if TYPE_CHECKING: - from sqlalchemy.orm import Session - from airflow.models.expandinput import ( ExpandInput, OperatorExpandArgument, @@ -184,7 +182,9 @@ def __init__( kwargs_to_upstream: dict[str, Any] | None = None, **kwargs, ) -> None: - task_id = get_unique_task_id(task_id, kwargs.get("dag"), kwargs.get("task_group")) + if not getattr(self, "_BaseOperator__from_mapped", False): + # If we are being created from calling unmap(), then don't mangle the task id + task_id = get_unique_task_id(task_id, kwargs.get("dag"), kwargs.get("task_group")) self.python_callable = python_callable kwargs_to_upstream = kwargs_to_upstream or {} op_args = op_args or [] @@ -218,10 +218,10 @@ def __init__( The function signature broke while assigning defaults to context key parameters. The decorator is replacing the signature - > {python_callable.__name__}({', '.join(str(param) for param in signature.parameters.values())}) + > {python_callable.__name__}({", ".join(str(param) for param in signature.parameters.values())}) with - > {python_callable.__name__}({', '.join(str(param) for param in parameters)}) + > {python_callable.__name__}({", ".join(str(param) for param in parameters)}) which isn't valid: {err} """ @@ -568,13 +568,11 @@ def __attrs_post_init__(self): super(DecoratedMappedOperator, DecoratedMappedOperator).__attrs_post_init__(self) XComArg.apply_upstream_relationship(self, self.op_kwargs_expand_input.value) - def _expand_mapped_kwargs( - self, context: Mapping[str, Any], session: Session, *, include_xcom: bool - ) -> tuple[Mapping[str, Any], set[int]]: + def _expand_mapped_kwargs(self, context: Mapping[str, Any]) -> tuple[Mapping[str, Any], set[int]]: # We only use op_kwargs_expand_input so this must always be empty. if self.expand_input is not EXPAND_INPUT_EMPTY: raise AssertionError(f"unexpected expand_input: {self.expand_input}") - op_kwargs, resolved_oids = super()._expand_mapped_kwargs(context, session, include_xcom=include_xcom) + op_kwargs, resolved_oids = super()._expand_mapped_kwargs(context) return {"op_kwargs": op_kwargs}, resolved_oids def _get_unmap_kwargs(self, mapped_kwargs: Mapping[str, Any], *, strict: bool) -> dict[str, Any]: diff --git a/airflow/migrations/versions/0032_3_0_0_drop_execution_date_unique.py b/airflow/migrations/versions/0032_3_0_0_rename_execution_date_to_logical_date_and_nullable.py similarity index 84% rename from airflow/migrations/versions/0032_3_0_0_drop_execution_date_unique.py rename to airflow/migrations/versions/0032_3_0_0_rename_execution_date_to_logical_date_and_nullable.py index 399cc8aff91f3..8a63d8112ac28 100644 --- a/airflow/migrations/versions/0032_3_0_0_drop_execution_date_unique.py +++ b/airflow/migrations/versions/0032_3_0_0_rename_execution_date_to_logical_date_and_nullable.py @@ -17,9 +17,9 @@ # under the License. """ -Drop ``execution_date`` unique constraint on DagRun. +Make logical_date nullable. -The column has also been renamed to logical_date, although the Python model is +The column has been renamed to logical_date, although the Python model is not changed. This allows us to not need to fix all the Python code at once, but still do the two changes in one migration instead of two. @@ -49,10 +49,15 @@ def upgrade(): "execution_date", new_column_name="logical_date", existing_type=TIMESTAMP(timezone=True), - existing_nullable=False, + nullable=True, ) + with op.batch_alter_table("dag_run", schema=None) as batch_op: batch_op.drop_constraint("dag_run_dag_id_execution_date_key", type_="unique") + batch_op.create_unique_constraint( + "dag_run_dag_id_logical_date_key", + columns=["dag_id", "logical_date"], + ) def downgrade(): @@ -61,9 +66,11 @@ def downgrade(): "logical_date", new_column_name="execution_date", existing_type=TIMESTAMP(timezone=True), - existing_nullable=False, + nullable=False, ) + with op.batch_alter_table("dag_run", schema=None) as batch_op: + batch_op.drop_constraint("dag_run_dag_id_logical_date_key", type_="unique") batch_op.create_unique_constraint( "dag_run_dag_id_execution_date_key", columns=["dag_id", "execution_date"], diff --git a/airflow/models/abstractoperator.py b/airflow/models/abstractoperator.py index b8fb54f6966fd..98fd977c59128 100644 --- a/airflow/models/abstractoperator.py +++ b/airflow/models/abstractoperator.py @@ -27,7 +27,6 @@ from airflow.configuration import conf from airflow.exceptions import AirflowException -from airflow.models.expandinput import NotFullyPopulated from airflow.sdk.definitions._internal.abstractoperator import ( AbstractOperator as TaskSDKAbstractOperator, NotMapped as NotMapped, # Re-export this for compat @@ -237,6 +236,7 @@ def expand_mapped_task(self, run_id: str, *, session: Session) -> tuple[Sequence ) from airflow.models.baseoperator import BaseOperator as DBBaseOperator + from airflow.models.expandinput import NotFullyPopulated try: total_length: int | None = DBBaseOperator.get_mapped_ti_count(self, run_id, session=session) diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 425cbaca68880..3361fe33df6f7 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -848,11 +848,20 @@ def _(cls, task: TaskSDKAbstractOperator, run_id: str, *, session: Session) -> i @get_mapped_ti_count.register(MappedOperator) @classmethod def _(cls, task: MappedOperator, run_id: str, *, session: Session) -> int: - from airflow.serialization.serialized_objects import _ExpandInputRef + from airflow.serialization.serialized_objects import BaseSerialization, _ExpandInputRef exp_input = task._get_specified_expand_input() if isinstance(exp_input, _ExpandInputRef): exp_input = exp_input.deref(task.dag) + # TODO: TaskSDK This is only needed to support `dag.test()` etc until we port it over to use the + # task sdk runner. + if not hasattr(exp_input, "get_total_map_length"): + exp_input = _ExpandInputRef( + type(exp_input).EXPAND_INPUT_TYPE, + BaseSerialization.deserialize(BaseSerialization.serialize(exp_input.value)), + ) + exp_input = exp_input.deref(task.dag) + current_count = exp_input.get_total_map_length(run_id, session=session) group = task.get_closest_mapped_task_group() @@ -878,18 +887,24 @@ def _(cls, group: TaskGroup, run_id: str, *, session: Session) -> int: :raise NotFullyPopulated: If upstream tasks are not all complete yet. :return: Total number of mapped TIs this task should have. """ + from airflow.serialization.serialized_objects import BaseSerialization, _ExpandInputRef - def iter_mapped_task_groups(group) -> Iterator[MappedTaskGroup]: + def iter_mapped_task_group_lengths(group) -> Iterator[int]: while group is not None: if isinstance(group, MappedTaskGroup): - yield group + exp_input = group._expand_input + # TODO: TaskSDK This is only needed to support `dag.test()` etc until we port it over to use the + # task sdk runner. + if not hasattr(exp_input, "get_total_map_length"): + exp_input = _ExpandInputRef( + type(exp_input).EXPAND_INPUT_TYPE, + BaseSerialization.deserialize(BaseSerialization.serialize(exp_input.value)), + ) + exp_input = exp_input.deref(group.dag) + yield exp_input.get_total_map_length(run_id, session=session) group = group.parent_group - groups = iter_mapped_task_groups(group) - return functools.reduce( - operator.mul, - (g._expand_input.get_total_map_length(run_id, session=session) for g in groups), - ) + return functools.reduce(operator.mul, iter_mapped_task_group_lengths(group)) def chain(*tasks: DependencyMixin | Sequence[DependencyMixin]) -> None: diff --git a/airflow/models/connection.py b/airflow/models/connection.py index 01df1626657da..a8b9bb87985d8 100644 --- a/airflow/models/connection.py +++ b/airflow/models/connection.py @@ -32,10 +32,10 @@ from airflow.exceptions import AirflowException, AirflowNotFoundException from airflow.models.base import ID_LEN, Base from airflow.models.crypto import get_fernet +from airflow.sdk.execution_time.secrets_masker import mask_secret from airflow.secrets.cache import SecretCache from airflow.utils.helpers import prune_dict from airflow.utils.log.logging_mixin import LoggingMixin -from airflow.utils.log.secrets_masker import mask_secret from airflow.utils.module_loading import import_string log = logging.getLogger(__name__) diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 4c54f3dba54b6..e571c016e1914 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -1756,7 +1756,7 @@ def create_dagrun( self, *, run_id: str, - logical_date: datetime, + logical_date: datetime | None, data_interval: tuple[datetime, datetime], run_after: datetime, conf: dict | None = None, diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py index 7d46c29285fe0..d9bf14c3f983b 100644 --- a/airflow/models/dagrun.py +++ b/airflow/models/dagrun.py @@ -63,7 +63,6 @@ from airflow.models.backfill import Backfill from airflow.models.base import Base, StringID from airflow.models.dag_version import DagVersion -from airflow.models.expandinput import NotFullyPopulated from airflow.models.taskinstance import TaskInstance as TI from airflow.models.tasklog import LogTemplate from airflow.models.taskmap import TaskMap @@ -135,7 +134,7 @@ class DagRun(Base, LoggingMixin): id = Column(Integer, primary_key=True) dag_id = Column(StringID(), nullable=False) queued_at = Column(UtcDateTime) - logical_date = Column(UtcDateTime, default=timezone.utcnow, nullable=False) + logical_date = Column(UtcDateTime, nullable=True) start_date = Column(UtcDateTime) end_date = Column(UtcDateTime) _state = Column("state", String(50), default=DagRunState.QUEUED) @@ -186,6 +185,7 @@ class DagRun(Base, LoggingMixin): __table_args__ = ( Index("dag_id_state", dag_id, _state), UniqueConstraint("dag_id", "run_id", name="dag_run_dag_id_run_id_key"), + UniqueConstraint("dag_id", "logical_date", name="dag_run_dag_id_logical_date_key"), Index("idx_dag_run_dag_id", dag_id), Index("idx_dag_run_run_after", run_after), Index( @@ -792,6 +792,8 @@ def get_previous_dagrun( :param session: SQLAlchemy ORM Session :param state: the dag run state """ + if dag_run.logical_date is None: + return None filters = [ DagRun.dag_id == dag_run.dag_id, DagRun.logical_date < dag_run.logical_date, @@ -1321,8 +1323,12 @@ def verify_integrity(self, *, session: Session = NEW_SESSION) -> None: def task_filter(task: Operator) -> bool: return task.task_id not in task_ids and ( self.run_type == DagRunType.BACKFILL_JOB - or (task.start_date is None or task.start_date <= self.logical_date) - and (task.end_date is None or self.logical_date <= task.end_date) + or ( + task.start_date is None + or self.logical_date is None + or task.start_date <= self.logical_date + ) + and (task.end_date is None or self.logical_date is None or self.logical_date <= task.end_date) ) created_counts: dict[str, int] = defaultdict(int) @@ -1347,6 +1353,7 @@ def _check_for_removed_or_restored_tasks( """ from airflow.models.baseoperator import BaseOperator + from airflow.models.expandinput import NotFullyPopulated tis = self.get_task_instances(session=session) @@ -1484,6 +1491,7 @@ def _create_tasks( :param task_creator: Function to create task instances """ from airflow.models.baseoperator import BaseOperator + from airflow.models.expandinput import NotFullyPopulated map_indexes: Iterable[int] for task in tasks: @@ -1555,6 +1563,7 @@ def _revise_map_indexes_if_mapped(self, task: Operator, *, session: Session) -> for more details. """ from airflow.models.baseoperator import BaseOperator + from airflow.models.expandinput import NotFullyPopulated from airflow.settings import task_instance_mutation_hook try: diff --git a/airflow/models/expandinput.py b/airflow/models/expandinput.py index 8fb35f7032965..72f7b0eca22ba 100644 --- a/airflow/models/expandinput.py +++ b/airflow/models/expandinput.py @@ -17,108 +17,54 @@ # under the License. from __future__ import annotations -import collections.abc import functools import operator -from collections.abc import Iterable, Mapping, Sequence, Sized -from typing import TYPE_CHECKING, Any, NamedTuple, Union +from collections.abc import Iterable, Sized +from typing import TYPE_CHECKING, Any -import attr - -from airflow.sdk.definitions._internal.mixins import ResolveMixin -from airflow.utils.session import NEW_SESSION, provide_session +import attrs if TYPE_CHECKING: from sqlalchemy.orm import Session - from airflow.models.xcom_arg import XComArg - from airflow.sdk.types import Operator - from airflow.serialization.serialized_objects import _ExpandInputRef + from airflow.models.xcom_arg import SchedulerXComArg from airflow.typing_compat import TypeGuard -ExpandInput = Union["DictOfListsExpandInput", "ListOfDictsExpandInput"] - -# Each keyword argument to expand() can be an XComArg, sequence, or dict (not -# any mapping since we need the value to be ordered). -OperatorExpandArgument = Union["MappedArgument", "XComArg", Sequence, dict[str, Any]] - -# The single argument of expand_kwargs() can be an XComArg, or a list with each -# element being either an XComArg or a dict. -OperatorExpandKwargsArgument = Union["XComArg", Sequence[Union["XComArg", Mapping[str, Any]]]] - - -@attr.define(kw_only=True) -class MappedArgument(ResolveMixin): - """ - Stand-in stub for task-group-mapping arguments. - - This is very similar to an XComArg, but resolved differently. Declared here - (instead of in the task group module) to avoid import cycles. - """ - - _input: ExpandInput - _key: str - - def iter_references(self) -> Iterable[tuple[Operator, str]]: - yield from self._input.iter_references() - - @provide_session - def resolve( - self, context: Mapping[str, Any], *, include_xcom: bool = True, session: Session = NEW_SESSION - ) -> Any: - data, _ = self._input.resolve(context, session=session, include_xcom=include_xcom) - return data[self._key] - - -# To replace tedious isinstance() checks. -def is_mappable(v: Any) -> TypeGuard[OperatorExpandArgument]: - from airflow.models.xcom_arg import XComArg - - return isinstance(v, (MappedArgument, XComArg, Mapping, Sequence)) and not isinstance(v, str) - - -# To replace tedious isinstance() checks. -def _is_parse_time_mappable(v: OperatorExpandArgument) -> TypeGuard[Mapping | Sequence]: - from airflow.models.xcom_arg import XComArg - - return not isinstance(v, (MappedArgument, XComArg)) - - -# To replace tedious isinstance() checks. -def _needs_run_time_resolution(v: OperatorExpandArgument) -> TypeGuard[MappedArgument | XComArg]: - from airflow.models.xcom_arg import XComArg - - return isinstance(v, (MappedArgument, XComArg)) - - -class NotFullyPopulated(RuntimeError): - """ - Raise when ``get_map_lengths`` cannot populate all mapping metadata. - - This is generally due to not all upstream tasks have finished when the - function is called. - """ +from airflow.sdk.definitions._internal.expandinput import ( + DictOfListsExpandInput, + ExpandInput, + ListOfDictsExpandInput, + MappedArgument, + NotFullyPopulated, + OperatorExpandArgument, + OperatorExpandKwargsArgument, + is_mappable, +) - def __init__(self, missing: set[str]) -> None: - self.missing = missing +__all__ = [ + "DictOfListsExpandInput", + "ListOfDictsExpandInput", + "MappedArgument", + "NotFullyPopulated", + "OperatorExpandArgument", + "OperatorExpandKwargsArgument", + "is_mappable", +] - def __str__(self) -> str: - keys = ", ".join(repr(k) for k in sorted(self.missing)) - return f"Failed to populate all mapping metadata; missing: {keys}" +def _needs_run_time_resolution(v: OperatorExpandArgument) -> TypeGuard[MappedArgument | SchedulerXComArg]: + from airflow.models.xcom_arg import SchedulerXComArg -class DictOfListsExpandInput(NamedTuple): - """ - Storage type of a mapped operator's mapped kwargs. + return isinstance(v, (MappedArgument, SchedulerXComArg)) - This is created from ``expand(**kwargs)``. - """ - value: dict[str, OperatorExpandArgument] +@attrs.define +class SchedulerDictOfListsExpandInput: + value: dict def _iter_parse_time_resolved_kwargs(self) -> Iterable[tuple[str, Sized]]: """Generate kwargs with values available on parse-time.""" - return ((k, v) for k, v in self.value.items() if _is_parse_time_mappable(v)) + return ((k, v) for k, v in self.value.items() if not _needs_run_time_resolution(v)) def get_parse_time_mapped_ti_count(self) -> int: if not self.value: @@ -136,13 +82,12 @@ def _get_map_lengths(self, run_id: str, *, session: Session) -> dict[str, int]: If any arguments are not known right now (upstream task not finished), they will not be present in the dict. """ + from airflow.models.xcom_arg import SchedulerXComArg, get_task_map_length # TODO: This initiates one database call for each XComArg. Would it be # more efficient to do one single db call and unpack the value here? def _get_length(v: OperatorExpandArgument) -> int | None: - from airflow.models.xcom_arg import get_task_map_length - - if _needs_run_time_resolution(v): + if isinstance(v, SchedulerXComArg): return get_task_map_length(v, run_id, session=session) # Unfortunately a user-defined TypeGuard cannot apply negative type @@ -164,150 +109,34 @@ def get_total_map_length(self, run_id: str, *, session: Session) -> int: lengths = self._get_map_lengths(run_id, session=session) return functools.reduce(operator.mul, (lengths[name] for name in self.value), 1) - def _expand_mapped_field( - self, key: str, value: Any, context: Mapping[str, Any], *, session: Session, include_xcom: bool - ) -> Any: - if _needs_run_time_resolution(value): - value = ( - value.resolve(context, session=session, include_xcom=include_xcom) - if include_xcom - else str(value) - ) - map_index = context["ti"].map_index - if map_index < 0: - raise RuntimeError("can't resolve task-mapping argument without expanding") - all_lengths = self._get_map_lengths(context["run_id"], session=session) - - def _find_index_for_this_field(index: int) -> int: - # Need to use the original user input to retain argument order. - for mapped_key in reversed(self.value): - mapped_length = all_lengths[mapped_key] - if mapped_length < 1: - raise RuntimeError(f"cannot expand field mapped to length {mapped_length!r}") - if mapped_key == key: - return index % mapped_length - index //= mapped_length - return -1 - - found_index = _find_index_for_this_field(map_index) - if found_index < 0: - return value - if isinstance(value, collections.abc.Sequence): - return value[found_index] - if not isinstance(value, dict): - raise TypeError(f"can't map over value of type {type(value)}") - for i, (k, v) in enumerate(value.items()): - if i == found_index: - return k, v - raise IndexError(f"index {map_index} is over mapped length") - - def iter_references(self) -> Iterable[tuple[Operator, str]]: - from airflow.models.xcom_arg import XComArg - - for x in self.value.values(): - if isinstance(x, XComArg): - yield from x.iter_references() - - def resolve( - self, context: Mapping[str, Any], session: Session, *, include_xcom: bool = True - ) -> tuple[Mapping[str, Any], set[int]]: - data = { - k: self._expand_mapped_field(k, v, context, session=session, include_xcom=include_xcom) - for k, v in self.value.items() - } - literal_keys = {k for k, _ in self._iter_parse_time_resolved_kwargs()} - resolved_oids = {id(v) for k, v in data.items() if k not in literal_keys} - return data, resolved_oids - - -def _describe_type(value: Any) -> str: - if value is None: - return "None" - return type(value).__name__ - - -class ListOfDictsExpandInput(NamedTuple): - """ - Storage type of a mapped operator's mapped kwargs. - - This is created from ``expand_kwargs(xcom_arg)``. - """ - value: OperatorExpandKwargsArgument +@attrs.define +class SchedulerListOfDictsExpandInput: + value: list def get_parse_time_mapped_ti_count(self) -> int: - if isinstance(self.value, collections.abc.Sized): + if isinstance(self.value, Sized): return len(self.value) raise NotFullyPopulated({"expand_kwargs() argument"}) def get_total_map_length(self, run_id: str, *, session: Session) -> int: from airflow.models.xcom_arg import get_task_map_length - if isinstance(self.value, collections.abc.Sized): + if isinstance(self.value, Sized): return len(self.value) length = get_task_map_length(self.value, run_id, session=session) if length is None: raise NotFullyPopulated({"expand_kwargs() argument"}) return length - def iter_references(self) -> Iterable[tuple[Operator, str]]: - from airflow.models.xcom_arg import XComArg - - if isinstance(self.value, XComArg): - yield from self.value.iter_references() - else: - for x in self.value: - if isinstance(x, XComArg): - yield from x.iter_references() - - def resolve( - self, context: Mapping[str, Any], session: Session, *, include_xcom: bool = True - ) -> tuple[Mapping[str, Any], set[int]]: - map_index = context["ti"].map_index - if map_index < 0: - raise RuntimeError("can't resolve task-mapping argument without expanding") - - mapping: Any - if isinstance(self.value, collections.abc.Sized): - mapping = self.value[map_index] - if not isinstance(mapping, collections.abc.Mapping): - mapping = mapping.resolve(context, session, include_xcom=include_xcom) - elif include_xcom: - mappings = self.value.resolve(context, session, include_xcom=include_xcom) - if not isinstance(mappings, collections.abc.Sequence): - raise ValueError(f"expand_kwargs() expects a list[dict], not {_describe_type(mappings)}") - mapping = mappings[map_index] - - if not isinstance(mapping, collections.abc.Mapping): - raise ValueError(f"expand_kwargs() expects a list[dict], not list[{_describe_type(mapping)}]") - - for key in mapping: - if not isinstance(key, str): - raise ValueError( - f"expand_kwargs() input dict keys must all be str, " - f"but {key!r} is of type {_describe_type(key)}" - ) - # filter out parse time resolved values from the resolved_oids - resolved_oids = {id(v) for k, v in mapping.items() if not _is_parse_time_mappable(v)} - - return mapping, resolved_oids - EXPAND_INPUT_EMPTY = DictOfListsExpandInput({}) # Sentinel value. _EXPAND_INPUT_TYPES = { - "dict-of-lists": DictOfListsExpandInput, - "list-of-dicts": ListOfDictsExpandInput, + "dict-of-lists": SchedulerDictOfListsExpandInput, + "list-of-dicts": SchedulerListOfDictsExpandInput, } -def get_map_type_key(expand_input: ExpandInput | _ExpandInputRef) -> str: - from airflow.serialization.serialized_objects import _ExpandInputRef - - if isinstance(expand_input, _ExpandInputRef): - return expand_input.key - return next(k for k, v in _EXPAND_INPUT_TYPES.items() if isinstance(expand_input, v)) - - def create_expand_input(kind: str, value: Any) -> ExpandInput: return _EXPAND_INPUT_TYPES[kind](value) diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py index e7352ad1323d3..9b0b90b5814f5 100644 --- a/airflow/models/mappedoperator.py +++ b/airflow/models/mappedoperator.py @@ -17,10 +17,7 @@ # under the License. from __future__ import annotations -import contextlib -import copy -from collections.abc import Mapping -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING import attrs @@ -51,49 +48,6 @@ class MappedOperator(TaskSDKMappedOperator, AbstractOperator): # type: ignore[misc] # It complains about weight_rule being different """Object representing a mapped operator in a DAG.""" - def _expand_mapped_kwargs( - self, context: Mapping[str, Any], session: Session, *, include_xcom: bool - ) -> tuple[Mapping[str, Any], set[int]]: - """ - Get the kwargs to create the unmapped operator. - - This exists because taskflow operators expand against op_kwargs, not the - entire operator kwargs dict. - """ - return self._get_specified_expand_input().resolve(context, session, include_xcom=include_xcom) - - def _get_unmap_kwargs(self, mapped_kwargs: Mapping[str, Any], *, strict: bool) -> dict[str, Any]: - """ - Get init kwargs to unmap the underlying operator class. - - :param mapped_kwargs: The dict returned by ``_expand_mapped_kwargs``. - """ - if strict: - prevent_duplicates( - self.partial_kwargs, - mapped_kwargs, - fail_reason="unmappable or already specified", - ) - - # If params appears in the mapped kwargs, we need to merge it into the - # partial params, overriding existing keys. - params = copy.copy(self.params) - with contextlib.suppress(KeyError): - params.update(mapped_kwargs["params"]) - - # Ordering is significant; mapped kwargs should override partial ones, - # and the specially handled params should be respected. - return { - "task_id": self.task_id, - "dag": self.dag, - "task_group": self.task_group, - "start_date": self.start_date, - "end_date": self.end_date, - **self.partial_kwargs, - **mapped_kwargs, - "params": params, - } - def expand_start_from_trigger(self, *, context: Context, session: Session) -> bool: """ Get the start_from_trigger value of the current abstract operator. @@ -107,7 +61,7 @@ def expand_start_from_trigger(self, *, context: Context, session: Session) -> bo if not self.start_trigger_args: return False - mapped_kwargs, _ = self._expand_mapped_kwargs(context, session, include_xcom=False) + mapped_kwargs, _ = self._expand_mapped_kwargs(context) if self._disallow_kwargs_override: prevent_duplicates( self.partial_kwargs, @@ -129,7 +83,7 @@ def expand_start_trigger_args(self, *, context: Context, session: Session) -> St if not self.start_trigger_args: return None - mapped_kwargs, _ = self._expand_mapped_kwargs(context, session, include_xcom=False) + mapped_kwargs, _ = self._expand_mapped_kwargs(context) if self._disallow_kwargs_override: prevent_duplicates( self.partial_kwargs, diff --git a/airflow/models/renderedtifields.py b/airflow/models/renderedtifields.py index f2d7d83920fff..ea1c3af2b65b7 100644 --- a/airflow/models/renderedtifields.py +++ b/airflow/models/renderedtifields.py @@ -145,7 +145,7 @@ def __repr__(self): return prefix + ">" def _redact(self): - from airflow.utils.log.secrets_masker import redact + from airflow.sdk.execution_time.secrets_masker import redact if self.k8s_pod_yaml: self.k8s_pod_yaml = redact(self.k8s_pod_yaml) diff --git a/airflow/models/serialized_dag.py b/airflow/models/serialized_dag.py index 5da2f24957ff0..2787ce1b82993 100644 --- a/airflow/models/serialized_dag.py +++ b/airflow/models/serialized_dag.py @@ -291,7 +291,7 @@ def read_all_dags(cls, session: Session = NEW_SESSION) -> dict[str, SerializedDA @property def data(self) -> dict | None: # use __data_cache to avoid decompress and loads - if not hasattr(self, "__data_cache") or self.__data_cache is None: + if not hasattr(self, "_SerializedDagModel__data_cache") or self.__data_cache is None: if self._data_compressed: self.__data_cache = json.loads(zlib.decompress(self._data_compressed)) else: diff --git a/airflow/models/skipmixin.py b/airflow/models/skipmixin.py index 3b7d21d7b386b..78e1a631e080d 100644 --- a/airflow/models/skipmixin.py +++ b/airflow/models/skipmixin.py @@ -97,7 +97,7 @@ def skip( dag_id: str, run_id: str, tasks: Iterable[DAGNode], - map_index: int = -1, + map_index: int | None = -1, session: Session = NEW_SESSION, ): """ @@ -126,6 +126,9 @@ def skip( if task_id is not None: from airflow.models.xcom import XCom + if map_index is None: + map_index = -1 + XCom.set( key=XCOM_SKIPMIXIN_KEY, value={XCOM_SKIPMIXIN_SKIPPED: task_ids_list}, diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 7106366b67ef8..ceb0acfcf9946 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -573,39 +573,6 @@ def _xcom_pull( default: Any = None, run_id: str | None = None, ) -> Any: - """ - Pull XComs that optionally meet certain criteria. - - :param key: A key for the XCom. If provided, only XComs with matching - keys will be returned. The default key is ``'return_value'``, also - available as constant ``XCOM_RETURN_KEY``. This key is automatically - given to XComs returned by tasks (as opposed to being pushed - manually). To remove the filter, pass *None*. - :param task_ids: Only XComs from tasks with matching ids will be - pulled. Pass *None* to remove the filter. - :param dag_id: If provided, only pulls XComs from this DAG. If *None* - (default), the DAG of the calling task is used. - :param map_indexes: If provided, only pull XComs with matching indexes. - If *None* (default), this is inferred from the task(s) being pulled - (see below for details). - :param include_prior_dates: If False, only XComs from the current - logical_date are returned. If *True*, XComs from previous dates - are returned as well. - :param run_id: If provided, only pulls XComs from a DagRun w/a matching run_id. - If *None* (default), the run_id of the calling task is used. - - When pulling one single task (``task_id`` is *None* or a str) without - specifying ``map_indexes``, the return value is inferred from whether - the specified task is mapped. If not, value from the one single task - instance is returned. If the task to pull is mapped, an iterator (not a - list) yielding XComs from mapped task instances is returned. In either - case, ``default`` (*None* if not specified) is returned if no matching - XComs are found. - - When pulling multiple tasks (i.e. either ``task_id`` or ``map_index`` is - a non-str iterable), a list of matching XComs is returned. Elements in - the list is ordered by item ordering in ``task_id`` and ``map_index``. - """ if dag_id is None: dag_id = ti.dag_id if run_id is None: @@ -634,11 +601,10 @@ def _xcom_pull( return default if map_indexes is not None or first.map_index < 0: return XCom.deserialize_value(first) - return LazyXComSelectSequence.from_select( - query.with_entities(XCom.value).order_by(None).statement, - order_by=[XCom.map_index], - session=session, - ) + + # raise RuntimeError("Nothing should hit this anymore") + + # TODO: TaskSDK: We should remove this, but many tests still currently call `ti.run()`. See #45549 # At this point either task_ids or map_indexes is explicitly multi-value. # Order return values to match task_ids and map_indexes ordering. @@ -1035,14 +1001,18 @@ def get_triggering_events() -> dict[str, list[AssetEvent]]: ) context["expanded_ti_count"] = expanded_ti_count if expanded_ti_count: - context["_upstream_map_indexes"] = { # type: ignore[typeddict-unknown-key] - upstream.task_id: task_instance.get_relevant_upstream_map_indexes( - upstream, - expanded_ti_count, - session=session, - ) - for upstream in task.upstream_list - } + setattr( + task_instance, + "_upstream_map_indexes", + { + upstream.task_id: task_instance.get_relevant_upstream_map_indexes( + upstream, + expanded_ti_count, + session=session, + ) + for upstream in task.upstream_list + }, + ) except NotMapped: pass @@ -3267,7 +3237,7 @@ def get_rendered_template_fields(self, session: Session = NEW_SESSION) -> None: try: # If we get here, either the task hasn't run or the RTIF record was purged. - from airflow.utils.log.secrets_masker import redact + from airflow.sdk.execution_time.secrets_masker import redact self.render_templates() for field_name in self.task.template_fields: @@ -3380,39 +3350,8 @@ def xcom_pull( default: Any = None, run_id: str | None = None, ) -> Any: - """ - Pull XComs that optionally meet certain criteria. - - :param key: A key for the XCom. If provided, only XComs with matching - keys will be returned. The default key is ``'return_value'``, also - available as constant ``XCOM_RETURN_KEY``. This key is automatically - given to XComs returned by tasks (as opposed to being pushed - manually). To remove the filter, pass *None*. - :param task_ids: Only XComs from tasks with matching ids will be - pulled. Pass *None* to remove the filter. - :param dag_id: If provided, only pulls XComs from this DAG. If *None* - (default), the DAG of the calling task is used. - :param map_indexes: If provided, only pull XComs with matching indexes. - If *None* (default), this is inferred from the task(s) being pulled - (see below for details). - :param include_prior_dates: If False, only XComs from the current - logical_date are returned. If *True*, XComs from previous dates - are returned as well. - :param run_id: If provided, only pulls XComs from a DagRun w/a matching run_id. - If *None* (default), the run_id of the calling task is used. - - When pulling one single task (``task_id`` is *None* or a str) without - specifying ``map_indexes``, the return value is inferred from whether - the specified task is mapped. If not, value from the one single task - instance is returned. If the task to pull is mapped, an iterator (not a - list) yielding XComs from mapped task instances is returned. In either - case, ``default`` (*None* if not specified) is returned if no matching - XComs are found. - - When pulling multiple tasks (i.e. either ``task_id`` or ``map_index`` is - a non-str iterable), a list of matching XComs is returned. Elements in - the list is ordered by item ordering in ``task_id`` and ``map_index``. - """ + """:meta private:""" # noqa: D400 + # This is only kept for compatibility in tests for now while AIP-72 is in progress. return _xcom_pull( ti=self, task_ids=task_ids, diff --git a/airflow/models/taskinstancehistory.py b/airflow/models/taskinstancehistory.py index e97e6de22ec9a..d99cd34f3b88c 100644 --- a/airflow/models/taskinstancehistory.py +++ b/airflow/models/taskinstancehistory.py @@ -33,6 +33,7 @@ text, ) from sqlalchemy.ext.mutable import MutableDict +from sqlalchemy.orm import relationship from sqlalchemy_utils import UUIDType from airflow.models.base import Base, StringID @@ -94,6 +95,13 @@ class TaskInstanceHistory(Base): task_display_name = Column("task_display_name", String(2000), nullable=True) dag_version_id = Column(UUIDType(binary=False)) + dag_version = relationship( + "DagVersion", + primaryjoin="TaskInstanceHistory.dag_version_id == DagVersion.id", + viewonly=True, + foreign_keys=[dag_version_id], + ) + def __init__( self, ti: TaskInstance, diff --git a/airflow/models/variable.py b/airflow/models/variable.py index b4cf5560e8421..b4568ad09c489 100644 --- a/airflow/models/variable.py +++ b/airflow/models/variable.py @@ -28,10 +28,10 @@ from airflow.configuration import ensure_secrets_loaded from airflow.models.base import ID_LEN, Base from airflow.models.crypto import get_fernet +from airflow.sdk.execution_time.secrets_masker import mask_secret from airflow.secrets.cache import SecretCache from airflow.secrets.metastore import MetastoreBackend from airflow.utils.log.logging_mixin import LoggingMixin -from airflow.utils.log.secrets_masker import mask_secret from airflow.utils.session import provide_session if TYPE_CHECKING: diff --git a/airflow/models/xcom_arg.py b/airflow/models/xcom_arg.py index 078a9e6ff5223..1d885fb5bd1b0 100644 --- a/airflow/models/xcom_arg.py +++ b/airflow/models/xcom_arg.py @@ -17,46 +17,108 @@ from __future__ import annotations +from collections.abc import Sequence from functools import singledispatch -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any +import attrs from sqlalchemy import func, or_, select from sqlalchemy.orm import Session from airflow.sdk.definitions._internal.types import ArgNotSet from airflow.sdk.definitions.mappedoperator import MappedOperator from airflow.sdk.definitions.xcom_arg import ( - ConcatXComArg, - MapXComArg, - PlainXComArg, XComArg, - ZipXComArg, ) from airflow.utils.db import exists_query from airflow.utils.state import State +from airflow.utils.types import NOTSET from airflow.utils.xcom import XCOM_RETURN_KEY __all__ = ["XComArg", "get_task_map_length"] if TYPE_CHECKING: - from airflow.models.expandinput import OperatorExpandArgument + from airflow.models.dag import DAG as SchedulerDAG + from airflow.models.operator import Operator + from airflow.typing_compat import Self + + +@attrs.define +class SchedulerXComArg: + @classmethod + def _deserialize(cls, data: dict[str, Any], dag: SchedulerDAG) -> Self: + """ + Deserialize an XComArg. + + The implementation should be the inverse function to ``serialize``, + implementing given a data dict converted from this XComArg derivative, + how the original XComArg should be created. DAG serialization relies on + additional information added in ``serialize_xcom_arg`` to dispatch data + dicts to the correct ``_deserialize`` information, so this function does + not need to validate whether the incoming data contains correct keys. + """ + raise NotImplementedError() + + +@attrs.define +class SchedulerPlainXComArg(SchedulerXComArg): + operator: Operator + key: str + + @classmethod + def _deserialize(cls, data: dict[str, Any], dag: SchedulerDAG) -> Self: + return cls(dag.get_task(data["task_id"]), data["key"]) + + +@attrs.define +class SchedulerMapXComArg(SchedulerXComArg): + arg: SchedulerXComArg + callables: Sequence[str] + + @classmethod + def _deserialize(cls, data: dict[str, Any], dag: SchedulerDAG) -> Self: + # We are deliberately NOT deserializing the callables. These are shown + # in the UI, and displaying a function object is useless. + return cls(deserialize_xcom_arg(data["arg"], dag), data["callables"]) + + +@attrs.define +class SchedulerConcatXComArg(SchedulerXComArg): + args: Sequence[SchedulerXComArg] + + @classmethod + def _deserialize(cls, data: dict[str, Any], dag: SchedulerDAG) -> Self: + return cls([deserialize_xcom_arg(arg, dag) for arg in data["args"]]) + + +@attrs.define +class SchedulerZipXComArg(SchedulerXComArg): + args: Sequence[SchedulerXComArg] + fillvalue: Any + + @classmethod + def _deserialize(cls, data: dict[str, Any], dag: SchedulerDAG) -> Self: + return cls( + [deserialize_xcom_arg(arg, dag) for arg in data["args"]], + fillvalue=data.get("fillvalue", NOTSET), + ) @singledispatch -def get_task_map_length(xcom_arg: OperatorExpandArgument, run_id: str, *, session: Session) -> int | None: +def get_task_map_length(xcom_arg: SchedulerXComArg, run_id: str, *, session: Session) -> int | None: # The base implementation -- specific XComArg subclasses have specialised implementations - raise NotImplementedError() + raise NotImplementedError(f"get_task_map_length not implemented for {type(xcom_arg)}") @get_task_map_length.register -def _(xcom_arg: PlainXComArg, run_id: str, *, session: Session): +def _(xcom_arg: SchedulerPlainXComArg, run_id: str, *, session: Session): from airflow.models.taskinstance import TaskInstance from airflow.models.taskmap import TaskMap from airflow.models.xcom import XCom dag_id = xcom_arg.operator.dag_id task_id = xcom_arg.operator.task_id - is_mapped = isinstance(xcom_arg.operator, MappedOperator) + is_mapped = xcom_arg.operator.is_mapped or isinstance(xcom_arg.operator, MappedOperator) if is_mapped: unfinished_ti_exists = exists_query( @@ -92,12 +154,12 @@ def _(xcom_arg: PlainXComArg, run_id: str, *, session: Session): @get_task_map_length.register -def _(xcom_arg: MapXComArg, run_id: str, *, session: Session): +def _(xcom_arg: SchedulerMapXComArg, run_id: str, *, session: Session): return get_task_map_length(xcom_arg.arg, run_id, session=session) @get_task_map_length.register -def _(xcom_arg: ZipXComArg, run_id: str, *, session: Session): +def _(xcom_arg: SchedulerZipXComArg, run_id: str, *, session: Session): all_lengths = (get_task_map_length(arg, run_id, session=session) for arg in xcom_arg.args) ready_lengths = [length for length in all_lengths if length is not None] if len(ready_lengths) != len(xcom_arg.args): @@ -108,9 +170,23 @@ def _(xcom_arg: ZipXComArg, run_id: str, *, session: Session): @get_task_map_length.register -def _(xcom_arg: ConcatXComArg, run_id: str, *, session: Session): +def _(xcom_arg: SchedulerConcatXComArg, run_id: str, *, session: Session): all_lengths = (get_task_map_length(arg, run_id, session=session) for arg in xcom_arg.args) ready_lengths = [length for length in all_lengths if length is not None] if len(ready_lengths) != len(xcom_arg.args): return None # If any of the referenced XComs is not ready, we are not ready either. return sum(ready_lengths) + + +def deserialize_xcom_arg(data: dict[str, Any], dag: SchedulerDAG): + """DAG serialization interface.""" + klass = _XCOM_ARG_TYPES[data.get("type", "")] + return klass._deserialize(data, dag) + + +_XCOM_ARG_TYPES: dict[str, type[SchedulerXComArg]] = { + "": SchedulerPlainXComArg, + "concat": SchedulerConcatXComArg, + "map": SchedulerMapXComArg, + "zip": SchedulerZipXComArg, +} diff --git a/airflow/new_provider.yaml.schema.json b/airflow/new_provider.yaml.schema.json index 83f98abb038cb..4b5e16cedc0eb 100644 --- a/airflow/new_provider.yaml.schema.json +++ b/airflow/new_provider.yaml.schema.json @@ -31,6 +31,13 @@ "removed" ] }, + "excluded-python-versions": { + "description": "List of python versions excluded for that provider", + "type": "array", + "items": { + "type": "string" + } + }, "integrations": { "description": "List of integrations supported by the provider.", "type": "array", diff --git a/airflow/serialization/helpers.py b/airflow/serialization/helpers.py index 2e0e8cba41102..6cea9e1341536 100644 --- a/airflow/serialization/helpers.py +++ b/airflow/serialization/helpers.py @@ -21,8 +21,8 @@ from typing import Any from airflow.configuration import conf +from airflow.sdk.execution_time.secrets_masker import redact from airflow.settings import json -from airflow.utils.log.secrets_masker import redact def serialize_template_field(template_field: Any, name: str) -> str | dict | list | int | float: diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index 89ea668f02ffe..08d032e873c6a 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -45,10 +45,10 @@ from airflow.models.expandinput import ( EXPAND_INPUT_EMPTY, create_expand_input, - get_map_type_key, ) from airflow.models.taskinstance import SimpleTaskInstance from airflow.models.taskinstancekey import TaskInstanceKey +from airflow.models.xcom_arg import SchedulerXComArg, deserialize_xcom_arg from airflow.providers_manager import ProvidersManager from airflow.sdk.definitions.asset import ( Asset, @@ -66,7 +66,7 @@ from airflow.sdk.definitions.mappedoperator import MappedOperator from airflow.sdk.definitions.param import Param, ParamsDict from airflow.sdk.definitions.taskgroup import MappedTaskGroup, TaskGroup -from airflow.sdk.definitions.xcom_arg import XComArg, deserialize_xcom_arg, serialize_xcom_arg +from airflow.sdk.definitions.xcom_arg import XComArg, serialize_xcom_arg from airflow.sdk.execution_time.context import OutletEventAccessor, OutletEventAccessors from airflow.serialization.dag_dependency import DagDependency from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding @@ -493,7 +493,7 @@ class _XComRef(NamedTuple): data: dict - def deref(self, dag: DAG) -> XComArg: + def deref(self, dag: DAG) -> SchedulerXComArg: return deserialize_xcom_arg(self.data, dag) @@ -1195,7 +1195,7 @@ def serialize_mapped_operator(cls, op: MappedOperator) -> dict[str, Any]: if TYPE_CHECKING: # Let Mypy check the input type for us! _ExpandInputRef.validate_expand_input_value(expansion_kwargs.value) serialized_op[op._expand_input_attr] = { - "type": get_map_type_key(expansion_kwargs), + "type": type(expansion_kwargs).EXPAND_INPUT_TYPE, "value": cls.serialize(expansion_kwargs.value), } @@ -1792,7 +1792,7 @@ def serialize_task_group(cls, task_group: TaskGroup) -> dict[str, Any] | None: if isinstance(task_group, MappedTaskGroup): expand_input = task_group._expand_input encoded["expand_input"] = { - "type": get_map_type_key(expand_input), + "type": expand_input.EXPAND_INPUT_TYPE, "value": cls.serialize(expand_input.value), } encoded["is_mapped"] = True diff --git a/airflow/settings.py b/airflow/settings.py index aae6529c8b129..d8e796db7ce7e 100644 --- a/airflow/settings.py +++ b/airflow/settings.py @@ -328,7 +328,7 @@ def _is_sqlite_db_path_relative(sqla_conn_str: str) -> bool: def configure_orm(disable_connection_pool=False, pool_class=None): """Configure ORM using SQLAlchemy.""" - from airflow.utils.log.secrets_masker import mask_secret + from airflow.sdk.execution_time.secrets_masker import mask_secret if _is_sqlite_db_path_relative(SQL_ALCHEMY_CONN): from airflow.exceptions import AirflowConfigException diff --git a/airflow/timetables/trigger.py b/airflow/timetables/trigger.py index 4488a7fdaf619..bea972f4fb81f 100644 --- a/airflow/timetables/trigger.py +++ b/airflow/timetables/trigger.py @@ -17,6 +17,8 @@ from __future__ import annotations import datetime +import math +import operator from typing import TYPE_CHECKING, Any from airflow.timetables._cron import CronMixin @@ -31,6 +33,34 @@ from airflow.timetables.base import TimeRestriction +def _serialize_interval(interval: datetime.timedelta | relativedelta) -> float | dict: + from airflow.serialization.serialized_objects import encode_relativedelta + + if isinstance(interval, datetime.timedelta): + return interval.total_seconds() + return encode_relativedelta(interval) + + +def _deserialize_interval(value: int | dict) -> datetime.timedelta | relativedelta: + from airflow.serialization.serialized_objects import decode_relativedelta + + if isinstance(value, dict): + return decode_relativedelta(value) + return datetime.timedelta(seconds=value) + + +def _serialize_run_immediately(value: bool | datetime.timedelta) -> bool | float: + if isinstance(value, datetime.timedelta): + return value.total_seconds() + return value + + +def _deserialize_run_immediately(value: bool | float) -> bool | datetime.timedelta: + if isinstance(value, float): + return datetime.timedelta(seconds=value) + return value + + class CronTriggerTimetable(CronMixin, Timetable): """ Timetable that triggers DAG runs according to a cron expression. @@ -77,48 +107,23 @@ def __init__( @classmethod def deserialize(cls, data: dict[str, Any]) -> Timetable: - from airflow.serialization.serialized_objects import decode_relativedelta, decode_timezone - - interval: datetime.timedelta | relativedelta - if isinstance(data["interval"], dict): - interval = decode_relativedelta(data["interval"]) - else: - interval = datetime.timedelta(seconds=data["interval"]) - - immediate: bool | datetime.timedelta - if "immediate" not in data: - immediate = False - elif isinstance(data["immediate"], float): - immediate = datetime.timedelta(seconds=data["interval"]) - else: - immediate = data["immediate"] + from airflow.serialization.serialized_objects import decode_timezone return cls( data["expression"], timezone=decode_timezone(data["timezone"]), - interval=interval, - run_immediately=immediate, + interval=_deserialize_interval(data["interval"]), + run_immediately=_deserialize_run_immediately(data.get("run_immediately", False)), ) def serialize(self) -> dict[str, Any]: - from airflow.serialization.serialized_objects import encode_relativedelta, encode_timezone + from airflow.serialization.serialized_objects import encode_timezone - interval: float | dict[str, Any] - if isinstance(self._interval, datetime.timedelta): - interval = self._interval.total_seconds() - else: - interval = encode_relativedelta(self._interval) - timezone = encode_timezone(self._timezone) - immediate: bool | float - if isinstance(self.run_immediately, datetime.timedelta): - immediate = self.run_immediately.total_seconds() - else: - immediate = self.run_immediately return { "expression": self._expression, - "timezone": timezone, - "interval": interval, - "run_immediately": immediate, + "timezone": encode_timezone(self._timezone), + "interval": _serialize_interval(self._interval), + "run_immediately": _serialize_run_immediately(self.run_immediately), } def infer_manual_data_interval(self, *, run_after: DateTime) -> DataInterval: @@ -184,3 +189,95 @@ def _calc_first_run(self): return past_run_time else: return next_run_time + + +class MultipleCronTriggerTimetable(Timetable): + """ + Timetable that triggers DAG runs according to multiple cron expressions. + + This combines multiple ``CronTriggerTimetable`` instances underneath, and + triggers a DAG run whenever one of the timetables want to trigger a run. + + Only at most one run is triggered for any given time, even if more than one + timetable fires at the same time. + """ + + def __init__( + self, + *crons: str, + timezone: str | Timezone | FixedTimezone, + interval: datetime.timedelta | relativedelta = datetime.timedelta(), + run_immediately: bool | datetime.timedelta = False, + ) -> None: + if not crons: + raise ValueError("cron expression required") + self._timetables = [ + CronTriggerTimetable(cron, timezone=timezone, interval=interval, run_immediately=run_immediately) + for cron in crons + ] + self.description = ", ".join(t.description for t in self._timetables) + + @classmethod + def deserialize(cls, data: dict[str, Any]) -> Timetable: + from airflow.serialization.serialized_objects import decode_timezone + + return cls( + data["expressions"], + timezone=decode_timezone(data["timezone"]), + interval=_deserialize_interval(data["interval"]), + run_immediately=_deserialize_run_immediately(data["run_immediately"]), + ) + + def serialize(self) -> dict[str, Any]: + from airflow.serialization.serialized_objects import encode_timezone + + # All timetables share the same timezone, interval, and run_immediately + # values, so we can just use the first to represent them. + timetable = self._timetables[0] + return { + "expressions": [t._expression for t in self._timetables], + "timezone": encode_timezone(timetable._timezone), + "interval": _serialize_interval(timetable._interval), + "run_immediately": _serialize_run_immediately(timetable.run_immediately), + } + + @property + def summary(self) -> str: + return ", ".join(t.summary for t in self._timetables) + + def infer_manual_data_interval(self, *, run_after: DateTime) -> DataInterval: + return min( + (t.infer_manual_data_interval(run_after=run_after) for t in self._timetables), + key=operator.attrgetter("start"), + ) + + def next_dagrun_info( + self, + *, + last_automated_data_interval: DataInterval | None, + restriction: TimeRestriction, + ) -> DagRunInfo | None: + infos = ( + timetable.next_dagrun_info( + last_automated_data_interval=last_automated_data_interval, + restriction=restriction, + ) + for timetable in self._timetables + ) + return min(infos, key=self._dagrun_info_sort_key) + + @staticmethod + def _dagrun_info_sort_key(info: DagRunInfo | None) -> float: + """ + Sort key for DagRunInfo values. + + This is passed as the sort key to ``min`` in ``next_dagrun_info`` to + find the next closest run, ordered by logical date. + + The sort is done by simply returning the logical date converted to a + Unix timestamp. If the input is *None* (no next run), *inf* is returned + so it's selected last. + """ + if info is None: + return math.inf + return info.logical_date.timestamp() diff --git a/airflow/ui/openapi-gen/queries/common.ts b/airflow/ui/openapi-gen/queries/common.ts index 7a6d7eb780b24..dcb8ec827d072 100644 --- a/airflow/ui/openapi-gen/queries/common.ts +++ b/airflow/ui/openapi-gen/queries/common.ts @@ -1715,6 +1715,9 @@ export type TaskInstanceServicePatchTaskInstanceDryRun1MutationResult = Awaited< >; export type PoolServicePatchPoolMutationResult = Awaited>; export type PoolServiceBulkPoolsMutationResult = Awaited>; +export type XcomServiceUpdateXcomEntryMutationResult = Awaited< + ReturnType +>; export type VariableServicePatchVariableMutationResult = Awaited< ReturnType >; diff --git a/airflow/ui/openapi-gen/queries/queries.ts b/airflow/ui/openapi-gen/queries/queries.ts index d838c9ccf5e5d..5a71695254bfa 100644 --- a/airflow/ui/openapi-gen/queries/queries.ts +++ b/airflow/ui/openapi-gen/queries/queries.ts @@ -52,6 +52,7 @@ import { TriggerDAGRunPostBody, VariableBody, XComCreateBody, + XComUpdateBody, } from "../requests/types.gen"; import * as Common from "./common"; @@ -3982,6 +3983,61 @@ export const usePoolServiceBulkPools = < mutationFn: ({ requestBody }) => PoolService.bulkPools({ requestBody }) as unknown as Promise, ...options, }); +/** + * Update Xcom Entry + * Update an existing XCom entry. + * @param data The data for the request. + * @param data.dagId + * @param data.taskId + * @param data.dagRunId + * @param data.xcomKey + * @param data.requestBody + * @returns XComResponseNative Successful Response + * @throws ApiError + */ +export const useXcomServiceUpdateXcomEntry = < + TData = Common.XcomServiceUpdateXcomEntryMutationResult, + TError = unknown, + TContext = unknown, +>( + options?: Omit< + UseMutationOptions< + TData, + TError, + { + dagId: string; + dagRunId: string; + requestBody: XComUpdateBody; + taskId: string; + xcomKey: string; + }, + TContext + >, + "mutationFn" + >, +) => + useMutation< + TData, + TError, + { + dagId: string; + dagRunId: string; + requestBody: XComUpdateBody; + taskId: string; + xcomKey: string; + }, + TContext + >({ + mutationFn: ({ dagId, dagRunId, requestBody, taskId, xcomKey }) => + XcomService.updateXcomEntry({ + dagId, + dagRunId, + requestBody, + taskId, + xcomKey, + }) as unknown as Promise, + ...options, + }); /** * Patch Variable * Update a variable by key. diff --git a/airflow/ui/openapi-gen/requests/schemas.gen.ts b/airflow/ui/openapi-gen/requests/schemas.gen.ts index e5ae3ca34878d..1d688c3442d27 100644 --- a/airflow/ui/openapi-gen/requests/schemas.gen.ts +++ b/airflow/ui/openapi-gen/requests/schemas.gen.ts @@ -4746,6 +4746,16 @@ export const $TaskInstanceHistoryResponse = { type: "string", title: "Executor Config", }, + dag_version: { + anyOf: [ + { + $ref: "#/components/schemas/DagVersionResponse", + }, + { + type: "null", + }, + ], + }, }, type: "object", required: [ @@ -4772,6 +4782,7 @@ export const $TaskInstanceHistoryResponse = { "pid", "executor", "executor_config", + "dag_version", ], title: "TaskInstanceHistoryResponse", description: "TaskInstanceHistory serializer for responses.", @@ -6227,3 +6238,21 @@ export const $XComResponseString = { title: "XComResponseString", description: "XCom response serializer with string return type.", } as const; + +export const $XComUpdateBody = { + properties: { + value: { + title: "Value", + }, + map_index: { + type: "integer", + title: "Map Index", + default: -1, + }, + }, + additionalProperties: false, + type: "object", + required: ["value"], + title: "XComUpdateBody", + description: "Payload serializer for updating an XCom entry.", +} as const; diff --git a/airflow/ui/openapi-gen/requests/services.gen.ts b/airflow/ui/openapi-gen/requests/services.gen.ts index 69ab77337a1f9..585796aa2b2db 100644 --- a/airflow/ui/openapi-gen/requests/services.gen.ts +++ b/airflow/ui/openapi-gen/requests/services.gen.ts @@ -175,6 +175,8 @@ import type { GetProvidersResponse, GetXcomEntryData, GetXcomEntryResponse, + UpdateXcomEntryData, + UpdateXcomEntryResponse, GetXcomEntriesData, GetXcomEntriesResponse, CreateXcomEntryData, @@ -2981,6 +2983,40 @@ export class XcomService { }); } + /** + * Update Xcom Entry + * Update an existing XCom entry. + * @param data The data for the request. + * @param data.dagId + * @param data.taskId + * @param data.dagRunId + * @param data.xcomKey + * @param data.requestBody + * @returns XComResponseNative Successful Response + * @throws ApiError + */ + public static updateXcomEntry(data: UpdateXcomEntryData): CancelablePromise { + return __request(OpenAPI, { + method: "PATCH", + url: "/public/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/xcomEntries/{xcom_key}", + path: { + dag_id: data.dagId, + task_id: data.taskId, + dag_run_id: data.dagRunId, + xcom_key: data.xcomKey, + }, + body: data.requestBody, + mediaType: "application/json", + errors: { + 400: "Bad Request", + 401: "Unauthorized", + 403: "Forbidden", + 404: "Not Found", + 422: "Validation Error", + }, + }); + } + /** * Get Xcom Entries * Get all XCom entries. diff --git a/airflow/ui/openapi-gen/requests/types.gen.ts b/airflow/ui/openapi-gen/requests/types.gen.ts index 1283a77fe5548..a63ec17ffadcf 100644 --- a/airflow/ui/openapi-gen/requests/types.gen.ts +++ b/airflow/ui/openapi-gen/requests/types.gen.ts @@ -1239,6 +1239,7 @@ export type TaskInstanceHistoryResponse = { pid: number | null; executor: string | null; executor_config: string; + dag_version: DagVersionResponse | null; }; /** @@ -1541,6 +1542,14 @@ export type XComResponseString = { value: string | null; }; +/** + * Payload serializer for updating an XCom entry. + */ +export type XComUpdateBody = { + value: unknown; + map_index?: number; +}; + export type NextRunAssetsData = { dagId: string; }; @@ -2319,6 +2328,16 @@ export type GetXcomEntryData = { export type GetXcomEntryResponse = XComResponseNative | XComResponseString; +export type UpdateXcomEntryData = { + dagId: string; + dagRunId: string; + requestBody: XComUpdateBody; + taskId: string; + xcomKey: string; +}; + +export type UpdateXcomEntryResponse = XComResponseNative; + export type GetXcomEntriesData = { dagId: string; dagRunId: string; @@ -4720,6 +4739,35 @@ export type $OpenApiTs = { 422: HTTPValidationError; }; }; + patch: { + req: UpdateXcomEntryData; + res: { + /** + * Successful Response + */ + 200: XComResponseNative; + /** + * Bad Request + */ + 400: HTTPExceptionResponse; + /** + * Unauthorized + */ + 401: HTTPExceptionResponse; + /** + * Forbidden + */ + 403: HTTPExceptionResponse; + /** + * Not Found + */ + 404: HTTPExceptionResponse; + /** + * Validation Error + */ + 422: HTTPValidationError; + }; + }; }; "/public/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/xcomEntries": { get: { diff --git a/airflow/ui/package.json b/airflow/ui/package.json index 322623e9db68b..4a816fc06da45 100644 --- a/airflow/ui/package.json +++ b/airflow/ui/package.json @@ -82,7 +82,7 @@ "typescript-eslint": "^8.5.0", "vite": "^5.4.12", "vite-plugin-css-injected-by-js": "^3.5.2", - "vitest": "^2.1.1", + "vitest": "^2.1.9", "web-worker": "^1.3.0" } } diff --git a/airflow/ui/pnpm-lock.yaml b/airflow/ui/pnpm-lock.yaml index 20ae58594252c..20ac155c03cad 100644 --- a/airflow/ui/pnpm-lock.yaml +++ b/airflow/ui/pnpm-lock.yaml @@ -149,7 +149,7 @@ importers: version: 3.7.0(@swc/helpers@0.5.13)(vite@5.4.12(@types/node@22.5.4)) '@vitest/coverage-v8': specifier: ^2.1.1 - version: 2.1.1(vitest@2.1.1(@types/node@22.5.4)(happy-dom@15.10.2)(msw@2.7.0(@types/node@22.5.4)(typescript@5.5.4))) + version: 2.1.1(vitest@2.1.9(@types/node@22.5.4)(happy-dom@15.10.2)(msw@2.7.0(@types/node@22.5.4)(typescript@5.5.4))) eslint: specifier: ^9.10.0 version: 9.10.0(jiti@1.21.6) @@ -202,8 +202,8 @@ importers: specifier: ^3.5.2 version: 3.5.2(vite@5.4.12(@types/node@22.5.4)) vitest: - specifier: ^2.1.1 - version: 2.1.1(@types/node@22.5.4)(happy-dom@15.10.2)(msw@2.7.0(@types/node@22.5.4)(typescript@5.5.4)) + specifier: ^2.1.9 + version: 2.1.9(@types/node@22.5.4)(happy-dom@15.10.2)(msw@2.7.0(@types/node@22.5.4)(typescript@5.5.4)) web-worker: specifier: ^1.3.0 version: 1.3.0 @@ -1368,14 +1368,13 @@ packages: '@vitest/browser': optional: true - '@vitest/expect@2.1.1': - resolution: {integrity: sha512-YeueunS0HiHiQxk+KEOnq/QMzlUuOzbU1Go+PgAsHvvv3tUkJPm9xWt+6ITNTlzsMXUjmgm5T+U7KBPK2qQV6w==} + '@vitest/expect@2.1.9': + resolution: {integrity: sha512-UJCIkTBenHeKT1TTlKMJWy1laZewsRIzYighyYiJKZreqtdxSos/S1t+ktRMQWu2CKqaarrkeszJx1cgC5tGZw==} - '@vitest/mocker@2.1.1': - resolution: {integrity: sha512-LNN5VwOEdJqCmJ/2XJBywB11DLlkbY0ooDJW3uRX5cZyYCrc4PI/ePX0iQhE3BiEGiQmK4GE7Q/PqCkkaiPnrA==} + '@vitest/mocker@2.1.9': + resolution: {integrity: sha512-tVL6uJgoUdi6icpxmdrn5YNo3g3Dxv+IHJBr0GXHaEdTcw3F+cPKnsXFhli6nO+f/6SDKPHEK1UN+k+TQv0Ehg==} peerDependencies: - '@vitest/spy': 2.1.1 - msw: ^2.3.5 + msw: ^2.4.9 vite: ^5.0.0 peerDependenciesMeta: msw: @@ -1383,20 +1382,20 @@ packages: vite: optional: true - '@vitest/pretty-format@2.1.1': - resolution: {integrity: sha512-SjxPFOtuINDUW8/UkElJYQSFtnWX7tMksSGW0vfjxMneFqxVr8YJ979QpMbDW7g+BIiq88RAGDjf7en6rvLPPQ==} + '@vitest/pretty-format@2.1.9': + resolution: {integrity: sha512-KhRIdGV2U9HOUzxfiHmY8IFHTdqtOhIzCpd8WRdJiE7D/HUcZVD0EgQCVjm+Q9gkUXWgBvMmTtZgIG48wq7sOQ==} - '@vitest/runner@2.1.1': - resolution: {integrity: sha512-uTPuY6PWOYitIkLPidaY5L3t0JJITdGTSwBtwMjKzo5O6RCOEncz9PUN+0pDidX8kTHYjO0EwUIvhlGpnGpxmA==} + '@vitest/runner@2.1.9': + resolution: {integrity: sha512-ZXSSqTFIrzduD63btIfEyOmNcBmQvgOVsPNPe0jYtESiXkhd8u2erDLnMxmGrDCwHCCHE7hxwRDCT3pt0esT4g==} - '@vitest/snapshot@2.1.1': - resolution: {integrity: sha512-BnSku1WFy7r4mm96ha2FzN99AZJgpZOWrAhtQfoxjUU5YMRpq1zmHRq7a5K9/NjqonebO7iVDla+VvZS8BOWMw==} + '@vitest/snapshot@2.1.9': + resolution: {integrity: sha512-oBO82rEjsxLNJincVhLhaxxZdEtV0EFHMK5Kmx5sJ6H9L183dHECjiefOAdnqpIgT5eZwT04PoggUnW88vOBNQ==} - '@vitest/spy@2.1.1': - resolution: {integrity: sha512-ZM39BnZ9t/xZ/nF4UwRH5il0Sw93QnZXd9NAZGRpIgj0yvVwPpLd702s/Cx955rGaMlyBQkZJ2Ir7qyY48VZ+g==} + '@vitest/spy@2.1.9': + resolution: {integrity: sha512-E1B35FwzXXTs9FHNK6bDszs7mtydNi5MIfUWpceJ8Xbfb1gBMscAnwLbEu+B44ed6W3XjL9/ehLPHR1fkf1KLQ==} - '@vitest/utils@2.1.1': - resolution: {integrity: sha512-Y6Q9TsI+qJ2CC0ZKj6VBb+T8UPz593N113nnUykqwANqhgf3QkZeHFlusgKLTqrnVHbj/XDKZcDHol+dxVT+rQ==} + '@vitest/utils@2.1.9': + resolution: {integrity: sha512-v0psaMSkNJ3A2NMrUEHFRzJtDPFn+/VWZ5WxImB21T9fjucJRmS7xCS3ppEnARb9y11OAzaD+P2Ps+b+BGX5iQ==} '@xyflow/react@12.3.5': resolution: {integrity: sha512-wAYqpicdrVo1rxCu0X3M9s3YIF45Agqfabw0IBryTGqjWvr2NyfciI8gIP4MB+NKpWWN5kxZ9tiZ9u8lwC7iAg==} @@ -1798,8 +1797,8 @@ packages: ccount@2.0.1: resolution: {integrity: sha512-eyrF0jiFpY+3drT6383f1qhkbGsLSifNAjA61IUjZjmLCWjItY6LB9ft9YhoDgwfmclB2zhu51Lc7+95b8NRAg==} - chai@5.1.1: - resolution: {integrity: sha512-pT1ZgP8rPNqUgieVaEY+ryQr6Q4HXNg8Ei9UnLUrjN4IA7dvQC5JB+/kxVcPNDHyBcc/26CXPkbNzq3qwrOEKA==} + chai@5.1.2: + resolution: {integrity: sha512-aGtmf24DW6MLHHG5gCx4zaI3uBq3KRtxeVs0DjFH6Z0rDNbsvTxFASFvdj79pxjxZ8/5u3PIiN3IwEIQkiiuPw==} engines: {node: '>=12'} chakra-react-select@6.0.0-next.2: @@ -2068,6 +2067,15 @@ packages: supports-color: optional: true + debug@4.4.0: + resolution: {integrity: sha512-6WTZ/IxCY/T6BALoZHaE4ctp9xm+Z5kY/pzYaCHRFeyVhojxlrm+46y68HA6hr0TcwEssoxNiDEUJQjfPZ/RYA==} + engines: {node: '>=6.0'} + peerDependencies: + supports-color: '*' + peerDependenciesMeta: + supports-color: + optional: true + decode-named-character-reference@1.0.2: resolution: {integrity: sha512-O8x12RzrUF8xyVcY0KJowWsmaJxQbmy0/EtnNtHRpsOcT7dFk5W598coHqBVpmWo1oQQfsCqfCmkZN5DJrZVdg==} @@ -2172,6 +2180,9 @@ packages: resolution: {integrity: sha512-zoMwbCcH5hwUkKJkT8kDIBZSz9I6mVG//+lDCinLCGov4+r7NIy0ld8o03M0cJxl2spVf6ESYVS6/gpIfq1FFw==} engines: {node: '>= 0.4'} + es-module-lexer@1.6.0: + resolution: {integrity: sha512-qqnD1yMU6tk/jnaMosogGySTZP8YtUgAffA9nMN+E/rjxcfRQ6IEk7IiozUjgxKoFHBGjTLnrHB/YC45r/59EQ==} + es-object-atoms@1.0.0: resolution: {integrity: sha512-MZ4iQ6JwHOBQjahnjwaC1ZtIBH+2ohjamzAO3oaHcXYup7qxjF2fixyH+Q71voWHeOkI2q/TnJao/KfXYIZWbw==} engines: {node: '>= 0.4'} @@ -2328,6 +2339,10 @@ packages: resolution: {integrity: sha512-VyhnebXciFV2DESc+p6B+y0LjSm0krU4OgJN44qFAhBY0TJ+1V61tYD2+wHusZ6F9n5K+vl8k0sTy7PEfV4qpg==} engines: {node: '>=16.17'} + expect-type@1.1.0: + resolution: {integrity: sha512-bFi65yM+xZgk+u/KRIpekdSYkTB5W1pEf0Lt8Q8Msh7b+eQ7LXVtIB1Bkm4fvclDEL1b2CZkMhv2mOeF8tMdkA==} + engines: {node: '>=12.0.0'} + extend@3.0.2: resolution: {integrity: sha512-fjquC59cD7CyW6urNXK0FBufkZcoiGG80wTuPujX590cB5Ttln20E2UB4S/WARVqhXffZl2LNgS+gQdPIIim/g==} @@ -2447,9 +2462,6 @@ packages: resolution: {integrity: sha512-DyFP3BM/3YHTQOCUL/w0OZHR0lpKeGrxotcHWcqNEdnltqFwXVfhEBQ94eIo34AfQpo0rGki4cyIiftY06h2Fg==} engines: {node: 6.* || 8.* || >= 10.*} - get-func-name@2.0.2: - resolution: {integrity: sha512-8vXOvuE167CtIc3OyItco7N/dpRtBbYOsPsXCz7X/PMnlGjYjSGuZJgM1Y7mmew7BKf9BqvLX2tnOVy1BBUsxQ==} - get-intrinsic@1.2.4: resolution: {integrity: sha512-5uYhsJH8VJBTv7oslg4BznJYhDoRI6waYCxMmCdnTrcCrHA/fCFKoTFz2JKKE0HdDFUF7/oQuhzumXJK7paBRQ==} engines: {node: '>= 0.4'} @@ -2898,8 +2910,8 @@ packages: resolution: {integrity: sha512-lyuxPGr/Wfhrlem2CL/UcnUc1zcqKAImBDzukY7Y5F/yQiNdko6+fRLevlw1HgMySw7f611UIY408EtxRSoK3Q==} hasBin: true - loupe@3.1.1: - resolution: {integrity: sha512-edNu/8D5MKVfGVFRhFf8aAxiTM6Wumfz5XsaatSxlD3w4R1d/WEKUTydCdPGbl9K7QG/Ca3GnDV2sIKIpXRQcw==} + loupe@3.1.3: + resolution: {integrity: sha512-kkIp7XSkP78ZxJEsSxW3712C6teJVoeHHwgo9zJ380de7IYyJ2ISlxojcH2pC5OFLewESmnRi/+XCDIEEVyoug==} lowlight@1.20.0: resolution: {integrity: sha512-8Ktj+prEb1RoCPkEOrPMYUN/nCggB7qAWe3a7OpMjWQkh3l2RD5wKRQ+o8Q8YuI9RG/xs95waaI/E6ym/7NsTw==} @@ -2918,6 +2930,9 @@ packages: magic-string@0.30.11: resolution: {integrity: sha512-+Wri9p0QHMy+545hKww7YAu5NyzF8iomPL/RQazugQ9+Ez4Ic3mERMd8ZTX5rfK944j+560ZJi8iAwgak1Ac7A==} + magic-string@0.30.17: + resolution: {integrity: sha512-sNPKHvyjVf7gyjwS4xGTaW/mCnF8wnjtifKBEhxfZ7E/S8tQ0rssrwGNn6q8JH/ohItJfSQp9mBtQYuTlH5QnA==} + magicast@0.3.5: resolution: {integrity: sha512-L0WhttDl+2BOsybvEOLK7fW3UA0OQ0IQ2d6Zl2x/a6vVRs3bAY0ECOSHHeL5jD+SbOpOCUEi0y1DgHEn9Qn1AQ==} @@ -3698,6 +3713,9 @@ packages: std-env@3.7.0: resolution: {integrity: sha512-JPbdCEQLj1w5GilpiHAx3qJvFndqybBysA3qUOnznweH4QbNYUsW/ea8QzSrnh0vNsezMMw5bcVool8lM0gwzg==} + std-env@3.8.0: + resolution: {integrity: sha512-Bc3YwwCB+OzldMxOXJIIvC6cPRWr/LxOp48CdQTOkPyk/t4JWWJbrilwBd7RJzKV8QW7tJkcgAmeuLLJugl5/w==} + stop-iteration-iterator@1.0.0: resolution: {integrity: sha512-iCGQj+0l0HOdZ2AEeBADlsRC+vsnDsZsbdSiH1yNSjcfKM7fdpCMfqAL/dwF5BLiw/XhRft/Wax6zQbhq2BcjQ==} engines: {node: '>= 0.4'} @@ -3799,11 +3817,11 @@ packages: tinybench@2.9.0: resolution: {integrity: sha512-0+DUvqWMValLmha6lr4kD8iAMK1HzV0/aKnCtWb9v9641TnP/MFb7Pc2bxoxQjTXAErryXVgUOfv2YqNllqGeg==} - tinyexec@0.3.0: - resolution: {integrity: sha512-tVGE0mVJPGb0chKhqmsoosjsS+qUnJVGJpZgsHYQcGoPlG3B51R3PouqTgEGH2Dc9jjFyOqOpix6ZHNMXp1FZg==} + tinyexec@0.3.2: + resolution: {integrity: sha512-KQQR9yN7R5+OSwaK0XQoj22pwHoTlgYqmUscPYoknOoWCWfj/5/ABTMRi69FrKU5ffPVh5QcFikpWJI/P1ocHA==} - tinypool@1.0.1: - resolution: {integrity: sha512-URZYihUbRPcGv95En+sz6MfghfIc2OJ1sv/RmhWZLouPY0/8Vo80viwPvg3dlaS9fuq7fQMEfgRRK7BBZThBEA==} + tinypool@1.0.2: + resolution: {integrity: sha512-al6n+QEANGFOMf/dmUMsuS5/r9B06uwlyNjZZql/zv8J7ybHCgoihBNORZCY2mzUuAnomQa2JdhyHKzZxPCrFA==} engines: {node: ^18.0.0 || >=20.0.0} tinyrainbow@1.2.0: @@ -4005,8 +4023,8 @@ packages: vfile@6.0.3: resolution: {integrity: sha512-KzIbH/9tXat2u30jf+smMwFCsno4wHVdNmzFyL+T/L3UGqqk6JKfVqOFOZEpZSHADH1k40ab6NUIXZq422ov3Q==} - vite-node@2.1.1: - resolution: {integrity: sha512-N/mGckI1suG/5wQI35XeR9rsMsPqKXzq1CdUndzVstBj/HvyxxGctwnK6WX43NGt5L3Z5tcRf83g4TITKJhPrA==} + vite-node@2.1.9: + resolution: {integrity: sha512-AM9aQ/IPrW/6ENLQg3AGY4K1N2TGZdR5e4gu/MmmR2xR3Ll1+dib+nook92g4TV3PXVyeyxdWwtaCAiUL0hMxA==} engines: {node: ^18.0.0 || >=20.0.0} hasBin: true @@ -4046,15 +4064,15 @@ packages: terser: optional: true - vitest@2.1.1: - resolution: {integrity: sha512-97We7/VC0e9X5zBVkvt7SGQMGrRtn3KtySFQG5fpaMlS+l62eeXRQO633AYhSTC3z7IMebnPPNjGXVGNRFlxBA==} + vitest@2.1.9: + resolution: {integrity: sha512-MSmPM9REYqDGBI8439mA4mWhV5sKmDlBKWIYbA3lRb2PTHACE0mgKwA8yQ2xq9vxDTuk4iPrECBAEW2aoFXY0Q==} engines: {node: ^18.0.0 || >=20.0.0} hasBin: true peerDependencies: '@edge-runtime/vm': '*' '@types/node': ^18.0.0 || >=20.0.0 - '@vitest/browser': 2.1.1 - '@vitest/ui': 2.1.1 + '@vitest/browser': 2.1.9 + '@vitest/ui': 2.1.9 happy-dom: '*' jsdom: '*' peerDependenciesMeta: @@ -4374,7 +4392,7 @@ snapshots: '@babel/helper-split-export-declaration': 7.24.7 '@babel/parser': 7.25.4 '@babel/types': 7.25.4 - debug: 4.3.7 + debug: 4.4.0 globals: 11.12.0 transitivePeerDependencies: - supports-color @@ -4386,7 +4404,7 @@ snapshots: '@babel/parser': 7.25.6 '@babel/template': 7.25.0 '@babel/types': 7.25.6 - debug: 4.3.7 + debug: 4.4.0 globals: 11.12.0 transitivePeerDependencies: - supports-color @@ -5182,7 +5200,7 @@ snapshots: '@typescript-eslint/types': 8.5.0 '@typescript-eslint/typescript-estree': 8.5.0(typescript@5.5.4) '@typescript-eslint/visitor-keys': 8.5.0 - debug: 4.3.7 + debug: 4.4.0 eslint: 9.10.0(jiti@1.21.6) optionalDependencies: typescript: 5.5.4 @@ -5203,7 +5221,7 @@ snapshots: dependencies: '@typescript-eslint/typescript-estree': 8.5.0(typescript@5.5.4) '@typescript-eslint/utils': 8.5.0(eslint@9.10.0(jiti@1.21.6))(typescript@5.5.4) - debug: 4.3.7 + debug: 4.4.0 ts-api-utils: 1.3.0(typescript@5.5.4) optionalDependencies: typescript: 5.5.4 @@ -5219,7 +5237,7 @@ snapshots: dependencies: '@typescript-eslint/types': 8.0.0-alpha.30 '@typescript-eslint/visitor-keys': 8.0.0-alpha.30 - debug: 4.3.6 + debug: 4.4.0 globby: 11.1.0 is-glob: 4.0.3 minimatch: 9.0.5 @@ -5234,7 +5252,7 @@ snapshots: dependencies: '@typescript-eslint/types': 8.5.0 '@typescript-eslint/visitor-keys': 8.5.0 - debug: 4.3.7 + debug: 4.4.0 fast-glob: 3.3.2 is-glob: 4.0.3 minimatch: 9.0.5 @@ -5696,7 +5714,7 @@ snapshots: transitivePeerDependencies: - '@swc/helpers' - '@vitest/coverage-v8@2.1.1(vitest@2.1.1(@types/node@22.5.4)(happy-dom@15.10.2)(msw@2.7.0(@types/node@22.5.4)(typescript@5.5.4)))': + '@vitest/coverage-v8@2.1.1(vitest@2.1.9(@types/node@22.5.4)(happy-dom@15.10.2)(msw@2.7.0(@types/node@22.5.4)(typescript@5.5.4)))': dependencies: '@ampproject/remapping': 2.3.0 '@bcoe/v8-coverage': 0.2.3 @@ -5710,49 +5728,49 @@ snapshots: std-env: 3.7.0 test-exclude: 7.0.1 tinyrainbow: 1.2.0 - vitest: 2.1.1(@types/node@22.5.4)(happy-dom@15.10.2)(msw@2.7.0(@types/node@22.5.4)(typescript@5.5.4)) + vitest: 2.1.9(@types/node@22.5.4)(happy-dom@15.10.2)(msw@2.7.0(@types/node@22.5.4)(typescript@5.5.4)) transitivePeerDependencies: - supports-color - '@vitest/expect@2.1.1': + '@vitest/expect@2.1.9': dependencies: - '@vitest/spy': 2.1.1 - '@vitest/utils': 2.1.1 - chai: 5.1.1 + '@vitest/spy': 2.1.9 + '@vitest/utils': 2.1.9 + chai: 5.1.2 tinyrainbow: 1.2.0 - '@vitest/mocker@2.1.1(@vitest/spy@2.1.1)(msw@2.7.0(@types/node@22.5.4)(typescript@5.5.4))(vite@5.4.12(@types/node@22.5.4))': + '@vitest/mocker@2.1.9(msw@2.7.0(@types/node@22.5.4)(typescript@5.5.4))(vite@5.4.12(@types/node@22.5.4))': dependencies: - '@vitest/spy': 2.1.1 + '@vitest/spy': 2.1.9 estree-walker: 3.0.3 - magic-string: 0.30.11 + magic-string: 0.30.17 optionalDependencies: msw: 2.7.0(@types/node@22.5.4)(typescript@5.5.4) vite: 5.4.12(@types/node@22.5.4) - '@vitest/pretty-format@2.1.1': + '@vitest/pretty-format@2.1.9': dependencies: tinyrainbow: 1.2.0 - '@vitest/runner@2.1.1': + '@vitest/runner@2.1.9': dependencies: - '@vitest/utils': 2.1.1 + '@vitest/utils': 2.1.9 pathe: 1.1.2 - '@vitest/snapshot@2.1.1': + '@vitest/snapshot@2.1.9': dependencies: - '@vitest/pretty-format': 2.1.1 - magic-string: 0.30.11 + '@vitest/pretty-format': 2.1.9 + magic-string: 0.30.17 pathe: 1.1.2 - '@vitest/spy@2.1.1': + '@vitest/spy@2.1.9': dependencies: tinyspy: 3.0.2 - '@vitest/utils@2.1.1': + '@vitest/utils@2.1.9': dependencies: - '@vitest/pretty-format': 2.1.1 - loupe: 3.1.1 + '@vitest/pretty-format': 2.1.9 + loupe: 3.1.3 tinyrainbow: 1.2.0 '@xyflow/react@12.3.5(@types/react@18.3.5)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)': @@ -6474,12 +6492,12 @@ snapshots: ccount@2.0.1: {} - chai@5.1.1: + chai@5.1.2: dependencies: assertion-error: 2.0.1 check-error: 2.1.1 deep-eql: 5.0.2 - loupe: 3.1.1 + loupe: 3.1.3 pathval: 2.0.0 chakra-react-select@6.0.0-next.2(@chakra-ui/react@3.1.1(@emotion/react@11.13.3(@types/react@18.3.5)(react@18.3.1))(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(@types/react@18.3.5)(next-themes@0.3.0(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react-dom@18.3.1(react@18.3.1))(react@18.3.1): @@ -6745,6 +6763,10 @@ snapshots: dependencies: ms: 2.1.3 + debug@4.4.0: + dependencies: + ms: 2.1.3 + decode-named-character-reference@1.0.2: dependencies: character-entities: 2.0.2 @@ -6921,6 +6943,8 @@ snapshots: iterator.prototype: 1.1.2 safe-array-concat: 1.1.2 + es-module-lexer@1.6.0: {} + es-object-atoms@1.0.0: dependencies: es-errors: 1.3.0 @@ -7155,6 +7179,8 @@ snapshots: signal-exit: 4.1.0 strip-final-newline: 3.0.0 + expect-type@1.1.0: {} + extend@3.0.2: {} fast-deep-equal@3.1.3: {} @@ -7287,8 +7313,6 @@ snapshots: get-caller-file@2.0.5: {} - get-func-name@2.0.2: {} - get-intrinsic@1.2.4: dependencies: es-errors: 1.3.0 @@ -7734,9 +7758,7 @@ snapshots: dependencies: js-tokens: 4.0.0 - loupe@3.1.1: - dependencies: - get-func-name: 2.0.2 + loupe@3.1.3: {} lowlight@1.20.0: dependencies: @@ -7753,6 +7775,10 @@ snapshots: dependencies: '@jridgewell/sourcemap-codec': 1.5.0 + magic-string@0.30.17: + dependencies: + '@jridgewell/sourcemap-codec': 1.5.0 + magicast@0.3.5: dependencies: '@babel/parser': 7.25.6 @@ -8096,7 +8122,7 @@ snapshots: micromark@4.0.1: dependencies: '@types/debug': 4.1.12 - debug: 4.3.7 + debug: 4.4.0 decode-named-character-reference: 1.0.2 devlop: 1.1.0 micromark-core-commonmark: 2.0.2 @@ -8817,6 +8843,8 @@ snapshots: std-env@3.7.0: {} + std-env@3.8.0: {} + stop-iteration-iterator@1.0.0: dependencies: internal-slot: 1.0.7 @@ -8944,9 +8972,9 @@ snapshots: tinybench@2.9.0: {} - tinyexec@0.3.0: {} + tinyexec@0.3.2: {} - tinypool@1.0.1: {} + tinypool@1.0.2: {} tinyrainbow@1.2.0: {} @@ -9154,10 +9182,11 @@ snapshots: '@types/unist': 3.0.3 vfile-message: 4.0.2 - vite-node@2.1.1(@types/node@22.5.4): + vite-node@2.1.9(@types/node@22.5.4): dependencies: cac: 6.7.14 - debug: 4.3.7 + debug: 4.4.0 + es-module-lexer: 1.6.0 pathe: 1.1.2 vite: 5.4.12(@types/node@22.5.4) transitivePeerDependencies: @@ -9184,26 +9213,27 @@ snapshots: '@types/node': 22.5.4 fsevents: 2.3.3 - vitest@2.1.1(@types/node@22.5.4)(happy-dom@15.10.2)(msw@2.7.0(@types/node@22.5.4)(typescript@5.5.4)): - dependencies: - '@vitest/expect': 2.1.1 - '@vitest/mocker': 2.1.1(@vitest/spy@2.1.1)(msw@2.7.0(@types/node@22.5.4)(typescript@5.5.4))(vite@5.4.12(@types/node@22.5.4)) - '@vitest/pretty-format': 2.1.1 - '@vitest/runner': 2.1.1 - '@vitest/snapshot': 2.1.1 - '@vitest/spy': 2.1.1 - '@vitest/utils': 2.1.1 - chai: 5.1.1 - debug: 4.3.7 - magic-string: 0.30.11 + vitest@2.1.9(@types/node@22.5.4)(happy-dom@15.10.2)(msw@2.7.0(@types/node@22.5.4)(typescript@5.5.4)): + dependencies: + '@vitest/expect': 2.1.9 + '@vitest/mocker': 2.1.9(msw@2.7.0(@types/node@22.5.4)(typescript@5.5.4))(vite@5.4.12(@types/node@22.5.4)) + '@vitest/pretty-format': 2.1.9 + '@vitest/runner': 2.1.9 + '@vitest/snapshot': 2.1.9 + '@vitest/spy': 2.1.9 + '@vitest/utils': 2.1.9 + chai: 5.1.2 + debug: 4.4.0 + expect-type: 1.1.0 + magic-string: 0.30.17 pathe: 1.1.2 - std-env: 3.7.0 + std-env: 3.8.0 tinybench: 2.9.0 - tinyexec: 0.3.0 - tinypool: 1.0.1 + tinyexec: 0.3.2 + tinypool: 1.0.2 tinyrainbow: 1.2.0 vite: 5.4.12(@types/node@22.5.4) - vite-node: 2.1.1(@types/node@22.5.4) + vite-node: 2.1.9(@types/node@22.5.4) why-is-node-running: 2.3.0 optionalDependencies: '@types/node': 22.5.4 diff --git a/airflow/ui/src/components/SearchBar.tsx b/airflow/ui/src/components/SearchBar.tsx index 349b66ce9e2e6..5a55b6a2f96c0 100644 --- a/airflow/ui/src/components/SearchBar.tsx +++ b/airflow/ui/src/components/SearchBar.tsx @@ -16,11 +16,14 @@ * specific language governing permissions and limitations * under the License. */ -import { Button, Input, type ButtonProps } from "@chakra-ui/react"; -import { useState, type ChangeEvent } from "react"; +import { Button, Input, Kbd, type ButtonProps } from "@chakra-ui/react"; +import { useState, useRef, type ChangeEvent } from "react"; +import { useHotkeys } from "react-hotkeys-hook"; import { FiSearch } from "react-icons/fi"; import { useDebouncedCallback } from "use-debounce"; +import { getMetaKey } from "src/utils"; + import { CloseButton, InputGroup, type InputGroupProps } from "./ui"; const debounceDelay = 200; @@ -43,14 +46,23 @@ export const SearchBar = ({ placeHolder, }: Props) => { const handleSearchChange = useDebouncedCallback((val: string) => onChange(val), debounceDelay); - + const searchRef = useRef(null); const [value, setValue] = useState(defaultValue); + const metaKey = getMetaKey(); const onSearchChange = (event: ChangeEvent) => { setValue(event.target.value); handleSearchChange(event.target.value); }; + useHotkeys( + "mod+k", + () => { + searchRef.current?.focus(); + }, + { preventDefault: true }, + ); + return ( )} + {metaKey}+K } startElement={} @@ -83,6 +96,7 @@ export const SearchBar = ({ onChange={onSearchChange} placeholder={placeHolder} pr={150} + ref={searchRef} value={value} /> diff --git a/airflow/ui/src/constants/stateOptions.ts b/airflow/ui/src/constants/stateOptions.ts new file mode 100644 index 0000000000000..cc85cc1115ee9 --- /dev/null +++ b/airflow/ui/src/constants/stateOptions.ts @@ -0,0 +1,53 @@ +/*! + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +import { createListCollection } from "@chakra-ui/react"; + +import type { TaskInstanceState } from "openapi/requests/types.gen"; + +export const taskInstanceStateOptions = createListCollection<{ + label: string; + value: TaskInstanceState | "all" | "none"; +}>({ + items: [ + { label: "All States", value: "all" }, + { label: "Scheduled", value: "scheduled" }, + { label: "Queued", value: "queued" }, + { label: "Running", value: "running" }, + { label: "Success", value: "success" }, + { label: "Restarting", value: "restarting" }, + { label: "Failed", value: "failed" }, + { label: "Up For Retry", value: "up_for_retry" }, + { label: "Up For Reschedule", value: "up_for_reschedule" }, + { label: "Upstream failed", value: "upstream_failed" }, + { label: "Skipped", value: "skipped" }, + { label: "Deferred", value: "deferred" }, + { label: "Removed", value: "removed" }, + { label: "No Status", value: "none" }, + ], +}); + +export const dagRunStateOptions = createListCollection({ + items: [ + { label: "All States", value: "all" }, + { label: "Queued", value: "queued" }, + { label: "Running", value: "running" }, + { label: "Failed", value: "failed" }, + { label: "Success", value: "success" }, + ], +}); diff --git a/airflow/ui/src/pages/Dag/Runs/Runs.tsx b/airflow/ui/src/pages/Dag/Runs/Runs.tsx index 29a1e90fd8616..36b6195eeea39 100644 --- a/airflow/ui/src/pages/Dag/Runs/Runs.tsx +++ b/airflow/ui/src/pages/Dag/Runs/Runs.tsx @@ -16,15 +16,7 @@ * specific language governing permissions and limitations * under the License. */ -import { - Box, - createListCollection, - Flex, - HStack, - Link, - type SelectValueChangeDetails, - Text, -} from "@chakra-ui/react"; +import { Box, Flex, HStack, Link, type SelectValueChangeDetails, Text } from "@chakra-ui/react"; import type { ColumnDef } from "@tanstack/react-table"; import { useCallback } from "react"; import { useParams, Link as RouterLink, useSearchParams } from "react-router-dom"; @@ -40,6 +32,7 @@ import { RunTypeIcon } from "src/components/RunTypeIcon"; import { StateBadge } from "src/components/StateBadge"; import Time from "src/components/Time"; import { Select } from "src/components/ui"; +import { taskInstanceStateOptions as stateOptions } from "src/constants/stateOptions"; import { capitalize, getDuration, useAutoRefresh, isStatePending } from "src/utils"; const columns: Array> = [ @@ -105,16 +98,6 @@ const columns: Array> = [ }, ]; -const stateOptions = createListCollection({ - items: [ - { label: "All States", value: "all" }, - { label: "Queued", value: "queued" }, - { label: "Running", value: "running" }, - { label: "Failed", value: "failed" }, - { label: "Success", value: "success" }, - ], -}); - const STATE_PARAM = "state"; export const Runs = () => { diff --git a/airflow/ui/src/pages/Pools/Pools.tsx b/airflow/ui/src/pages/Pools/Pools.tsx index 329d41484f7e8..163f1039719be 100644 --- a/airflow/ui/src/pages/Pools/Pools.tsx +++ b/airflow/ui/src/pages/Pools/Pools.tsx @@ -17,21 +17,50 @@ * under the License. */ import { Box, Skeleton } from "@chakra-ui/react"; +import { useState } from "react"; +import { useSearchParams } from "react-router-dom"; import { usePoolServiceGetPools } from "openapi/queries"; import { ErrorAlert } from "src/components/ErrorAlert"; +import { SearchBar } from "src/components/SearchBar"; +import { type SearchParamsKeysType, SearchParamsKeys } from "src/constants/searchParams"; import PoolBar from "./PoolBar"; export const Pools = () => { - const { data, error, isLoading } = usePoolServiceGetPools(); + const [searchParams, setSearchParams] = useSearchParams(); + const { NAME_PATTERN: NAME_PATTERN_PARAM }: SearchParamsKeysType = SearchParamsKeys; + const [poolNamePattern, setPoolNamePattern] = useState(searchParams.get(NAME_PATTERN_PARAM) ?? undefined); + const { data, error, isLoading } = usePoolServiceGetPools({ + poolNamePattern: poolNamePattern ?? undefined, + }); - return isLoading ? ( - - ) : ( + const handleSearchChange = (value: string) => { + if (value) { + searchParams.set(NAME_PATTERN_PARAM, value); + } else { + searchParams.delete(NAME_PATTERN_PARAM); + } + setSearchParams(searchParams); + setPoolNamePattern(value); + }; + + return ( <> - {data?.pools.map((pool) => )} + + + {isLoading ? ( + + ) : ( + data?.pools.map((pool) => ) + )} + ); }; diff --git a/airflow/ui/src/pages/Run/Details.tsx b/airflow/ui/src/pages/Run/Details.tsx index e46f7357a8002..a063b3a5e6b64 100644 --- a/airflow/ui/src/pages/Run/Details.tsx +++ b/airflow/ui/src/pages/Run/Details.tsx @@ -118,7 +118,7 @@ export const Details = () => { {dagRun.external_trigger ? ( - Externally Trigger Source + External Trigger Source {dagRun.triggered_by} ) : undefined} diff --git a/airflow/ui/src/pages/Run/TaskInstances.tsx b/airflow/ui/src/pages/Run/TaskInstances.tsx index 7ba8be6e5bfe1..f05a83e8c56b7 100644 --- a/airflow/ui/src/pages/Run/TaskInstances.tsx +++ b/airflow/ui/src/pages/Run/TaskInstances.tsx @@ -16,14 +16,7 @@ * specific language governing permissions and limitations * under the License. */ -import { - Box, - Flex, - Link, - createListCollection, - HStack, - type SelectValueChangeDetails, -} from "@chakra-ui/react"; +import { Box, Flex, Link, HStack, type SelectValueChangeDetails } from "@chakra-ui/react"; import type { ColumnDef } from "@tanstack/react-table"; import { useCallback, useState } from "react"; import { Link as RouterLink, useParams, useSearchParams } from "react-router-dom"; @@ -40,6 +33,7 @@ import { StateBadge } from "src/components/StateBadge"; import Time from "src/components/Time"; import { Select } from "src/components/ui"; import { SearchParamsKeys, type SearchParamsKeysType } from "src/constants/searchParams"; +import { taskInstanceStateOptions as stateOptions } from "src/constants/stateOptions"; import { capitalize, getDuration, useAutoRefresh, isStatePending } from "src/utils"; import { getTaskInstanceLink } from "src/utils/links"; @@ -109,25 +103,6 @@ const columns: Array> = [ }, ]; -const stateOptions = createListCollection<{ label: string; value: TaskInstanceState | "all" | "none" }>({ - items: [ - { label: "All States", value: "all" }, - { label: "Scheduled", value: "scheduled" }, - { label: "Queued", value: "queued" }, - { label: "Running", value: "running" }, - { label: "Success", value: "success" }, - { label: "Restarting", value: "restarting" }, - { label: "Failed", value: "failed" }, - { label: "Up For Retry", value: "up_for_retry" }, - { label: "Up For Reschedule", value: "up_for_reschedule" }, - { label: "Upstream failed", value: "upstream_failed" }, - { label: "Skipped", value: "skipped" }, - { label: "Deferred", value: "deferred" }, - { label: "Removed", value: "removed" }, - { label: "No Status", value: "none" }, - ], -}); - const STATE_PARAM = "state"; export const TaskInstances = () => { diff --git a/airflow/ui/src/pages/Task/Instances.tsx b/airflow/ui/src/pages/Task/Instances.tsx index a45c7acebd156..912c588c93193 100644 --- a/airflow/ui/src/pages/Task/Instances.tsx +++ b/airflow/ui/src/pages/Task/Instances.tsx @@ -16,18 +16,21 @@ * specific language governing permissions and limitations * under the License. */ -import { Box, Link } from "@chakra-ui/react"; +import { Box, Flex, HStack, Link, type SelectValueChangeDetails } from "@chakra-ui/react"; import type { ColumnDef } from "@tanstack/react-table"; -import { Link as RouterLink, useParams } from "react-router-dom"; +import { useCallback } from "react"; +import { Link as RouterLink, useParams, useSearchParams } from "react-router-dom"; import { useTaskInstanceServiceGetTaskInstances, useTaskServiceGetTask } from "openapi/queries"; -import type { TaskInstanceResponse } from "openapi/requests/types.gen"; +import type { TaskInstanceResponse, TaskInstanceState } from "openapi/requests/types.gen"; import { DataTable } from "src/components/DataTable"; import { useTableURLState } from "src/components/DataTable/useTableUrlState"; import { ErrorAlert } from "src/components/ErrorAlert"; import { StateBadge } from "src/components/StateBadge"; import Time from "src/components/Time"; -import { getDuration } from "src/utils"; +import { Select } from "src/components/ui"; +import { taskInstanceStateOptions as stateOptions } from "src/constants/stateOptions"; +import { capitalize, getDuration } from "src/utils"; import { getTaskInstanceLink } from "src/utils/links"; const columns = (isMapped?: boolean): Array> => [ @@ -79,26 +82,94 @@ const columns = (isMapped?: boolean): Array> => }, ]; +const STATE_PARAM = "state"; + export const Instances = () => { const { dagId = "", taskId } = useParams(); + const [searchParams, setSearchParams] = useSearchParams(); const { setTableURLState, tableURLState } = useTableURLState(); const { pagination, sorting } = tableURLState; const [sort] = sorting; const orderBy = sort ? `${sort.desc ? "-" : ""}${sort.id}` : "-start_date"; + const filteredState = searchParams.getAll(STATE_PARAM); + const hasFilteredState = filteredState.length > 0; const { data: task, error: taskError, isLoading: isTaskLoading } = useTaskServiceGetTask({ dagId, taskId }); + const handleStateChange = useCallback( + ({ value }: SelectValueChangeDetails) => { + const [val, ...rest] = value; + + if ((val === undefined || val === "all") && rest.length === 0) { + searchParams.delete(STATE_PARAM); + } else { + searchParams.delete(STATE_PARAM); + value.filter((state) => state !== "all").map((state) => searchParams.append(STATE_PARAM, state)); + } + setTableURLState({ + pagination: { ...pagination, pageIndex: 0 }, + sorting, + }); + setSearchParams(searchParams); + }, + [pagination, searchParams, setSearchParams, setTableURLState, sorting], + ); + const { data, error, isFetching, isLoading } = useTaskInstanceServiceGetTaskInstances({ dagId, dagRunId: "~", limit: pagination.pageSize, offset: pagination.pageIndex * pagination.pageSize, orderBy, + state: hasFilteredState ? filteredState : undefined, taskId, }); return ( - + + + + + + {() => + hasFilteredState ? ( + + {filteredState.map((state) => ( + + {state === "none" ? "No Status" : capitalize(state)} + + ))} + + ) : ( + "All States" + ) + } + + + + {stateOptions.items.map((option) => ( + + {option.value === "all" ? ( + option.label + ) : ( + {option.label} + )} + + ))} + + + + { const { dagId = "", runId = "", taskId = "" } = useParams(); const [searchParams, setSearchParams] = useSearchParams(); @@ -82,6 +85,13 @@ export const Details = () => { taskInstance={taskInstance} /> )} + + {taskInstance !== undefined && (taskInstance.trigger ?? taskInstance.triggerer_job) ? ( + + ) : undefined} + + Task Instance Info + diff --git a/airflow/ui/src/pages/TaskInstance/ExtraLinks.tsx b/airflow/ui/src/pages/TaskInstance/ExtraLinks.tsx new file mode 100644 index 0000000000000..5990c6fdb5f59 --- /dev/null +++ b/airflow/ui/src/pages/TaskInstance/ExtraLinks.tsx @@ -0,0 +1,53 @@ +/*! + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +import { Box, Button, Heading, HStack } from "@chakra-ui/react"; +import { useParams, useSearchParams } from "react-router-dom"; + +import { useTaskInstanceServiceGetExtraLinks } from "openapi/queries"; + +export const ExtraLinks = () => { + const { dagId = "", runId = "", taskId = "" } = useParams(); + const [searchParams] = useSearchParams(); + const mapIndexParam = searchParams.get("map_index"); + const mapIndex = parseInt(mapIndexParam ?? "-1", 10); + + const { data } = useTaskInstanceServiceGetExtraLinks({ + dagId, + dagRunId: runId, + mapIndex, + taskId, + }); + + return data && Object.keys(data).length > 0 ? ( + + Extra Links + + {Object.entries(data).map(([key, value], _) => + value === null ? undefined : ( + + ), + )} + + + ) : undefined; +}; diff --git a/airflow/ui/src/pages/TaskInstance/Header.tsx b/airflow/ui/src/pages/TaskInstance/Header.tsx index a4a97fcbb1bd0..20e4066c3ad46 100644 --- a/airflow/ui/src/pages/TaskInstance/Header.tsx +++ b/airflow/ui/src/pages/TaskInstance/Header.tsx @@ -73,6 +73,9 @@ export const Header = ({ ); diff --git a/airflow/ui/src/pages/TaskInstance/TriggererInfo.tsx b/airflow/ui/src/pages/TaskInstance/TriggererInfo.tsx new file mode 100644 index 0000000000000..cc718810d58f7 --- /dev/null +++ b/airflow/ui/src/pages/TaskInstance/TriggererInfo.tsx @@ -0,0 +1,58 @@ +/*! + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +import { Box, Table, Heading } from "@chakra-ui/react"; + +import type { TaskInstanceResponse } from "openapi/requests/types.gen"; +import Time from "src/components/Time"; + +export const TriggererInfo = ({ taskInstance }: { readonly taskInstance: TaskInstanceResponse }) => ( + + + Triggerer Info + + + + + Trigger class + {taskInstance.trigger?.classpath} + + + Trigger ID + {taskInstance.trigger?.id} + + + Trigger creation time + + + + + Assigned triggerer + {taskInstance.triggerer_job?.hostname} + + + Latest triggerer heartbeat + + + + + + +); diff --git a/airflow/ui/src/queries/usePatchDagRun.ts b/airflow/ui/src/queries/usePatchDagRun.ts index 1b42c374304b1..0ed16f5ecab33 100644 --- a/airflow/ui/src/queries/usePatchDagRun.ts +++ b/airflow/ui/src/queries/usePatchDagRun.ts @@ -22,9 +22,12 @@ import { UseDagRunServiceGetDagRunKeyFn, useDagRunServiceGetDagRunsKey, useDagRunServicePatchDagRun, + useTaskInstanceServiceGetTaskInstancesKey, } from "openapi/queries"; import { toaster } from "src/components/ui"; +import { useClearDagRunDryRunKey } from "./useClearDagRunDryRun"; + const onError = () => { toaster.create({ description: "Patch Dag Run request failed", @@ -45,7 +48,12 @@ export const usePatchDagRun = ({ const queryClient = useQueryClient(); const onSuccessFn = async () => { - const queryKeys = [UseDagRunServiceGetDagRunKeyFn({ dagId, dagRunId }), [useDagRunServiceGetDagRunsKey]]; + const queryKeys = [ + UseDagRunServiceGetDagRunKeyFn({ dagId, dagRunId }), + [useDagRunServiceGetDagRunsKey], + [useTaskInstanceServiceGetTaskInstancesKey, { dagId, dagRunId }], + [useClearDagRunDryRunKey, dagId], + ]; await Promise.all(queryKeys.map((key) => queryClient.invalidateQueries({ queryKey: key }))); diff --git a/airflow/ui/src/queries/useTrigger.ts b/airflow/ui/src/queries/useTrigger.ts index 0a6e6f492abb9..2c56bda669539 100644 --- a/airflow/ui/src/queries/useTrigger.ts +++ b/airflow/ui/src/queries/useTrigger.ts @@ -20,15 +20,14 @@ import { useQueryClient } from "@tanstack/react-query"; import { useState } from "react"; import { - useDagRunServiceGetDagRunsKey, + UseDagRunServiceGetDagRunsKeyFn, useDagRunServiceTriggerDagRun, useDagServiceGetDagsKey, useDagsServiceRecentDagRunsKey, - useTaskInstanceServiceGetTaskInstancesKey, + UseTaskInstanceServiceGetTaskInstancesKeyFn, } from "openapi/queries"; import type { DagRunTriggerParams } from "src/components/TriggerDag/TriggerDAGForm"; import { toaster } from "src/components/ui"; -import { doQueryKeysMatch, type PartialQueryKey } from "src/utils"; export const useTrigger = ({ dagId, onSuccessConfirm }: { dagId: string; onSuccessConfirm: () => void }) => { const queryClient = useQueryClient(); @@ -37,14 +36,14 @@ export const useTrigger = ({ dagId, onSuccessConfirm }: { dagId: string; onSucce const [dateValidationError, setDateValidationError] = useState(undefined); const onSuccess = async () => { - const queryKeys: Array = [ - { baseKey: useDagServiceGetDagsKey }, - { baseKey: useDagsServiceRecentDagRunsKey }, - { baseKey: useDagRunServiceGetDagRunsKey, options: { dagIds: [dagId] } }, - { baseKey: useTaskInstanceServiceGetTaskInstancesKey, options: { dagId, dagRunId: "~" } }, + const queryKeys = [ + [useDagServiceGetDagsKey], + [useDagsServiceRecentDagRunsKey], + UseDagRunServiceGetDagRunsKeyFn({ dagId }, [{ dagId }]), + UseTaskInstanceServiceGetTaskInstancesKeyFn({ dagId, dagRunId: "~" }, [{ dagId, dagRunId: "~" }]), ]; - await queryClient.invalidateQueries({ predicate: (query) => doQueryKeysMatch(query, queryKeys) }); + await Promise.all(queryKeys.map((key) => queryClient.invalidateQueries({ queryKey: key }))); toaster.create({ description: "DAG run has been successfully triggered.", diff --git a/airflow/ui/src/utils/query.ts b/airflow/ui/src/utils/query.ts index 415d1a5c9fa9b..095229d2dac59 100644 --- a/airflow/ui/src/utils/query.ts +++ b/airflow/ui/src/utils/query.ts @@ -16,8 +16,6 @@ * specific language governing permissions and limitations * under the License. */ -import type { Query } from "@tanstack/react-query"; - import { useDagServiceGetDagDetails } from "openapi/queries"; import type { TaskInstanceState } from "openapi/requests/types.gen"; import { useConfig } from "src/queries/useConfig"; @@ -32,26 +30,6 @@ export const isStatePending = (state?: TaskInstanceState | null) => state === "restarting" || !Boolean(state); -export type PartialQueryKey = { baseKey: string; options?: Record }; - -// This allows us to specify what query key values we actually care about and ignore the rest -// ex: match everything with this dagId and dagRunId but ignore anything related to pagination -export const doQueryKeysMatch = (query: Query, queryKeysToMatch: Array) => { - const [baseKey, options] = query.queryKey; - - const matchedKey = queryKeysToMatch.find((qk) => qk.baseKey === baseKey); - - if (!matchedKey) { - return false; - } - - return matchedKey.options - ? Object.entries(matchedKey.options).every( - ([key, value]) => typeof options === "object" && (options as Record)[key] === value, - ) - : true; -}; - export const useAutoRefresh = ({ dagId, isPaused }: { dagId?: string; isPaused?: boolean }) => { const autoRefreshInterval = useConfig("auto_refresh_interval") as number | undefined; const { data: dag } = useDagServiceGetDagDetails( diff --git a/airflow/utils/cli.py b/airflow/utils/cli.py index c1aba4da8f580..5c050c50a1dcb 100644 --- a/airflow/utils/cli.py +++ b/airflow/utils/cli.py @@ -35,9 +35,9 @@ from airflow import settings from airflow.exceptions import AirflowException +from airflow.sdk.execution_time.secrets_masker import should_hide_value_for_key from airflow.utils import cli_action_loggers, timezone from airflow.utils.log.non_caching_file_handler import NonCachingFileHandler -from airflow.utils.log.secrets_masker import should_hide_value_for_key from airflow.utils.platform import getuser, is_terminal_support_colors T = TypeVar("T", bound=Callable) diff --git a/airflow/utils/context.py b/airflow/utils/context.py index 0415542c6ca8c..62308605dbb37 100644 --- a/airflow/utils/context.py +++ b/airflow/utils/context.py @@ -66,7 +66,6 @@ from sqlalchemy.orm import Session from sqlalchemy.sql.expression import Select, TextClause - from airflow.sdk.definitions.baseoperator import BaseOperator from airflow.sdk.types import OutletEventAccessorsProtocol # NOTE: Please keep this in sync with the following: @@ -293,24 +292,6 @@ def context_merge(context: Context, *args: Any, **kwargs: Any) -> None: context.update(*args, **kwargs) -def context_update_for_unmapped(context: Context, task: BaseOperator) -> None: - """ - Update context after task unmapping. - - Since ``get_template_context()`` is called before unmapping, the context - contains information about the mapped task. We need to do some in-place - updates to ensure the template context reflects the unmapped task instead. - - :meta private: - """ - from airflow.sdk.definitions.param import process_params - - context["task"] = context["ti"].task = task - context["params"] = process_params( - context["dag"], task, context["dag_run"].conf, suppress_exception=False - ) - - def context_copy_partial(source: Context, keys: Container[str]) -> Context: """ Create a context by copying items under selected keys in ``source``. diff --git a/airflow/utils/setup_teardown.py b/airflow/utils/setup_teardown.py index 3108657d30ac2..32d19d316844c 100644 --- a/airflow/utils/setup_teardown.py +++ b/airflow/utils/setup_teardown.py @@ -23,8 +23,8 @@ if TYPE_CHECKING: from airflow.models.taskmixin import DependencyMixin - from airflow.models.xcom_arg import PlainXComArg from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator + from airflow.sdk.definitions.xcom_arg import PlainXComArg class BaseSetupTeardownContext: @@ -335,7 +335,7 @@ class SetupTeardownContext(BaseSetupTeardownContext): @staticmethod def add_task(task: AbstractOperator | PlainXComArg): """Add task to context manager.""" - from airflow.models.xcom_arg import PlainXComArg + from airflow.sdk.definitions.xcom_arg import PlainXComArg if not SetupTeardownContext.active: raise AirflowException("Cannot add task to context outside the context manager.") diff --git a/airflow/utils/task_group.py b/airflow/utils/task_group.py index b74921f75d533..2ebb95709913c 100644 --- a/airflow/utils/task_group.py +++ b/airflow/utils/task_group.py @@ -19,61 +19,15 @@ from __future__ import annotations -import functools -import operator -from collections.abc import Iterator from typing import TYPE_CHECKING import airflow.sdk.definitions.taskgroup if TYPE_CHECKING: - from sqlalchemy.orm import Session - from airflow.typing_compat import TypeAlias TaskGroup: TypeAlias = airflow.sdk.definitions.taskgroup.TaskGroup - - -class MappedTaskGroup(airflow.sdk.definitions.taskgroup.MappedTaskGroup): # noqa: D101 - # TODO: Rename this to SerializedMappedTaskGroup perhaps? - - def iter_mapped_task_groups(self) -> Iterator[MappedTaskGroup]: - """ - Return mapped task groups in the hierarchy. - - Groups are returned from the closest to the outmost. If *self* is a - mapped task group, it is returned first. - - :meta private: - """ - group: TaskGroup | None = self - while group is not None: - if isinstance(group, MappedTaskGroup): - yield group - group = group.parent_group - - def get_mapped_ti_count(self, run_id: str, *, session: Session) -> int: - """ - Return the number of instances a task in this group should be mapped to at run time. - - This considers both literal and non-literal mapped arguments, and the - result is therefore available when all depended tasks have finished. The - return value should be identical to ``parse_time_mapped_ti_count`` if - all mapped arguments are literal. - - If this group is inside mapped task groups, all the nested counts are - multiplied and accounted. - - :meta private: - - :raise NotFullyPopulated: If upstream tasks are not all complete yet. - :return: Total number of mapped TIs this task should have. - """ - groups = self.iter_mapped_task_groups() - return functools.reduce( - operator.mul, - (g._expand_input.get_total_map_length(run_id, session=session) for g in groups), - ) +MappedTaskGroup: TypeAlias = airflow.sdk.definitions.taskgroup.MappedTaskGroup def task_group_to_dict(task_item_or_group): diff --git a/airflow/www/decorators.py b/airflow/www/decorators.py index 9a916da184c13..96a7f32e4c1a7 100644 --- a/airflow/www/decorators.py +++ b/airflow/www/decorators.py @@ -31,7 +31,7 @@ from airflow.api_fastapi.app import get_auth_manager from airflow.models import Log -from airflow.utils.log import secrets_masker +from airflow.sdk.execution_time import secrets_masker from airflow.utils.session import create_session T = TypeVar("T", bound=Callable) diff --git a/airflow/www/views.py b/airflow/www/views.py index 4ce5c6564619f..f99dcd9161f33 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -114,6 +114,7 @@ from airflow.plugins_manager import PLUGINS_ATTRIBUTES_TO_DUMP from airflow.providers_manager import ProvidersManager from airflow.sdk.definitions.asset import Asset, AssetAlias +from airflow.sdk.execution_time import secrets_masker from airflow.security import permissions from airflow.ti_deps.dep_context import DepContext from airflow.ti_deps.dependencies_deps import SCHEDULER_QUEUED_DEPS @@ -127,7 +128,6 @@ from airflow.utils.db import get_query_count from airflow.utils.docs import get_doc_url_for_provider, get_docs_url from airflow.utils.helpers import exactly_one -from airflow.utils.log import secrets_masker from airflow.utils.log.log_reader import TaskLogReader from airflow.utils.net import get_hostname from airflow.utils.session import NEW_SESSION, create_session, provide_session diff --git a/chart/templates/_helpers.yaml b/chart/templates/_helpers.yaml index c76cefa843e34..718674f807441 100644 --- a/chart/templates/_helpers.yaml +++ b/chart/templates/_helpers.yaml @@ -620,8 +620,8 @@ server_tls_key_file = /etc/pgbouncer/server.key {{/* Create the name of the API server service account to use */}} -{{- define "apiServer.serviceAccountName" -}} - {{- include "_serviceAccountName" (merge (dict "key" "apiServer" "nameSuffix" "api-server" ) .) -}} +{{- define "_apiServer.serviceAccountName" -}} + {{- include "_serviceAccountName" (merge (dict "key" "_apiServer" "nameSuffix" "api-server" ) .) -}} {{- end }} {{/* Create the name of the redis service account to use */}} diff --git a/chart/templates/api-server/api-server-deployment.yaml b/chart/templates/api-server/api-server-deployment.yaml index b4cf2cd4461f1..3ce0fb9ac266f 100644 --- a/chart/templates/api-server/api-server-deployment.yaml +++ b/chart/templates/api-server/api-server-deployment.yaml @@ -21,15 +21,15 @@ ## Airflow API Server Deployment ################################# {{- if semverCompare ">=3.0.0" .Values.airflowVersion }} -{{- $nodeSelector := or .Values.apiServer.nodeSelector .Values.nodeSelector }} -{{- $affinity := or .Values.apiServer.affinity .Values.affinity }} -{{- $tolerations := or .Values.apiServer.tolerations .Values.tolerations }} -{{- $topologySpreadConstraints := or .Values.apiServer.topologySpreadConstraints .Values.topologySpreadConstraints }} -{{- $revisionHistoryLimit := or .Values.apiServer.revisionHistoryLimit .Values.revisionHistoryLimit }} -{{- $securityContext := include "airflowPodSecurityContext" (list . .Values.apiServer) }} -{{- $containerSecurityContext := include "containerSecurityContext" (list . .Values.apiServer) }} -{{- $containerSecurityContextWaitForMigrations := include "containerSecurityContext" (list . .Values.apiServer.waitForMigrations) }} -{{- $containerLifecycleHooks := or .Values.apiServer.containerLifecycleHooks .Values.containerLifecycleHooks }} +{{- $nodeSelector := or .Values._apiServer.nodeSelector .Values.nodeSelector }} +{{- $affinity := or .Values._apiServer.affinity .Values.affinity }} +{{- $tolerations := or .Values._apiServer.tolerations .Values.tolerations }} +{{- $topologySpreadConstraints := or .Values._apiServer.topologySpreadConstraints .Values.topologySpreadConstraints }} +{{- $revisionHistoryLimit := or .Values._apiServer.revisionHistoryLimit .Values.revisionHistoryLimit }} +{{- $securityContext := include "airflowPodSecurityContext" (list . .Values._apiServer) }} +{{- $containerSecurityContext := include "containerSecurityContext" (list . .Values._apiServer) }} +{{- $containerSecurityContextWaitForMigrations := include "containerSecurityContext" (list . .Values._apiServer.waitForMigrations) }} +{{- $containerLifecycleHooks := or .Values._apiServer.containerLifecycleHooks .Values.containerLifecycleHooks }} apiVersion: apps/v1 kind: Deployment metadata: @@ -43,17 +43,17 @@ metadata: {{- with .Values.labels }} {{- toYaml . | nindent 4 }} {{- end }} - {{- if .Values.apiServer.annotations }} - annotations: {{- toYaml .Values.apiServer.annotations | nindent 4 }} + {{- if .Values._apiServer.annotations }} + annotations: {{- toYaml .Values._apiServer.annotations | nindent 4 }} {{- end }} spec: - replicas: {{ .Values.apiServer.replicas }} + replicas: {{ .Values._apiServer.replicas }} {{- if $revisionHistoryLimit }} revisionHistoryLimit: {{ $revisionHistoryLimit }} {{- end }} strategy: - {{- if .Values.apiServer.strategy }} - {{- toYaml .Values.apiServer.strategy | nindent 4 }} + {{- if .Values._apiServer.strategy }} + {{- toYaml .Values._apiServer.strategy | nindent 4 }} {{- else }} # Here we define the rolling update strategy # - maxSurge define how many pod we can add at a time @@ -78,8 +78,8 @@ spec: tier: airflow component: api-server release: {{ .Release.Name }} - {{- if or (.Values.labels) (.Values.apiServer.labels) }} - {{- mustMerge .Values.apiServer.labels .Values.labels | toYaml | nindent 8 }} + {{- if or (.Values.labels) (.Values._apiServer.labels) }} + {{- mustMerge .Values._apiServer.labels .Values.labels | toYaml | nindent 8 }} {{- end }} annotations: checksum/metadata-secret: {{ include (print $.Template.BasePath "/secrets/metadata-connection-secret.yaml") . | sha256sum }} @@ -90,16 +90,16 @@ spec: {{- if .Values.airflowPodAnnotations }} {{- toYaml .Values.airflowPodAnnotations | nindent 8 }} {{- end }} - {{- if .Values.apiServer.podAnnotations }} - {{- toYaml .Values.apiServer.podAnnotations | nindent 8 }} + {{- if .Values._apiServer.podAnnotations }} + {{- toYaml .Values._apiServer.podAnnotations | nindent 8 }} {{- end }} spec: - {{- if .Values.apiServer.hostAliases }} - hostAliases: {{- toYaml .Values.apiServer.hostAliases | nindent 8 }} + {{- if .Values._apiServer.hostAliases }} + hostAliases: {{- toYaml .Values._apiServer.hostAliases | nindent 8 }} {{- end }} - serviceAccountName: {{ include "apiServer.serviceAccountName" . }} - {{- if .Values.apiServer.priorityClassName }} - priorityClassName: {{ .Values.apiServer.priorityClassName }} + serviceAccountName: {{ include "_apiServer.serviceAccountName" . }} + {{- if .Values._apiServer.priorityClassName }} + priorityClassName: {{ .Values._apiServer.priorityClassName }} {{- end }} {{- if .Values.schedulerName }} schedulerName: {{ .Values.schedulerName }} @@ -127,9 +127,9 @@ spec: - name: {{ template "registry_secret" . }} {{- end }} initContainers: - {{- if .Values.apiServer.waitForMigrations.enabled }} + {{- if .Values._apiServer.waitForMigrations.enabled }} - name: wait-for-airflow-migrations - resources: {{- toYaml .Values.apiServer.resources | nindent 12 }} + resources: {{- toYaml .Values._apiServer.resources | nindent 12 }} image: {{ template "airflow_image_for_migrations" . }} imagePullPolicy: {{ .Values.images.airflow.pullPolicy }} securityContext: {{ $containerSecurityContextWaitForMigrations | nindent 12 }} @@ -138,20 +138,20 @@ spec: {{- if .Values.volumeMounts }} {{- toYaml .Values.volumeMounts | nindent 12 }} {{- end }} - {{- if .Values.apiServer.extraVolumeMounts }} - {{- tpl (toYaml .Values.apiServer.extraVolumeMounts) . | nindent 12 }} + {{- if .Values._apiServer.extraVolumeMounts }} + {{- tpl (toYaml .Values._apiServer.extraVolumeMounts) . | nindent 12 }} {{- end }} args: {{- include "wait-for-migrations-command" . | indent 10 }} envFrom: {{- include "custom_airflow_environment_from" . | default "\n []" | indent 10 }} env: {{- include "custom_airflow_environment" . | indent 10 }} {{- include "standard_airflow_environment" . | indent 10 }} - {{- if .Values.apiServer.waitForMigrations.env }} - {{- tpl (toYaml .Values.apiServer.waitForMigrations.env) $ | nindent 12 }} + {{- if .Values._apiServer.waitForMigrations.env }} + {{- tpl (toYaml .Values._apiServer.waitForMigrations.env) $ | nindent 12 }} {{- end }} {{- end }} - {{- if .Values.apiServer.extraInitContainers }} - {{- toYaml .Values.apiServer.extraInitContainers | nindent 8 }} + {{- if .Values._apiServer.extraInitContainers }} + {{- toYaml .Values._apiServer.extraInitContainers | nindent 8 }} {{- end }} containers: - name: api-server @@ -161,13 +161,13 @@ spec: {{- if $containerLifecycleHooks }} lifecycle: {{- tpl (toYaml $containerLifecycleHooks) . | nindent 12 }} {{- end }} - {{- if .Values.apiServer.command }} - command: {{ tpl (toYaml .Values.apiServer.command) . | nindent 12 }} + {{- if .Values._apiServer.command }} + command: {{ tpl (toYaml .Values._apiServer.command) . | nindent 12 }} {{- end }} - {{- if .Values.apiServer.args }} - args: {{- tpl (toYaml .Values.apiServer.args) . | nindent 12 }} + {{- if .Values._apiServer.args }} + args: {{- tpl (toYaml .Values._apiServer.args) . | nindent 12 }} {{- end }} - resources: {{- toYaml .Values.apiServer.resources | nindent 12 }} + resources: {{- toYaml .Values._apiServer.resources | nindent 12 }} volumeMounts: {{- include "airflow_config_mount" . | nindent 12 }} {{- if .Values.logs.persistence.enabled }} @@ -177,48 +177,48 @@ spec: {{- if .Values.volumeMounts }} {{- toYaml .Values.volumeMounts | nindent 12 }} {{- end }} - {{- if .Values.apiServer.extraVolumeMounts }} - {{- tpl (toYaml .Values.apiServer.extraVolumeMounts) . | nindent 12 }} + {{- if .Values._apiServer.extraVolumeMounts }} + {{- tpl (toYaml .Values._apiServer.extraVolumeMounts) . | nindent 12 }} {{- end }} ports: - name: api-server - containerPort: {{ .Values.ports.apiServer }} + containerPort: {{ .Values.ports._apiServer }} livenessProbe: httpGet: path: /public/version - port: {{ .Values.ports.apiServer }} - scheme: {{ .Values.apiServer.livenessProbe.scheme | default "http" }} - initialDelaySeconds: {{ .Values.apiServer.livenessProbe.initialDelaySeconds }} - timeoutSeconds: {{ .Values.apiServer.livenessProbe.timeoutSeconds }} - failureThreshold: {{ .Values.apiServer.livenessProbe.failureThreshold }} - periodSeconds: {{ .Values.apiServer.livenessProbe.periodSeconds }} + port: {{ .Values.ports._apiServer }} + scheme: {{ .Values._apiServer.livenessProbe.scheme | default "http" }} + initialDelaySeconds: {{ .Values._apiServer.livenessProbe.initialDelaySeconds }} + timeoutSeconds: {{ .Values._apiServer.livenessProbe.timeoutSeconds }} + failureThreshold: {{ .Values._apiServer.livenessProbe.failureThreshold }} + periodSeconds: {{ .Values._apiServer.livenessProbe.periodSeconds }} readinessProbe: httpGet: path: /public/version - port: {{ .Values.ports.apiServer }} - scheme: {{ .Values.apiServer.readinessProbe.scheme | default "http" }} - initialDelaySeconds: {{ .Values.apiServer.readinessProbe.initialDelaySeconds }} - timeoutSeconds: {{ .Values.apiServer.readinessProbe.timeoutSeconds }} - failureThreshold: {{ .Values.apiServer.readinessProbe.failureThreshold }} - periodSeconds: {{ .Values.apiServer.readinessProbe.periodSeconds }} + port: {{ .Values.ports._apiServer }} + scheme: {{ .Values._apiServer.readinessProbe.scheme | default "http" }} + initialDelaySeconds: {{ .Values._apiServer.readinessProbe.initialDelaySeconds }} + timeoutSeconds: {{ .Values._apiServer.readinessProbe.timeoutSeconds }} + failureThreshold: {{ .Values._apiServer.readinessProbe.failureThreshold }} + periodSeconds: {{ .Values._apiServer.readinessProbe.periodSeconds }} startupProbe: httpGet: path: /public/version - port: {{ .Values.ports.apiServer }} - scheme: {{ .Values.apiServer.startupProbe.scheme | default "http" }} - timeoutSeconds: {{ .Values.apiServer.startupProbe.timeoutSeconds }} - failureThreshold: {{ .Values.apiServer.startupProbe.failureThreshold }} - periodSeconds: {{ .Values.apiServer.startupProbe.periodSeconds }} + port: {{ .Values.ports._apiServer }} + scheme: {{ .Values._apiServer.startupProbe.scheme | default "http" }} + timeoutSeconds: {{ .Values._apiServer.startupProbe.timeoutSeconds }} + failureThreshold: {{ .Values._apiServer.startupProbe.failureThreshold }} + periodSeconds: {{ .Values._apiServer.startupProbe.periodSeconds }} envFrom: {{- include "custom_airflow_environment_from" . | default "\n []" | indent 10 }} env: {{- include "custom_airflow_environment" . | indent 10 }} {{- include "standard_airflow_environment" . | indent 10 }} - {{- include "container_extra_envs" (list . .Values.apiServer.env) | indent 10 }} + {{- include "container_extra_envs" (list . .Values._apiServer.env) | indent 10 }} {{- if and (.Values.dags.gitSync.enabled) (not .Values.dags.persistence.enabled) (semverCompare "<2.0.0" .Values.airflowVersion) }} {{- include "git_sync_container" . | nindent 8 }} {{- end }} - {{- if .Values.apiServer.extraContainers }} - {{- tpl (toYaml .Values.apiServer.extraContainers) . | nindent 8 }} + {{- if .Values._apiServer.extraContainers }} + {{- tpl (toYaml .Values._apiServer.extraContainers) . | nindent 8 }} {{- end }} volumes: - name: config @@ -234,7 +234,7 @@ spec: {{- if .Values.volumes }} {{- toYaml .Values.volumes | nindent 8 }} {{- end }} - {{- if .Values.apiServer.extraVolumes }} - {{- tpl (toYaml .Values.apiServer.extraVolumes) . | nindent 8 }} + {{- if .Values._apiServer.extraVolumes }} + {{- tpl (toYaml .Values._apiServer.extraVolumes) . | nindent 8 }} {{- end }} {{- end }} diff --git a/chart/templates/api-server/api-server-networkpolicy.yaml b/chart/templates/api-server/api-server-networkpolicy.yaml index af4601811200e..d648ea4baa25b 100644 --- a/chart/templates/api-server/api-server-networkpolicy.yaml +++ b/chart/templates/api-server/api-server-networkpolicy.yaml @@ -32,8 +32,8 @@ metadata: release: {{ .Release.Name }} chart: "{{ .Chart.Name }}-{{ .Chart.Version }}" heritage: {{ .Release.Service }} - {{- if or (.Values.labels) (.Values.apiServer.labels) }} - {{- mustMerge .Values.apiServer.labels .Values.labels | toYaml | nindent 4 }} + {{- if or (.Values.labels) (.Values._apiServer.labels) }} + {{- mustMerge .Values._apiServer.labels .Values.labels | toYaml | nindent 4 }} {{- end }} spec: podSelector: @@ -43,11 +43,11 @@ spec: release: {{ .Release.Name }} policyTypes: - Ingress - {{- if .Values.apiServer.networkPolicy.ingress.from }} + {{- if .Values._apiServer.networkPolicy.ingress.from }} ingress: - - from: {{- toYaml .Values.apiServer.networkPolicy.ingress.from | nindent 6 }} + - from: {{- toYaml .Values._apiServer.networkPolicy.ingress.from | nindent 6 }} ports: - {{ range .Values.apiServer.networkPolicy.ingress.ports }} + {{ range .Values._apiServer.networkPolicy.ingress.ports }} - {{- range $key, $val := . }} {{ $key }}: {{ tpl (toString $val) $ }} diff --git a/chart/templates/api-server/api-server-poddisruptionbudget.yaml b/chart/templates/api-server/api-server-poddisruptionbudget.yaml index 7d0b162e41ea9..c8d9249e4acef 100644 --- a/chart/templates/api-server/api-server-poddisruptionbudget.yaml +++ b/chart/templates/api-server/api-server-poddisruptionbudget.yaml @@ -21,7 +21,7 @@ ## Airflow api-server PodDisruptionBudget ################################# {{- if semverCompare ">=3.0.0" .Values.airflowVersion }} -{{- if .Values.apiServer.podDisruptionBudget.enabled }} +{{- if .Values._apiServer.podDisruptionBudget.enabled }} apiVersion: policy/v1 kind: PodDisruptionBudget metadata: @@ -32,8 +32,8 @@ metadata: release: {{ .Release.Name }} chart: {{ .Chart.Name }} heritage: {{ .Release.Service }} - {{- if or (.Values.labels) (.Values.apiServer.labels) }} - {{- mustMerge .Values.apiServer.labels .Values.labels | toYaml | nindent 4 }} + {{- if or (.Values.labels) (.Values._apiServer.labels) }} + {{- mustMerge .Values._apiServer.labels .Values.labels | toYaml | nindent 4 }} {{- end }} spec: selector: @@ -41,6 +41,6 @@ spec: tier: airflow component: api-server release: {{ .Release.Name }} - {{- toYaml .Values.apiServer.podDisruptionBudget.config | nindent 2 }} + {{- toYaml .Values._apiServer.podDisruptionBudget.config | nindent 2 }} {{- end }} {{- end }} diff --git a/chart/templates/api-server/api-server-service.yaml b/chart/templates/api-server/api-server-service.yaml index 0a360aee08539..71ad37f2ff7a9 100644 --- a/chart/templates/api-server/api-server-service.yaml +++ b/chart/templates/api-server/api-server-service.yaml @@ -31,29 +31,29 @@ metadata: release: {{ .Release.Name }} chart: "{{ .Chart.Name }}-{{ .Chart.Version }}" heritage: {{ .Release.Service }} - {{- if or (.Values.labels) (.Values.apiServer.labels) }} - {{- mustMerge .Values.apiServer.labels .Values.labels | toYaml | nindent 4 }} + {{- if or (.Values.labels) (.Values._apiServer.labels) }} + {{- mustMerge .Values._apiServer.labels .Values.labels | toYaml | nindent 4 }} {{- end }} - {{- with .Values.apiServer.service.annotations }} + {{- with .Values._apiServer.service.annotations }} annotations: {{- toYaml . | nindent 4 }} {{- end }} spec: - type: {{ .Values.apiServer.service.type }} + type: {{ .Values._apiServer.service.type }} selector: tier: airflow component: api-server release: {{ .Release.Name }} ports: - {{ range .Values.apiServer.service.ports }} + {{ range .Values._apiServer.service.ports }} - {{- range $key, $val := . }} {{ $key }}: {{ tpl (toString $val) $ }} {{- end }} {{- end }} - {{- if .Values.apiServer.service.loadBalancerIP }} - loadBalancerIP: {{ .Values.apiServer.service.loadBalancerIP }} + {{- if .Values._apiServer.service.loadBalancerIP }} + loadBalancerIP: {{ .Values._apiServer.service.loadBalancerIP }} {{- end }} - {{- if .Values.apiServer.service.loadBalancerSourceRanges }} - loadBalancerSourceRanges: {{- toYaml .Values.apiServer.service.loadBalancerSourceRanges | nindent 4 }} + {{- if .Values._apiServer.service.loadBalancerSourceRanges }} + loadBalancerSourceRanges: {{- toYaml .Values._apiServer.service.loadBalancerSourceRanges | nindent 4 }} {{- end }} {{- end }} diff --git a/chart/templates/api-server/api-server-serviceaccount.yaml b/chart/templates/api-server/api-server-serviceaccount.yaml index 3b864d01602fa..b797a7caadf2c 100644 --- a/chart/templates/api-server/api-server-serviceaccount.yaml +++ b/chart/templates/api-server/api-server-serviceaccount.yaml @@ -20,22 +20,22 @@ ###################################### ## Airflow api-server ServiceAccount ###################################### -{{- if and .Values.apiServer.serviceAccount.create (semverCompare ">=3.0.0" .Values.airflowVersion) }} +{{- if and .Values._apiServer.serviceAccount.create (semverCompare ">=3.0.0" .Values.airflowVersion) }} apiVersion: v1 kind: ServiceAccount -automountServiceAccountToken: {{ .Values.apiServer.serviceAccount.automountServiceAccountToken }} +automountServiceAccountToken: {{ .Values._apiServer.serviceAccount.automountServiceAccountToken }} metadata: - name: {{ include "apiServer.serviceAccountName" . }} + name: {{ include "_apiServer.serviceAccountName" . }} labels: tier: airflow component: api-server release: {{ .Release.Name }} chart: "{{ .Chart.Name }}-{{ .Chart.Version }}" heritage: {{ .Release.Service }} - {{- if or (.Values.labels) (.Values.apiServer.labels) }} - {{- mustMerge .Values.apiServer.labels .Values.labels | toYaml | nindent 4 }} + {{- if or (.Values.labels) (.Values._apiServer.labels) }} + {{- mustMerge .Values._apiServer.labels .Values.labels | toYaml | nindent 4 }} {{- end }} - {{- with .Values.apiServer.serviceAccount.annotations }} + {{- with .Values._apiServer.serviceAccount.annotations }} annotations: {{- toYaml . | nindent 4 }} {{- end }} {{- end }} diff --git a/chart/templates/configmaps/configmap.yaml b/chart/templates/configmaps/configmap.yaml index 119e43cce0a5e..ab4f13d918ad4 100644 --- a/chart/templates/configmaps/configmap.yaml +++ b/chart/templates/configmaps/configmap.yaml @@ -42,7 +42,7 @@ data: {{- if semverCompare ">=3.0.0" .Values.airflowVersion -}} {{- $config := merge .Values.config ( dict "workers" dict )}} {{- if not (hasKey $config.workers "execution_api_server_url") -}} - {{- $_ := set $config.workers "execution_api_server_url" (printf "http://%s-api-server:%d/execution/" (include "airflow.fullname" .) (int .Values.ports.apiServer)) -}} + {{- $_ := set $config.workers "execution_api_server_url" (printf "http://%s-api-server:%d/execution/" (include "airflow.fullname" .) (int .Values.ports._apiServer)) -}} {{- end -}} {{- end -}} # These are system-specified config overrides. diff --git a/chart/templates/scheduler/scheduler-serviceaccount.yaml b/chart/templates/scheduler/scheduler-serviceaccount.yaml index 641fdb82a5e7a..0f4f8cfaa67e0 100644 --- a/chart/templates/scheduler/scheduler-serviceaccount.yaml +++ b/chart/templates/scheduler/scheduler-serviceaccount.yaml @@ -23,7 +23,7 @@ {{- if and .Values.scheduler.enabled .Values.scheduler.serviceAccount.create }} apiVersion: v1 kind: ServiceAccount -{{- if contains "CeleryExecutor" .Values.executor }} +{{- if eq .Values.executor "CeleryExecutor" }} automountServiceAccountToken: {{ .Values.scheduler.serviceAccount.automountServiceAccountToken }} {{- end }} metadata: diff --git a/chart/values.schema.json b/chart/values.schema.json index 036dc63617a52..4564f43dd26fd 100644 --- a/chart/values.schema.json +++ b/chart/values.schema.json @@ -657,7 +657,7 @@ } }, "airflowLocalSettings": { - "description": "`airflow_local_settings` file as a string (can be templated).", + "description": "`airflow_local_settings` file as a string (templated).", "type": [ "string", "null" @@ -1102,7 +1102,7 @@ } }, "extraEnv": { - "description": "Extra env 'items' that will be added to the definition of Airflow containers; a string is expected (can be templated).", + "description": "Extra env 'items' that will be added to the definition of Airflow containers; a string is expected (templated).", "type": [ "null", "string" @@ -1114,7 +1114,7 @@ ] }, "extraEnvFrom": { - "description": "Extra envFrom 'items' that will be added to the definition of Airflow containers; a string is expected (can be templated).", + "description": "Extra envFrom 'items' that will be added to the definition of Airflow containers; a string is expected (templated).", "type": [ "null", "string" @@ -1168,7 +1168,7 @@ "x-docsSection": "Kubernetes", "default": {}, "additionalProperties": { - "description": "Name of the secret (can be templated).", + "description": "Name of the secret (templated).", "type": "object", "minProperties": 1, "additionalProperties": false, @@ -1199,11 +1199,11 @@ "default": true }, "data": { - "description": "Content **as string** for the 'data' item of the secret (can be templated)", + "description": "Content **as string** for the 'data' item of the secret (templated)", "type": "string" }, "stringData": { - "description": "Content **as string** for the 'stringData' item of the secret (can be templated)", + "description": "Content **as string** for the 'stringData' item of the secret (templated)", "type": "string" } } @@ -1223,7 +1223,7 @@ "x-docsSection": "Kubernetes", "default": {}, "additionalProperties": { - "description": "Name of the configMap (can be templated).", + "description": "Name of the configMap (templated).", "type": "object", "minProperties": 1, "additionalProperties": false, @@ -1250,7 +1250,7 @@ "default": true }, "data": { - "description": "Content **as string** for the 'data' item of the configmap (can be templated)", + "description": "Content **as string** for the 'data' item of the configmap (templated)", "type": "string" } } @@ -1790,7 +1790,7 @@ "default": "100Gi" }, "storageClassName": { - "description": "If using a custom StorageClass, pass name ref to all StatefulSets here (can be templated).", + "description": "If using a custom StorageClass, pass name ref to all StatefulSets here (templated).", "type": [ "string", "null" @@ -3161,7 +3161,7 @@ "description": "PersistentVolumeClaim retention policy to be used in the lifecycle of a StatefulSet" }, "storageClassName": { - "description": "If using a custom StorageClass, pass name ref to all StatefulSets here (can be templated).", + "description": "If using a custom StorageClass, pass name ref to all StatefulSets here (templated).", "type": [ "string", "null" @@ -4676,8 +4676,8 @@ } } }, - "apiServer": { - "description": "Airflow API server settings.", + "_apiServer": { + "description": "Airflow API server settings. Experimental / for dev purpose only.", "type": "object", "x-docsSection": "API Server", "additionalProperties": false, @@ -4949,7 +4949,7 @@ }, "default": [ { - "port": "{{ .Values.ports.apiServer }}" + "port": "{{ .Values.ports._apiServer }}" } ], "examples": [ @@ -5176,7 +5176,7 @@ "default": [ { "name": "api-server", - "port": "{{ .Values.ports.apiServer }}" + "port": "{{ .Values.ports._apiServer }}" } ], "examples": [ @@ -5884,7 +5884,7 @@ } }, "webserverConfig": { - "description": "This string (can be templated) will be mounted into the Airflow webserver as a custom `webserver_config.py`. You can bake a `webserver_config.py` in to your image instead or specify a configmap containing the webserver_config.py.", + "description": "This string (templated) will be mounted into the Airflow webserver as a custom `webserver_config.py`. You can bake a `webserver_config.py` in to your image instead or specify a configmap containing the webserver_config.py.", "type": [ "string", "null" @@ -7736,7 +7736,7 @@ "default": "1Gi" }, "storageClassName": { - "description": "If using a custom StorageClass, pass name ref to all StatefulSets here (can be templated).", + "description": "If using a custom StorageClass, pass name ref to all StatefulSets here (templated).", "type": [ "string", "null" @@ -8175,7 +8175,7 @@ "type": "integer", "default": 8080 }, - "apiServer": { + "_apiServer": { "description": "API server port.", "type": "integer", "default": 9091 @@ -8668,7 +8668,7 @@ "default": "1Gi" }, "storageClassName": { - "description": "If using a custom StorageClass, pass name here (can be templated).", + "description": "If using a custom StorageClass, pass name here (templated).", "type": [ "string", "null" @@ -8922,7 +8922,7 @@ ] }, "envFrom": { - "description": "Extra envFrom 'items' that will be added to the definition of Airflow gitSync containers; a string or array are expected (can be templated).", + "description": "Extra envFrom 'items' that will be added to the definition of Airflow gitSync containers; a string or array are expected (templated).", "type": [ "null", "string" @@ -8977,7 +8977,7 @@ "default": "100Gi" }, "storageClassName": { - "description": "If using a custom StorageClass, pass name here (can be templated).", + "description": "If using a custom StorageClass, pass name here (templated).", "type": [ "string", "null" diff --git a/chart/values.yaml b/chart/values.yaml index 852d6c2c792a6..89718d2a34e37 100644 --- a/chart/values.yaml +++ b/chart/values.yaml @@ -160,7 +160,7 @@ ingress: # The hostnames or hosts configuration for the web Ingress hosts: [] - # # The hostname for the web Ingress (can be templated) + # # The hostname for the web Ingress (templated) # - name: "" # # configs for web Ingress TLS # tls: @@ -204,7 +204,7 @@ ingress: # The hostnames or hosts configuration for the flower Ingress hosts: [] - # # The hostname for the flower Ingress (can be templated) + # # The hostname for the flower Ingress (templated) # - name: "" # tls: # # Enable TLS termination for the flower Ingress @@ -241,7 +241,7 @@ ingress: # The hostnames or hosts configuration for the statsd Ingress hosts: [] - # # The hostname for the statsd Ingress (can be templated) + # # The hostname for the statsd Ingress (templated) # - name: "" # tls: # # Enable TLS termination for the statsd Ingress @@ -271,7 +271,7 @@ ingress: # The hostnames or hosts configuration for the pgbouncer Ingress hosts: [] - # # The hostname for the statsd Ingress (can be templated) + # # The hostname for the statsd Ingress (templated) # - name: "" # tls: # # Enable TLS termination for the pgbouncer Ingress @@ -295,7 +295,7 @@ airflowPodAnnotations: {} # main Airflow configmap airflowConfigAnnotations: {} -# `airflow_local_settings` file as a string (can be templated). +# `airflow_local_settings` file as a string (templated). airflowLocalSettings: |- {{- if semverCompare ">=2.2.0" .Values.airflowVersion }} {{- if not (or .Values.webserverSecretKey .Values.webserverSecretKeySecretName) }} @@ -386,9 +386,9 @@ priorityClasses: [] # Extra secrets that will be managed by the chart # (You can use them with extraEnv or extraEnvFrom or some of the extraVolumes values). # The format for secret data is "key/value" where -# * key (can be templated) is the name of the secret that will be created +# * key (templated) is the name of the secret that will be created # * value: an object with the standard 'data' or 'stringData' key (or both). -# The value associated with those keys must be a string (can be templated) +# The value associated with those keys must be a string (templated) extraSecrets: {} # eg: # extraSecrets: @@ -413,9 +413,9 @@ extraSecrets: {} # Extra ConfigMaps that will be managed by the chart # (You can use them with extraEnv or extraEnvFrom or some of the extraVolumes values). # The format for configmap data is "key/value" where -# * key (can be templated) is the name of the configmap that will be created +# * key (templated) is the name of the configmap that will be created # * value: an object with the standard 'data' key. -# The value associated with this keys must be a string (can be templated) +# The value associated with this keys must be a string (templated) extraConfigMaps: {} # eg: # extraConfigMaps: @@ -427,7 +427,7 @@ extraConfigMaps: {} # AIRFLOW_VAR_KUBERNETES_NAMESPACE: "{{ .Release.Namespace }}" # Extra env 'items' that will be added to the definition of airflow containers -# a string is expected (can be templated). +# a string is expected (templated). # TODO: difference from `env`? This is a templated string. Probably should template `env` and remove this. extraEnv: ~ # eg: @@ -436,7 +436,7 @@ extraEnv: ~ # value: 'True' # Extra envFrom 'items' that will be added to the definition of airflow containers -# A string is expected (can be templated). +# A string is expected (templated). extraEnvFrom: ~ # eg: # extraEnvFrom: | @@ -1246,7 +1246,7 @@ migrateDatabaseJob: applyCustomEnv: true env: [] -apiServer: +_apiServer: # Labels specific to workers objects and pods labels: {} @@ -1274,7 +1274,7 @@ apiServer: annotations: {} ports: - name: api-server - port: "{{ .Values.ports.apiServer }}" + port: "{{ .Values.ports._apiServer }}" loadBalancerIP: ~ ## Limit load balancer source ips to list of CIDRs @@ -1313,7 +1313,7 @@ apiServer: from: [] # Ports for webserver NetworkPolicy ingress (if `from` is set) ports: - - port: "{{ .Values.ports.apiServer }}" + - port: "{{ .Values.ports._apiServer }}" resources: {} # limits: @@ -1505,7 +1505,7 @@ webserver: extraVolumes: [] extraVolumeMounts: [] - # This string (can be templated) will be mounted into the Airflow Webserver + # This string (templated) will be mounted into the Airflow Webserver # as a custom webserver_config.py. You can bake a webserver_config.py in to # your image instead or specify a configmap containing the # webserver_config.py. @@ -2529,7 +2529,7 @@ ports: statsdScrape: 9102 pgbouncer: 6543 pgbouncerScrape: 9127 - apiServer: 9091 + _apiServer: 9091 # Define any ResourceQuotas for namespace quotas: {} diff --git a/dev/breeze/src/airflow_breeze/commands/kubernetes_commands.py b/dev/breeze/src/airflow_breeze/commands/kubernetes_commands.py index 82450192422b4..ac6d070a4d4e3 100644 --- a/dev/breeze/src/airflow_breeze/commands/kubernetes_commands.py +++ b/dev/breeze/src/airflow_breeze/commands/kubernetes_commands.py @@ -42,7 +42,12 @@ option_verbose, ) from airflow_breeze.commands.production_image_commands import run_build_production_image -from airflow_breeze.global_constants import ALLOWED_EXECUTORS, ALLOWED_KUBERNETES_VERSIONS +from airflow_breeze.global_constants import ( + ALLOWED_EXECUTORS, + ALLOWED_KUBERNETES_VERSIONS, + CELERY_EXECUTOR, + KUBERNETES_EXECUTOR, +) from airflow_breeze.params.build_prod_params import BuildProdParams from airflow_breeze.utils.ci_group import ci_group from airflow_breeze.utils.click_utils import BreezeGroup @@ -596,7 +601,7 @@ def _rebuild_k8s_image( COPY --chown=airflow:0 airflow/example_dags/ /opt/airflow/dags/ -COPY --chown=airflow:0 providers/src/airflow/providers/cncf/kubernetes/kubernetes_executor_templates/ /opt/airflow/pod_templates/ +COPY --chown=airflow:0 providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/kubernetes_executor_templates/ /opt/airflow/pod_templates/ ENV GUNICORN_CMD_ARGS='--preload' AIRFLOW__WEBSERVER__WORKER_REFRESH_INTERVAL=0 """ @@ -782,6 +787,12 @@ def upload_k8s_image( if return_code == 0: get_console().print("\n[warning]NEXT STEP:[/][info] You might now deploy airflow by:\n") get_console().print("\nbreeze k8s deploy-airflow\n") + get_console().print( + "\n[warning]Note:[/]\nIf you want to run tests with [info]--executor KubernetesExecutor[/], you should deploy airflow with [info]--multi-namespace-mode --executor KubernetesExecutor[/] flag.\n" + ) + get_console().print( + "\nbreeze k8s deploy-airflow --multi-namespace-mode --executor KubernetesExecutor\n" + ) sys.exit(return_code) @@ -1406,6 +1417,31 @@ def _get_parallel_test_args( return combo_titles, combos, pytest_args, short_combo_titles +def _is_deployed_with_same_executor(python: str, kubernetes_version: str, executor: str) -> bool: + """Check if the current cluster is deployed with the same executor that the current tests are using. + + This is especially useful when running tests with executors like KubernetesExecutor, CeleryExecutor, etc. + It verifies by checking the label of the airflow-scheduler deployment. + """ + result = run_command_with_k8s_env( + [ + "kubectl", + "get", + "deployment", + "-n", + "airflow", + "airflow-scheduler", + "-o", + "jsonpath='{.metadata.labels.executor}'", + ], + python=python, + kubernetes_version=kubernetes_version, + capture_output=True, + check=False, + ) + return executor == result.stdout.decode().strip().replace("'", "") + + def _run_tests( python: str, kubernetes_version: str, @@ -1422,7 +1458,17 @@ def _run_tests( extra_shell_args.append("--no-rcs") elif shell_binary.endswith("bash"): extra_shell_args.extend(["--norc", "--noprofile"]) - the_tests: list[str] = ["kubernetes_tests/"] + if ( + executor == KUBERNETES_EXECUTOR or executor == CELERY_EXECUTOR + ) and not _is_deployed_with_same_executor(python, kubernetes_version, executor): + get_console(output=output).print( + f"[warning]{executor} not deployed. Please deploy airflow with {executor} first." + ) + get_console(output=output).print( + f"[info]You can deploy airflow with {executor} by running:[/]\nbreeze k8s configure-cluster\nbreeze k8s deploy-airflow --multi-namespace-mode --executor {executor}" + ) + return 1, f"Tests {kubectl_cluster_name}" + the_tests: list[str] = ["kubernetes_tests/test_kubernetes_executor.py::TestKubernetesExecutor"] command_to_run = " ".join([quote(arg) for arg in ["uv", "run", "pytest", *the_tests, *test_args]]) get_console(output).print(f"[info] Command to run:[/] {command_to_run}") result = run_command( diff --git a/dev/breeze/src/airflow_breeze/commands/production_image_commands.py b/dev/breeze/src/airflow_breeze/commands/production_image_commands.py index 0f5e95f54b79b..3076b677c6aaa 100644 --- a/dev/breeze/src/airflow_breeze/commands/production_image_commands.py +++ b/dev/breeze/src/airflow_breeze/commands/production_image_commands.py @@ -81,7 +81,11 @@ option_airflow_constraints_location, option_airflow_constraints_mode_prod, ) -from airflow_breeze.global_constants import ALLOWED_INSTALLATION_METHODS, DEFAULT_EXTRAS +from airflow_breeze.global_constants import ( + ALLOWED_INSTALLATION_METHODS, + CONSTRAINTS_SOURCE_PROVIDERS, + DEFAULT_EXTRAS, +) from airflow_breeze.params.build_prod_params import BuildProdParams from airflow_breeze.utils.ci_group import ci_group from airflow_breeze.utils.click_utils import BreezeGroup @@ -330,6 +334,10 @@ def run_build(prod_image_params: BuildProdParams) -> None: get_console().print(f"[error]Error when building image! {info}") sys.exit(return_code) + if not install_airflow_version and not airflow_constraints_location: + get_console().print(f"[yellow]Using {CONSTRAINTS_SOURCE_PROVIDERS} constraints mode[/]") + airflow_constraints_mode = CONSTRAINTS_SOURCE_PROVIDERS + perform_environment_checks() check_remote_ghcr_io_commands() base_build_params = BuildProdParams( diff --git a/dev/breeze/src/airflow_breeze/global_constants.py b/dev/breeze/src/airflow_breeze/global_constants.py index 7df3b551a5b7d..0296c996b9c05 100644 --- a/dev/breeze/src/airflow_breeze/global_constants.py +++ b/dev/breeze/src/airflow_breeze/global_constants.py @@ -148,9 +148,13 @@ START_AIRFLOW_DEFAULT_ALLOWED_EXECUTOR = START_AIRFLOW_ALLOWED_EXECUTORS[0] ALLOWED_CELERY_EXECUTORS = [CELERY_EXECUTOR, CELERY_K8S_EXECUTOR] +CONSTRAINTS_SOURCE_PROVIDERS = "constraints-source-providers" +CONSTRAINTS = "constraints" +CONSTRAINTS_NO_PROVIDERS = "constraints-no-providers" + ALLOWED_KIND_OPERATIONS = ["start", "stop", "restart", "status", "deploy", "test", "shell", "k9s"] -ALLOWED_CONSTRAINTS_MODES_CI = ["constraints-source-providers", "constraints", "constraints-no-providers"] -ALLOWED_CONSTRAINTS_MODES_PROD = ["constraints", "constraints-no-providers", "constraints-source-providers"] +ALLOWED_CONSTRAINTS_MODES_CI = [CONSTRAINTS_SOURCE_PROVIDERS, CONSTRAINTS, CONSTRAINTS_NO_PROVIDERS] +ALLOWED_CONSTRAINTS_MODES_PROD = [CONSTRAINTS, CONSTRAINTS_NO_PROVIDERS, CONSTRAINTS_SOURCE_PROVIDERS] ALLOWED_CELERY_BROKERS = ["rabbitmq", "redis"] DEFAULT_CELERY_BROKER = ALLOWED_CELERY_BROKERS[1] diff --git a/dev/breeze/src/airflow_breeze/prepare_providers/provider_documentation.py b/dev/breeze/src/airflow_breeze/prepare_providers/provider_documentation.py index 6b705148c94cb..98936310c3c37 100644 --- a/dev/breeze/src/airflow_breeze/prepare_providers/provider_documentation.py +++ b/dev/breeze/src/airflow_breeze/prepare_providers/provider_documentation.py @@ -1231,6 +1231,7 @@ def _regenerate_pyproject_toml(context: dict[str, Any], provider_details: Provid trim_blocks=True, keep_trailing_newline=True, ) + get_pyproject_toml_path.write_text(get_pyproject_toml_content) get_console().print( f"[info]Generated {get_pyproject_toml_path} for the {provider_details.provider_id} provider\n" diff --git a/dev/breeze/src/airflow_breeze/templates/pyproject_TEMPLATE.toml.jinja2 b/dev/breeze/src/airflow_breeze/templates/pyproject_TEMPLATE.toml.jinja2 index 5da149fa0d542..62560249e2331 100644 --- a/dev/breeze/src/airflow_breeze/templates/pyproject_TEMPLATE.toml.jinja2 +++ b/dev/breeze/src/airflow_breeze/templates/pyproject_TEMPLATE.toml.jinja2 @@ -68,7 +68,7 @@ classifiers = [ {% endfor %} "Topic :: System :: Monitoring", ] -requires-python = "~=3.9" +requires-python = "{{ REQUIRES_PYTHON }}" # The dependencies should be modified in place in the generated file # Any change in the dependencies is preserved when the file is regenerated diff --git a/dev/breeze/src/airflow_breeze/utils/packages.py b/dev/breeze/src/airflow_breeze/utils/packages.py index fec5ed898d697..022f093871e37 100644 --- a/dev/breeze/src/airflow_breeze/utils/packages.py +++ b/dev/breeze/src/airflow_breeze/utils/packages.py @@ -32,6 +32,7 @@ from airflow_breeze.global_constants import ( ALLOWED_PYTHON_MAJOR_MINOR_VERSIONS, + DEFAULT_PYTHON_MAJOR_MINOR_VERSION, PROVIDER_DEPENDENCIES, PROVIDER_RUNTIME_DATA_SCHEMA_PATH, REGULAR_DOC_PACKAGES, @@ -788,6 +789,12 @@ def get_provider_jinja_context( p for p in ALLOWED_PYTHON_MAJOR_MINOR_VERSIONS if p not in provider_details.excluded_python_versions ] cross_providers_dependencies = get_cross_provider_dependent_packages(provider_package_id=provider_id) + + # Most providers require the same python versions, but some may have exclusions + requires_python_version: str = f"~={DEFAULT_PYTHON_MAJOR_MINOR_VERSION}" + for excluded_python_version in provider_details.excluded_python_versions: + requires_python_version += f",!={excluded_python_version}" + context: dict[str, Any] = { "PROVIDER_ID": provider_details.provider_id, "PACKAGE_PIP_NAME": get_pip_package_name(provider_details.provider_id), @@ -825,6 +832,7 @@ def get_provider_jinja_context( "PIP_REQUIREMENTS_TABLE_RST": convert_pip_requirements_to_table( get_provider_requirements(provider_id), markdown=False ), + "REQUIRES_PYTHON": requires_python_version, } return context diff --git a/dev/moving_providers/move_providers.py b/dev/moving_providers/move_providers.py index 13c1ef20cb586..da1213e4ffbc1 100755 --- a/dev/moving_providers/move_providers.py +++ b/dev/moving_providers/move_providers.py @@ -134,7 +134,11 @@ def _do_stuff( if from_path.exists(): shutil.move(from_path, to_path) console.print(f"\n[yellow]Moved {from_path} -> {to_path}\n") - if remove_empty_parent_dir and len([path for path in from_path.parent.iterdir()]) == 0: + if ( + remove_empty_parent_dir + and from_path.exists() + and len([path for path in from_path.parent.iterdir()]) == 0 + ): console.print(f"\n[yellow]Removed also empty parent dir {from_path.parent}\n") from_path.parent.rmdir() return diff --git a/docs/.gitignore b/docs/.gitignore index 2125d28bc1c31..ce8d1e7b430fd 100644 --- a/docs/.gitignore +++ b/docs/.gitignore @@ -6,8 +6,10 @@ apache-airflow-providers-apache-beam apache-airflow-providers-apache-cassandra apache-airflow-providers-apache-drill apache-airflow-providers-apache-druid +apache-airflow-providers-apache-flink apache-airflow-providers-apache-hive apache-airflow-providers-apache-iceberg +apache-airflow-providers-apache-impala apache-airflow-providers-apache-kafka apache-airflow-providers-apache-kylin apache-airflow-providers-apache-livy @@ -18,6 +20,8 @@ apache-airflow-providers-apprise apache-airflow-providers-asana apache-airflow-providers-atlassian-jira apache-airflow-providers-celery +apache-airflow-providers-cloudant +apache-airflow-providers-cncf-kubernetes apache-airflow-providers-cohere apache-airflow-providers-common-compat apache-airflow-providers-common-io @@ -41,6 +45,7 @@ apache-airflow-providers-jdbc apache-airflow-providers-influxdb apache-airflow-providers-microsoft-mssql apache-airflow-providers-microsoft-psrp +apache-airflow-providers-microsoft-winrm apache-airflow-providers-mongo apache-airflow-providers-openlineage apache-airflow-providers-hashicorp diff --git a/docs/apache-airflow-providers-apache-flink/changelog.rst b/docs/apache-airflow-providers-apache-flink/changelog.rst deleted file mode 100644 index 07ffea0939e6f..0000000000000 --- a/docs/apache-airflow-providers-apache-flink/changelog.rst +++ /dev/null @@ -1,25 +0,0 @@ - - .. Licensed to the Apache Software Foundation (ASF) under one - or more contributor license agreements. See the NOTICE file - distributed with this work for additional information - regarding copyright ownership. The ASF licenses this file - to you under the Apache License, Version 2.0 (the - "License"); you may not use this file except in compliance - with the License. You may obtain a copy of the License at - - .. http://www.apache.org/licenses/LICENSE-2.0 - - .. Unless required by applicable law or agreed to in writing, - software distributed under the License is distributed on an - "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - KIND, either express or implied. See the License for the - specific language governing permissions and limitations - under the License. - - .. NOTE! THIS FILE IS AUTOMATICALLY GENERATED AND WILL BE - OVERWRITTEN WHEN PREPARING PACKAGES. - - .. IF YOU WANT TO MODIFY THIS FILE, YOU SHOULD MODIFY THE TEMPLATE - `PROVIDER_CHANGELOG_TEMPLATE.rst.jinja2` IN the `dev/breeze/src/airflow_breeze/templates` DIRECTORY - -.. include:: ../../providers/src/airflow/providers/apache/flink/CHANGELOG.rst diff --git a/docs/apache-airflow-providers-apache-impala/changelog.rst b/docs/apache-airflow-providers-apache-impala/changelog.rst deleted file mode 100644 index ad7e0972ce927..0000000000000 --- a/docs/apache-airflow-providers-apache-impala/changelog.rst +++ /dev/null @@ -1,25 +0,0 @@ - - .. Licensed to the Apache Software Foundation (ASF) under one - or more contributor license agreements. See the NOTICE file - distributed with this work for additional information - regarding copyright ownership. The ASF licenses this file - to you under the Apache License, Version 2.0 (the - "License"); you may not use this file except in compliance - with the License. You may obtain a copy of the License at - - .. http://www.apache.org/licenses/LICENSE-2.0 - - .. Unless required by applicable law or agreed to in writing, - software distributed under the License is distributed on an - "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - KIND, either express or implied. See the License for the - specific language governing permissions and limitations - under the License. - - .. NOTE! THIS FILE IS AUTOMATICALLY GENERATED AND WILL BE - OVERWRITTEN WHEN PREPARING PACKAGES. - - .. IF YOU WANT TO MODIFY THIS FILE, YOU SHOULD MODIFY THE TEMPLATE - `PROVIDER_CHANGELOG_TEMPLATE.rst.jinja2` IN the `dev/breeze/src/airflow_breeze/templates` DIRECTORY - -.. include:: ../../providers/src/airflow/providers/apache/impala/CHANGELOG.rst diff --git a/docs/apache-airflow-providers-cloudant/changelog.rst b/docs/apache-airflow-providers-cloudant/changelog.rst deleted file mode 100644 index d969e082c17b2..0000000000000 --- a/docs/apache-airflow-providers-cloudant/changelog.rst +++ /dev/null @@ -1,25 +0,0 @@ - - .. Licensed to the Apache Software Foundation (ASF) under one - or more contributor license agreements. See the NOTICE file - distributed with this work for additional information - regarding copyright ownership. The ASF licenses this file - to you under the Apache License, Version 2.0 (the - "License"); you may not use this file except in compliance - with the License. You may obtain a copy of the License at - - .. http://www.apache.org/licenses/LICENSE-2.0 - - .. Unless required by applicable law or agreed to in writing, - software distributed under the License is distributed on an - "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - KIND, either express or implied. See the License for the - specific language governing permissions and limitations - under the License. - - .. NOTE! THIS FILE IS AUTOMATICALLY GENERATED AND WILL BE - OVERWRITTEN WHEN PREPARING PACKAGES. - - .. IF YOU WANT TO MODIFY THIS FILE, YOU SHOULD MODIFY THE TEMPLATE - `PROVIDER_CHANGELOG_TEMPLATE.rst.jinja2` IN the `dev/breeze/src/airflow_breeze/templates` DIRECTORY - -.. include:: ../../providers/src/airflow/providers/cloudant/CHANGELOG.rst diff --git a/docs/apache-airflow-providers-cncf-kubernetes/changelog.rst b/docs/apache-airflow-providers-cncf-kubernetes/changelog.rst deleted file mode 100644 index 6ad86cec6753c..0000000000000 --- a/docs/apache-airflow-providers-cncf-kubernetes/changelog.rst +++ /dev/null @@ -1,25 +0,0 @@ - - .. Licensed to the Apache Software Foundation (ASF) under one - or more contributor license agreements. See the NOTICE file - distributed with this work for additional information - regarding copyright ownership. The ASF licenses this file - to you under the Apache License, Version 2.0 (the - "License"); you may not use this file except in compliance - with the License. You may obtain a copy of the License at - - .. http://www.apache.org/licenses/LICENSE-2.0 - - .. Unless required by applicable law or agreed to in writing, - software distributed under the License is distributed on an - "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - KIND, either express or implied. See the License for the - specific language governing permissions and limitations - under the License. - - .. NOTE! THIS FILE IS AUTOMATICALLY GENERATED AND WILL BE - OVERWRITTEN WHEN PREPARING PACKAGES. - - .. IF YOU WANT TO MODIFY THIS FILE, YOU SHOULD MODIFY THE TEMPLATE - `PROVIDER_CHANGELOG_TEMPLATE.rst.jinja2` IN the `dev/breeze/src/airflow_breeze/templates` DIRECTORY - -.. include:: ../../providers/src/airflow/providers/cncf/kubernetes/CHANGELOG.rst diff --git a/docs/apache-airflow-providers-microsoft-winrm/changelog.rst b/docs/apache-airflow-providers-microsoft-winrm/changelog.rst deleted file mode 100644 index fb0faf44d10fb..0000000000000 --- a/docs/apache-airflow-providers-microsoft-winrm/changelog.rst +++ /dev/null @@ -1,25 +0,0 @@ - - .. Licensed to the Apache Software Foundation (ASF) under one - or more contributor license agreements. See the NOTICE file - distributed with this work for additional information - regarding copyright ownership. The ASF licenses this file - to you under the Apache License, Version 2.0 (the - "License"); you may not use this file except in compliance - with the License. You may obtain a copy of the License at - - .. http://www.apache.org/licenses/LICENSE-2.0 - - .. Unless required by applicable law or agreed to in writing, - software distributed under the License is distributed on an - "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - KIND, either express or implied. See the License for the - specific language governing permissions and limitations - under the License. - - .. NOTE! THIS FILE IS AUTOMATICALLY GENERATED AND WILL BE - OVERWRITTEN WHEN PREPARING PACKAGES. - - .. IF YOU WANT TO MODIFY THIS FILE, YOU SHOULD MODIFY THE TEMPLATE - `PROVIDER_CHANGELOG_TEMPLATE.rst.jinja2` IN the `dev/breeze/src/airflow_breeze/templates` DIRECTORY - -.. include:: ../../providers/src/airflow/providers/microsoft/winrm/CHANGELOG.rst diff --git a/docs/apache-airflow-providers-yandex/changelog.rst b/docs/apache-airflow-providers-yandex/changelog.rst deleted file mode 100644 index 9bcad616eb83d..0000000000000 --- a/docs/apache-airflow-providers-yandex/changelog.rst +++ /dev/null @@ -1,25 +0,0 @@ - - .. Licensed to the Apache Software Foundation (ASF) under one - or more contributor license agreements. See the NOTICE file - distributed with this work for additional information - regarding copyright ownership. The ASF licenses this file - to you under the Apache License, Version 2.0 (the - "License"); you may not use this file except in compliance - with the License. You may obtain a copy of the License at - - .. http://www.apache.org/licenses/LICENSE-2.0 - - .. Unless required by applicable law or agreed to in writing, - software distributed under the License is distributed on an - "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - KIND, either express or implied. See the License for the - specific language governing permissions and limitations - under the License. - - .. NOTE! THIS FILE IS AUTOMATICALLY GENERATED AND WILL BE - OVERWRITTEN WHEN PREPARING PACKAGES. - - .. IF YOU WANT TO MODIFY THIS FILE, YOU SHOULD MODIFY THE TEMPLATE - `PROVIDER_CHANGELOG_TEMPLATE.rst.jinja2` IN the `dev/breeze/src/airflow_breeze/templates` DIRECTORY - -.. include:: ../../providers/src/airflow/providers/yandex/CHANGELOG.rst diff --git a/docs/apache-airflow-providers-yandex/commits.rst b/docs/apache-airflow-providers-yandex/commits.rst deleted file mode 100644 index 78d8068f45cbc..0000000000000 --- a/docs/apache-airflow-providers-yandex/commits.rst +++ /dev/null @@ -1,492 +0,0 @@ - - .. Licensed to the Apache Software Foundation (ASF) under one - or more contributor license agreements. See the NOTICE file - distributed with this work for additional information - regarding copyright ownership. The ASF licenses this file - to you under the Apache License, Version 2.0 (the - "License"); you may not use this file except in compliance - with the License. You may obtain a copy of the License at - - .. http://www.apache.org/licenses/LICENSE-2.0 - - .. Unless required by applicable law or agreed to in writing, - software distributed under the License is distributed on an - "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - KIND, either express or implied. See the License for the - specific language governing permissions and limitations - under the License. - - .. NOTE! THIS FILE IS AUTOMATICALLY GENERATED AND WILL BE - OVERWRITTEN WHEN PREPARING PACKAGES. - - .. IF YOU WANT TO MODIFY THIS FILE, YOU SHOULD MODIFY THE TEMPLATE - `PROVIDER_COMMITS_TEMPLATE.rst.jinja2` IN the `dev/breeze/src/airflow_breeze/templates` DIRECTORY - - .. THE REMAINDER OF THE FILE IS AUTOMATICALLY GENERATED. IT WILL BE OVERWRITTEN AT RELEASE TIME! - -Package apache-airflow-providers-yandex ------------------------------------------------------- - -This package is for Yandex, including: - - - `Yandex.Cloud `__ - - -This is detailed commit list of changes for versions provider package: ``yandex``. -For high-level changelog, see :doc:`package information including changelog `. - - - -4.0.0 -..... - -Latest change: 2024-12-19 - -================================================================================================= =========== ======================================================================================== -Commit Committed Subject -================================================================================================= =========== ======================================================================================== -`35b927fe17 `_ 2024-12-19 ``Update path of example dags in docs (#45069)`` -`4b38bed76c `_ 2024-12-16 ``Bump min version of Providers to 2.9 (#44956)`` -`e786c78f52 `_ 2024-12-07 ``Remove Provider Deprecations in Yandex provider (#44754)`` -`1275fec92f `_ 2024-11-24 ``Use Python 3.9 as target version for Ruff & Black rules (#44298)`` -`4dfae23532 `_ 2024-11-15 ``Update DAG example links in multiple providers documents (#44034)`` -`a53d9f6d25 `_ 2024-11-14 ``Prepare docs for Nov 1st wave of providers (#44011)`` -`857ca4c06c `_ 2024-10-09 ``Split providers out of the main "airflow/" tree into a UV workspace project (#42505)`` -================================================================================================= =========== ======================================================================================== - -3.12.0 -...... - -Latest change: 2024-08-19 - -================================================================================================= =========== ======================================================================= -Commit Committed Subject -================================================================================================= =========== ======================================================================= -`75fb7acbac `_ 2024-08-19 ``Prepare docs for Aug 2nd wave of providers (#41559)`` -`fcbff15bda `_ 2024-08-12 ``Bump minimum Airflow version in providers to Airflow 2.8.0 (#41396)`` -`2daa5bd01a `_ 2024-08-04 ``providers/yandex: fix typing (#40997)`` -`d23881c648 `_ 2024-08-03 ``Prepare docs for Aug 1st wave of providers (#41230)`` -`09a7bd1d58 `_ 2024-07-09 ``Prepare docs 1st wave July 2024 (#40644)`` -`a62bd83188 `_ 2024-06-27 ``Enable enforcing pydocstyle rule D213 in ruff. (#40448)`` -================================================================================================= =========== ======================================================================= - -3.11.2 -...... - -Latest change: 2024-06-22 - -================================================================================================= =========== ========================================================================= -Commit Committed Subject -================================================================================================= =========== ========================================================================= -`6e5ae26382 `_ 2024-06-22 ``Prepare docs 2nd wave June 2024 (#40273)`` -`0d5bb60981 `_ 2024-06-17 ``Fix typos in Providers docs and Yandex hook (#40277)`` -`53e6739e67 `_ 2024-06-01 ``Limit yandex provider to avoid mypy errors (#39990)`` -`8173693a70 `_ 2024-05-31 ``Remove upper-binding in yandex after dataproc issue is fixed (#39974)`` -`b8a83b2293 `_ 2024-05-31 ``Workaround new yandexcloud breaking dataproc integration (#39964)`` -================================================================================================= =========== ========================================================================= - -3.11.1 -...... - -Latest change: 2024-05-26 - -================================================================================================= =========== ========================================================================= -Commit Committed Subject -================================================================================================= =========== ========================================================================= -`34500f3a2f `_ 2024-05-26 ``Prepare docs 3rd wave May 2024 (#39738)`` -`e0dd075d1b `_ 2024-05-15 `` AIP-21: yandexcloud: rename files, emit deprecation warning (#39618)`` -`defe4590e9 `_ 2024-05-11 ``yandex provider: bump version for yq http client package (#39548)`` -`2b1a2f8d56 `_ 2024-05-11 ``Reapply templates for all providers (#39554)`` -`2c05187b07 `_ 2024-05-10 ``Faster 'airflow_version' imports (#39552)`` -`05945a47f3 `_ 2024-05-09 ``add doc about Yandex Query operator (#39445)`` -`73918925ed `_ 2024-05-08 ``Simplify 'airflow_version' imports (#39497)`` -================================================================================================= =========== ========================================================================= - -3.11.0 -...... - -Latest change: 2024-05-01 - -================================================================================================= =========== ======================================================================= -Commit Committed Subject -================================================================================================= =========== ======================================================================= -`fe4605a10e `_ 2024-05-01 ``Prepare docs 1st wave May 2024 (#39328)`` -`ead9b00f7c `_ 2024-04-25 ``Bump minimum Airflow version in providers to Airflow 2.7.0 (#39240)`` -================================================================================================= =========== ======================================================================= - -3.10.0 -...... - -Latest change: 2024-04-13 - -================================================================================================= =========== ================================================================== -Commit Committed Subject -================================================================================================= =========== ================================================================== -`f9dcc82fb6 `_ 2024-04-13 ``Prepare docs 1st wave (RC2) April 2024 (#38995)`` -`5fa80b6aea `_ 2024-04-10 ``Prepare docs 1st wave (RC1) April 2024 (#38863)`` -`a9a6976dd2 `_ 2024-03-28 ``docs: yandex provider grammatical improvements (#38589)`` -`30817a5c6d `_ 2024-03-22 ``support iam token from metadata, simplify code (#38411)`` -`390bec1c82 `_ 2024-03-20 ``Add Yandex Query support from Yandex.Cloud (#37458)`` -`0a74928894 `_ 2024-03-18 ``Bump ruff to 0.3.3 (#38240)`` -`c0b849ad2b `_ 2024-03-11 ``Avoid use of 'assert' outside of the tests (#37718)`` -`83316b8158 `_ 2024-03-04 ``Prepare docs 1st wave (RC1) March 2024 (#37876)`` -`5a0be392e6 `_ 2024-02-16 ``Add comment about versions updated by release manager (#37488)`` -================================================================================================= =========== ================================================================== - -3.9.0 -..... - -Latest change: 2024-02-12 - -================================================================================================= =========== ================================================================================== -Commit Committed Subject -================================================================================================= =========== ================================================================================== -`bfb054e9e8 `_ 2024-02-12 ``Prepare docs 1st wave of Providers February 2024 (#37326)`` -`08036e5df5 `_ 2024-02-08 ``D401 Support in Providers (simple) (#37258)`` -`cea58c1111 `_ 2024-02-02 ``fix: using endpoint from connection if not specified (#37076)`` -`3ec781946a `_ 2024-02-01 ``Add secrets-backends section into the Yandex provider yaml definition (#37065)`` -`0e752383a8 `_ 2024-01-31 ``docs: update description in airflow provider.yaml (#37096)`` -================================================================================================= =========== ================================================================================== - -3.8.0 -..... - -Latest change: 2024-01-26 - -================================================================================================= =========== ==================================================================================================================== -Commit Committed Subject -================================================================================================= =========== ==================================================================================================================== -`cead3da4a6 `_ 2024-01-26 ``Add docs for RC2 wave of providers for 2nd round of Jan 2024 (#37019)`` -`0b680c9492 `_ 2024-01-26 ``Revert "Provide the logger_name param in providers hooks in order to override the logger name (#36675)" (#37015)`` -`12ccb5f0ac `_ 2024-01-25 ``feat: add Yandex Cloud Lockbox secrets backend (#36449)`` -`2b4da0101f `_ 2024-01-22 ``Prepare docs 2nd wave of Providers January 2024 (#36945)`` -`6ff96af480 `_ 2024-01-18 ``Fix stacklevel in warnings.warn into the providers (#36831)`` -`6bd450da1e `_ 2024-01-10 ``Provide the logger_name param in providers hooks in order to override the logger name (#36675)`` -`19ebcac239 `_ 2024-01-07 ``Prepare docs 1st wave of Providers January 2024 (#36640)`` -`6937ae7647 `_ 2023-12-30 ``Speed up autocompletion of Breeze by simplifying provider state (#36499)`` -================================================================================================= =========== ==================================================================================================================== - -3.7.1 -..... - -Latest change: 2023-12-23 - -================================================================================================= =========== ================================================================================== -Commit Committed Subject -================================================================================================= =========== ================================================================================== -`b15d5578da `_ 2023-12-23 ``Re-apply updated version numbers to 2nd wave of providers in December (#36380)`` -`f5883d6e7b `_ 2023-12-23 ``Prepare 2nd wave of providers in December (#36373)`` -`cd476acd8f `_ 2023-12-11 ``Follow BaseHook connection fields method signature in child classes (#36086)`` -================================================================================================= =========== ================================================================================== - -3.7.0 -..... - -Latest change: 2023-12-08 - -================================================================================================= =========== ======================================================================= -Commit Committed Subject -================================================================================================= =========== ======================================================================= -`999b70178a `_ 2023-12-08 ``Prepare docs 1st wave of Providers December 2023 (#36112)`` -`d0918d77ee `_ 2023-12-07 ``Bump minimum Airflow version in providers to Airflow 2.6.0 (#36017)`` -`0b23d5601c `_ 2023-11-24 ``Prepare docs 2nd wave of Providers November 2023 (#35836)`` -`99534e47f3 `_ 2023-11-19 ``Use reproducible builds for provider packages (#35693)`` -`99df205f42 `_ 2023-11-16 ``Fix and reapply templates for provider documentation (#35686)`` -================================================================================================= =========== ======================================================================= - -3.6.0 -..... - -Latest change: 2023-11-08 - -================================================================================================= =========== ================================================================== -Commit Committed Subject -================================================================================================= =========== ================================================================== -`1b059c57d6 `_ 2023-11-08 ``Prepare docs 1st wave of Providers November 2023 (#35537)`` -`706878ec35 `_ 2023-11-04 ``Remove empty lines in generated changelog (#35436)`` -`052e26ad47 `_ 2023-11-04 ``Change security.rst to use includes in providers (#35435)`` -`09880741cb `_ 2023-11-03 ``Add configuration files for yandex (#35420)`` -`0b850a97e8 `_ 2023-11-03 ``Yandex dataproc deduce default service account (#35059)`` -`d1c58d86de `_ 2023-10-28 ``Prepare docs 3rd wave of Providers October 2023 - FIX (#35233)`` -`3592ff4046 `_ 2023-10-28 ``Prepare docs 3rd wave of Providers October 2023 (#35187)`` -`dd7ba3cae1 `_ 2023-10-19 ``Pre-upgrade 'ruff==0.0.292' changes in providers (#35053)`` -================================================================================================= =========== ================================================================== - -3.5.0 -..... - -Latest change: 2023-10-13 - -================================================================================================= =========== =============================================================== -Commit Committed Subject -================================================================================================= =========== =============================================================== -`e9987d5059 `_ 2023-10-13 ``Prepare docs 1st wave of Providers in October 2023 (#34916)`` -`0c8e30e43b `_ 2023-10-05 ``Bump min airflow version of providers (#34728)`` -================================================================================================= =========== =============================================================== - -3.4.0 -..... - -Latest change: 2023-08-26 - -================================================================================================= =========== ====================================================================== -Commit Committed Subject -================================================================================================= =========== ====================================================================== -`c077d19060 `_ 2023-08-26 ``Prepare docs for Aug 2023 3rd wave of Providers (#33730)`` -`2ae1c10bfa `_ 2023-08-23 ``add support for Yandex Dataproc cluster labels (#29811)`` -`2b43fa473f `_ 2023-08-22 ``Resume yandex provider (#33574)`` -`73b90c48b1 `_ 2023-07-21 ``Allow configuration to be contributed by providers (#32604)`` -`3878fe6fab `_ 2023-07-05 ``Remove spurious headers for provider changelogs (#32373)`` -`09d4718d3a `_ 2023-06-27 ``Improve provider documentation and README structure (#32125)`` -`8b146152d6 `_ 2023-06-20 ``Add note about dropping Python 3.7 for providers (#32015)`` -`a59076eaee `_ 2023-06-02 ``Add D400 pydocstyle check - Providers (#31427)`` -`abea189022 `_ 2023-05-18 ``Use '__version__' in providers not 'version' (#31393)`` -`0a30706aa7 `_ 2023-05-03 ``Use 'AirflowProviderDeprecationWarning' in providers (#30975)`` -`eef5bc7f16 `_ 2023-05-03 ``Add full automation for min Airflow version for providers (#30994)`` -`a7eb32a5b2 `_ 2023-04-30 ``Bump minimum Airflow version in providers (#30917)`` -`b4d6e83686 `_ 2023-04-19 ``Suspend Yandex provider due to protobuf limitation (#30667)`` -`d23a3bbed8 `_ 2023-04-04 ``Add mechanism to suspend providers (#30422)`` -================================================================================================= =========== ====================================================================== - -3.3.0 -..... - -Latest change: 2023-03-03 - -================================================================================================= =========== ===================================================================== -Commit Committed Subject -================================================================================================= =========== ===================================================================== -`fcd3c0149f `_ 2023-03-03 ``Prepare docs for 03/2023 wave of Providers (#29878)`` -`1768872a00 `_ 2023-02-22 ``support Yandex SDK feature "endpoint" (#29635)`` -`2b92c3c74d `_ 2023-01-05 ``Fix providers documentation formatting (#28754)`` -`c8e348dcb0 `_ 2022-12-05 ``Add automated version replacement in example dag indexes (#28090)`` -================================================================================================= =========== ===================================================================== - -3.2.0 -..... - -Latest change: 2022-11-26 - -================================================================================================= =========== ==================================================================================== -Commit Committed Subject -================================================================================================= =========== ==================================================================================== -`25bdbc8e67 `_ 2022-11-26 ``Updated docs for RC3 wave of providers (#27937)`` -`2e20e9f7eb `_ 2022-11-24 ``Prepare for follow-up relase for November providers (#27774)`` -`12c3c39d1a `_ 2022-11-15 ``pRepare docs for November 2022 wave of Providers (#27613)`` -`78b8ea2f22 `_ 2022-10-24 ``Move min airflow version to 2.3.0 for all providers (#27196)`` -`2a34dc9e84 `_ 2022-10-23 ``Enable string normalization in python formatting - providers (#27205)`` -`837e463ae8 `_ 2022-10-22 ``Allow no extra prefix in yandex hook (#27040)`` -`f8db64c35c `_ 2022-09-28 ``Update docs for September Provider's release (#26731)`` -`06acf40a43 `_ 2022-09-13 ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)`` -================================================================================================= =========== ==================================================================================== - -3.1.0 -..... - -Latest change: 2022-08-10 - -================================================================================================= =========== =============================================================================== -Commit Committed Subject -================================================================================================= =========== =============================================================================== -`e5ac6c7cfb `_ 2022-08-10 ``Prepare docs for new providers release (August 2022) (#25618)`` -`a61e0c1df7 `_ 2022-07-29 ``YandexCloud provider: Support new Yandex SDK features for DataProc (#25158)`` -`d2459a241b `_ 2022-07-13 ``Add documentation for July 2022 Provider's release (#25030)`` -`0de31bd73a `_ 2022-06-29 ``Move provider dependencies to inside provider folders (#24672)`` -`510a6bab45 `_ 2022-06-28 ``Remove 'hook-class-names' from provider.yaml (#24702)`` -`08b675cf66 `_ 2022-06-13 ``Fix links to sources for examples (#24386)`` -================================================================================================= =========== =============================================================================== - -3.0.0 -..... - -Latest change: 2022-06-09 - -================================================================================================= =========== ================================================================================== -Commit Committed Subject -================================================================================================= =========== ================================================================================== -`dcdcf3a2b8 `_ 2022-06-09 ``Update release notes for RC2 release of Providers for May 2022 (#24307)`` -`717a7588bc `_ 2022-06-07 ``Update package description to remove double min-airflow specification (#24292)`` -`aeabe994b3 `_ 2022-06-07 ``Prepare docs for May 2022 provider's release (#24231)`` -`333e98759b `_ 2022-06-07 ``Fix link broken after #24082 (#24276)`` -`027b707d21 `_ 2022-06-05 ``Add explanatory note for contributors about updating Changelog (#24229)`` -`65ad2aed26 `_ 2022-06-01 ``Migrate Yandex example DAGs to new design AIP-47 (#24082)`` -================================================================================================= =========== ================================================================================== - -2.2.3 -..... - -Latest change: 2022-03-22 - -================================================================================================= =========== ============================================================== -Commit Committed Subject -================================================================================================= =========== ============================================================== -`d7dbfb7e26 `_ 2022-03-22 ``Add documentation for bugfix release of Providers (#22383)`` -================================================================================================= =========== ============================================================== - -2.2.2 -..... - -Latest change: 2022-03-14 - -================================================================================================= =========== ==================================================================== -Commit Committed Subject -================================================================================================= =========== ==================================================================== -`16adc035b1 `_ 2022-03-14 ``Add documentation for Classifier release for March 2022 (#22226)`` -================================================================================================= =========== ==================================================================== - -2.2.1 -..... - -Latest change: 2022-03-07 - -================================================================================================= =========== ========================================================================== -Commit Committed Subject -================================================================================================= =========== ========================================================================== -`f5b96315fe `_ 2022-03-07 ``Add documentation for Feb Providers release (#22056)`` -`6126c4e40f `_ 2022-03-07 ``Fix spelling (#22054)`` -`d94fa37830 `_ 2022-02-08 ``Fixed changelog for January 2022 (delayed) provider's release (#21439)`` -`6c3a67d4fc `_ 2022-02-05 ``Add documentation for January 2021 providers release (#21257)`` -`cb73053211 `_ 2022-01-27 ``Add optional features in providers. (#21074)`` -`602abe8394 `_ 2022-01-20 ``Remove ':type' lines now sphinx-autoapi supports typehints (#20951)`` -================================================================================================= =========== ========================================================================== - -2.2.0 -..... - -Latest change: 2021-12-31 - -================================================================================================= =========== ============================================================================================================ -Commit Committed Subject -================================================================================================= =========== ============================================================================================================ -`f77417eb0d `_ 2021-12-31 ``Fix K8S changelog to be PyPI-compatible (#20614)`` -`97496ba2b4 `_ 2021-12-31 ``Update documentation for provider December 2021 release (#20523)`` -`d56e7b56bb `_ 2021-12-30 ``Fix template_fields type to have MyPy friendly Sequence type (#20571)`` -`a0821235fb `_ 2021-12-30 ``Use typed Context EVERYWHERE (#20565)`` -`6e51608f28 `_ 2021-12-16 ``Fix mypy for providers: elasticsearch, oracle, yandex (#20344)`` -`41c49c7ff6 `_ 2021-12-14 ``YandexCloud provider: Support new Yandex SDK features: log_group_id, user-agent, maven packages (#20103)`` -`9a469d813f `_ 2021-11-30 ``Capitalize names in docs (#19893)`` -`853576d901 `_ 2021-11-30 ``Update documentation for November 2021 provider's release (#19882)`` -`d9567eb106 `_ 2021-10-29 ``Prepare documentation for October Provider's release (#19321)`` -`f5ad26dcdd `_ 2021-10-21 ``Fixup string concatenations (#19099)`` -`840ea3efb9 `_ 2021-09-30 ``Update documentation for September providers release (#18613)`` -`ef037e7021 `_ 2021-09-29 ``Static start_date and default arg cleanup for misc. provider example DAGs (#18597)`` -`e25eea052f `_ 2021-09-19 ``Inclusive Language (#18349)`` -`1cb456cba1 `_ 2021-09-12 ``Add official download page for providers (#18187)`` -`046f02e5a7 `_ 2021-09-09 ``fix misspelling (#18121)`` -================================================================================================= =========== ============================================================================================================ - -2.1.0 -..... - -Latest change: 2021-08-30 - -================================================================================================= =========== ============================================================================= -Commit Committed Subject -================================================================================================= =========== ============================================================================= -`0a68588479 `_ 2021-08-30 ``Add August 2021 Provider's documentation (#17890)`` -`be75dcd39c `_ 2021-08-23 ``Update description about the new ''connection-types'' provider meta-data`` -`76ed2a49c6 `_ 2021-08-19 ``Import Hooks lazily individually in providers manager (#17682)`` -`e3089dd5d0 `_ 2021-08-02 ``Add autoscaling subcluster support and remove defaults (#17033)`` -`87f408b1e7 `_ 2021-07-26 ``Prepares docs for Rc2 release of July providers (#17116)`` -`0dbd0f420c `_ 2021-07-26 ``Remove/refactor default_args pattern for miscellaneous providers (#16872)`` -`b916b75079 `_ 2021-07-15 ``Prepare documentation for July release of providers. (#17015)`` -`866a601b76 `_ 2021-06-28 ``Removes pylint from our toolchain (#16682)`` -================================================================================================= =========== ============================================================================= - -2.0.0 -..... - -Latest change: 2021-06-18 - -================================================================================================= =========== ======================================================================= -Commit Committed Subject -================================================================================================= =========== ======================================================================= -`bbc627a3da `_ 2021-06-18 ``Prepares documentation for rc2 release of Providers (#16501)`` -`cbf8001d76 `_ 2021-06-16 ``Synchronizes updated changelog after buggfix release (#16464)`` -`1fba5402bb `_ 2021-06-15 ``More documentation update for June providers release (#16405)`` -`9c94b72d44 `_ 2021-06-07 ``Updated documentation for June 2021 provider release (#16294)`` -`1e647029e4 `_ 2021-06-01 ``Rename the main branch of the Airflow repo to be 'main' (#16149)`` -`37681bca00 `_ 2021-05-07 ``Auto-apply apply_default decorator (#15667)`` -`807ad32ce5 `_ 2021-05-01 ``Prepares provider release after PIP 21 compatibility (#15576)`` -`40a2476a5d `_ 2021-04-28 ``Adds interactivity when generating provider documentation. (#15518)`` -`a7ca1b3b0b `_ 2021-03-26 ``Fix Sphinx Issues with Docstrings (#14968)`` -`e172bd0e16 `_ 2021-03-22 ``Update docstrings to adhere to sphinx standards (#14918)`` -`68e4c4dcb0 `_ 2021-03-20 ``Remove Backport Providers (#14886)`` -`6e6526a0f6 `_ 2021-03-13 ``Update documentation for broken package releases (#14734)`` -================================================================================================= =========== ======================================================================= - -1.0.1 -..... - -Latest change: 2021-02-04 - -================================================================================================= =========== ========================================================= -Commit Committed Subject -================================================================================================= =========== ========================================================= -`88bdcfa0df `_ 2021-02-04 ``Prepare to release a new wave of providers. (#14013)`` -`ac2f72c98d `_ 2021-02-01 ``Implement provider versioning tools (#13767)`` -`3fd5ef3555 `_ 2021-01-21 ``Add missing logos for integrations (#13717)`` -`295d66f914 `_ 2020-12-30 ``Fix Grammar in PIP warning (#13380)`` -`6cf76d7ac0 `_ 2020-12-18 ``Fix typo in pip upgrade command :( (#13148)`` -`f6448b4e48 `_ 2020-12-15 ``Add link to PyPI Repository to provider docs (#13064)`` -================================================================================================= =========== ========================================================= - -1.0.0 -..... - -Latest change: 2020-12-09 - -================================================================================================= =========== ====================================================================================================================================================================== -Commit Committed Subject -================================================================================================= =========== ====================================================================================================================================================================== -`32971a1a2d `_ 2020-12-09 ``Updates providers versions to 1.0.0 (#12955)`` -`b40dffa085 `_ 2020-12-08 ``Rename remaing modules to match AIP-21 (#12917)`` -`9b39f24780 `_ 2020-12-08 ``Add support for dynamic connection form fields per provider (#12558)`` -`bd90136aaf `_ 2020-11-30 ``Move operator guides to provider documentation packages (#12681)`` -`de3b1e687b `_ 2020-11-28 ``Move connection guides to provider documentation packages (#12653)`` -`ef4af21351 `_ 2020-11-22 ``Move providers docs to separate package + Spell-check in a common job with docs-build (#12527)`` -`f2569de7d1 `_ 2020-11-22 ``Add example DAGs to provider docs (#12528)`` -`c34ef853c8 `_ 2020-11-20 ``Separate out documentation building per provider (#12444)`` -`0080354502 `_ 2020-11-18 ``Update provider READMEs for 1.0.0b2 batch release (#12449)`` -`ae7cb4a1e2 `_ 2020-11-17 ``Update wrong commit hash in backport provider changes (#12390)`` -`6889a333cf `_ 2020-11-15 ``Improvements for operators and hooks ref docs (#12366)`` -`7825e8f590 `_ 2020-11-13 ``Docs installation improvements (#12304)`` -`85a18e13d9 `_ 2020-11-09 ``Point at pypi project pages for cross-dependency of provider packages (#12212)`` -`59eb5de78c `_ 2020-11-09 ``Update provider READMEs for up-coming 1.0.0beta1 releases (#12206)`` -`b2a28d1590 `_ 2020-11-09 ``Moves provider packages scripts to dev (#12082)`` -`4e8f9cc8d0 `_ 2020-11-03 ``Enable Black - Python Auto Formmatter (#9550)`` -`8c42cf1b00 `_ 2020-11-03 ``Use PyUpgrade to use Python 3.6 features (#11447)`` -`5a439e84eb `_ 2020-10-26 ``Prepare providers release 0.0.2a1 (#11855)`` -`872b1566a1 `_ 2020-10-25 ``Generated backport providers readmes/setup for 2020.10.29 (#11826)`` -`349b0811c3 `_ 2020-10-20 ``Add D200 pydocstyle check (#11688)`` -`16e7129719 `_ 2020-10-13 ``Added support for provider packages for Airflow 2.0 (#11487)`` -`0a0e1af800 `_ 2020-10-03 ``Fix Broken Markdown links in Providers README TOC (#11249)`` -`ca4238eb4d `_ 2020-10-02 ``Fixed month in backport packages to October (#11242)`` -`5220e4c384 `_ 2020-10-02 ``Prepare Backport release 2020.09.07 (#11238)`` -`5093245d6f `_ 2020-09-30 ``Strict type coverage for Oracle and Yandex provider (#11198)`` -`9549274d11 `_ 2020-09-09 ``Upgrade black to 20.8b1 (#10818)`` -`fdd9b6f65b `_ 2020-08-25 ``Enable Black on Providers Packages (#10543)`` -`3696c34c28 `_ 2020-08-24 ``Fix typo in the word "release" (#10528)`` -`ee7ca128a1 `_ 2020-08-22 ``Fix broken Markdown refernces in Providers README (#10483)`` -`f6734b3b85 `_ 2020-08-12 ``Enable Sphinx spellcheck for doc generation (#10280)`` -`cdec301254 `_ 2020-08-07 ``Add correct signature to all operators and sensors (#10205)`` -`aeea71274d `_ 2020-08-02 ``Remove 'args' parameter from provider operator constructors (#10097)`` -`7d24b088cd `_ 2020-07-25 ``Stop using start_date in default_args in example_dags (2) (#9985)`` -`d0e7db4024 `_ 2020-06-19 ``Fixed release number for fresh release (#9408)`` -`12af6a0800 `_ 2020-06-19 ``Final cleanup for 2020.6.23rc1 release preparation (#9404)`` -`c7e5bce57f `_ 2020-06-19 ``Prepare backport release candidate for 2020.6.23rc1 (#9370)`` -`40bf8f28f9 `_ 2020-06-18 ``Detect automatically the lack of reference to the guide in the operator descriptions (#9290)`` -`f6bd817a3a `_ 2020-06-16 ``Introduce 'transfers' packages (#9320)`` -`0b0e4f7a4c `_ 2020-05-26 ``Preparing for RC3 relase of backports (#9026)`` -`00642a46d0 `_ 2020-05-26 ``Fixed name of 20 remaining wrongly named operators. (#8994)`` -`1d36b0303b `_ 2020-05-23 ``Fix references in docs (#8984)`` -`375d1ca229 `_ 2020-05-19 ``Release candidate 2 for backport packages 2020.05.20 (#8898)`` -`12c5e5d8ae `_ 2020-05-17 ``Prepare release candidate for backport packages (#8891)`` -`f3521fb0e3 `_ 2020-05-16 ``Regenerate readme files for backport package release (#8886)`` -`92585ca4cb `_ 2020-05-15 ``Added automated release notes generation for backport operators (#8807)`` -`59a4f26699 `_ 2020-04-17 ``stop rendering some class docs in wrong place (#8095)`` -`3320e432a1 `_ 2020-02-24 ``[AIRFLOW-6817] Lazy-load 'airflow.DAG' to keep user-facing API untouched (#7517)`` -`4d03e33c11 `_ 2020-02-22 ``[AIRFLOW-6817] remove imports from 'airflow/__init__.py', replaced implicit imports with explicit imports, added entry to 'UPDATING.MD' - squashed/rebased (#7456)`` -`9cbd7de6d1 `_ 2020-02-18 ``[AIRFLOW-6792] Remove _operator/_hook/_sensor in providers package and add tests (#7412)`` -`ee1ab7697c `_ 2020-02-14 ``[AIRFLOW-6531] Initial Yandex.Cloud Dataproc support (#7252)`` -================================================================================================= =========== ====================================================================================================================================================================== diff --git a/docs/apache-airflow-providers-yandex/configurations-ref.rst b/docs/apache-airflow-providers-yandex/configurations-ref.rst deleted file mode 100644 index 5885c9d91b6e8..0000000000000 --- a/docs/apache-airflow-providers-yandex/configurations-ref.rst +++ /dev/null @@ -1,18 +0,0 @@ - .. Licensed to the Apache Software Foundation (ASF) under one - or more contributor license agreements. See the NOTICE file - distributed with this work for additional information - regarding copyright ownership. The ASF licenses this file - to you under the Apache License, Version 2.0 (the - "License"); you may not use this file except in compliance - with the License. You may obtain a copy of the License at - - .. http://www.apache.org/licenses/LICENSE-2.0 - - .. Unless required by applicable law or agreed to in writing, - software distributed under the License is distributed on an - "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - KIND, either express or implied. See the License for the - specific language governing permissions and limitations - under the License. - -.. include:: ../exts/includes/providers-configurations-ref.rst diff --git a/docs/apache-airflow-providers-yandex/connections/yandexcloud.rst b/docs/apache-airflow-providers-yandex/connections/yandexcloud.rst deleted file mode 100644 index b1d8b4074c295..0000000000000 --- a/docs/apache-airflow-providers-yandex/connections/yandexcloud.rst +++ /dev/null @@ -1,95 +0,0 @@ - .. Licensed to the Apache Software Foundation (ASF) under one - or more contributor license agreements. See the NOTICE file - distributed with this work for additional information - regarding copyright ownership. The ASF licenses this file - to you under the Apache License, Version 2.0 (the - "License"); you may not use this file except in compliance - with the License. You may obtain a copy of the License at - - .. http://www.apache.org/licenses/LICENSE-2.0 - - .. Unless required by applicable law or agreed to in writing, - software distributed under the License is distributed on an - "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - KIND, either express or implied. See the License for the - specific language governing permissions and limitations - under the License. - -.. _yandex_cloud_connection: - -Yandex.Cloud Connection -======================= - -The Yandex.Cloud connection type enables the authentication in Yandex.Cloud services. - -Configuring the Connection --------------------------- - -Service account auth JSON - JSON object as a string. - - Example: ``{"id": "...", "service_account_id": "...", "private_key": "..."}`` - -Service account auth JSON file path - Path to the file containing service account auth JSON. - - Example: ``/home/airflow/authorized_key.json`` - -OAuth Token - User account OAuth token as a string. - - Example: ``y3_Vd3eub7w9bIut67GHeL345gfb5GAnd3dZnf08FR1vjeUFve7Yi8hGvc`` - -SSH public key (optional) - The key will be placed to all created Compute nodes, allowing you to have a root shell there. - -Folder ID (optional) - A folder is an entity to separate different projects within the cloud. - - If specified, this ID will be used by default when creating nodes and clusters. - - See `this guide `__ for details. - -Endpoint (optional) - Use this setting to configure your API endpoint. - - Leave blank to use default `endpoints `__. - -Default Connection IDs ----------------------- - -All hooks and operators related to Yandex.Cloud use the ``yandexcloud_default`` connection by default. - -Authenticating to Yandex.Cloud ------------------------------- - -Using authorized keys to authorize as a service account -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -Before you start, make sure you have `created `__ -a Yandex Cloud `service account `__. - -First, you need to create an `authorized key `__ -for your service account and save the generated JSON file with both public and private key parts. - -Then, you need to specify the key in the ``Service account auth JSON`` field. - -Alternatively, you can specify the path to the JSON file in the ``Service account auth JSON file path`` field. - -Using an OAuth token to authorize as a user account -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -First, you need to create -an `OAuth token `__ for your user account. -Your token will look like this: ``y3_Vd3eub7w9bIut67GHeL345gfb5GAnd3dZnf08FR1vjeUFve7Yi8hGvc``. - -Then you need to specify your token in the ``OAuth Token`` field. - -Using metadata service -~~~~~~~~~~~~~~~~~~~~~~ - -If you do not specify any credentials, the connection will attempt to use -the `metadata service `__ for authentication. - -To do this, you need to `link `__ -your service account with your VM. diff --git a/docs/apache-airflow-providers-yandex/index.rst b/docs/apache-airflow-providers-yandex/index.rst deleted file mode 100644 index 249542f23edc3..0000000000000 --- a/docs/apache-airflow-providers-yandex/index.rst +++ /dev/null @@ -1,110 +0,0 @@ - - .. Licensed to the Apache Software Foundation (ASF) under one - or more contributor license agreements. See the NOTICE file - distributed with this work for additional information - regarding copyright ownership. The ASF licenses this file - to you under the Apache License, Version 2.0 (the - "License"); you may not use this file except in compliance - with the License. You may obtain a copy of the License at - - .. http://www.apache.org/licenses/LICENSE-2.0 - - .. Unless required by applicable law or agreed to in writing, - software distributed under the License is distributed on an - "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - KIND, either express or implied. See the License for the - specific language governing permissions and limitations - under the License. - -``apache-airflow-providers-yandex`` -=================================== - - -.. toctree:: - :hidden: - :maxdepth: 1 - :caption: Basics - - Home - Changelog - Security - -.. toctree:: - :hidden: - :maxdepth: 1 - :caption: Guides - - Configuration - Connection types - Lockbox Secret Backend - Operators - -.. toctree:: - :hidden: - :maxdepth: 1 - :caption: References - - Python API <_api/airflow/providers/yandex/index> - -.. toctree:: - :hidden: - :maxdepth: 1 - :caption: System tests - - System Tests <_api/tests/system/yandex/index> - -.. toctree:: - :hidden: - :maxdepth: 1 - :caption: Resources - - Example DAGs - PyPI Repository - Installing from sources - -.. THE REMAINDER OF THE FILE IS AUTOMATICALLY GENERATED. IT WILL BE OVERWRITTEN AT RELEASE TIME! - - -.. toctree:: - :hidden: - :maxdepth: 1 - :caption: Commits - - Detailed list of commits - - -apache-airflow-providers-yandex package ------------------------------------------------------- - -This package is for Yandex, including: - - - `Yandex.Cloud `__ - - -Release: 4.0.0 - -Provider package ----------------- - -This package is for the ``yandex`` provider. -All classes for this package are included in the ``airflow.providers.yandex`` python package. - -Installation ------------- - -You can install this package on top of an existing Airflow 2 installation via -``pip install apache-airflow-providers-yandex``. -For the minimum Airflow version supported, see ``Requirements`` below. - -Requirements ------------- - -The minimum Apache Airflow version supported by this provider package is ``2.9.0``. - -======================= ================== -PIP package Version required -======================= ================== -``apache-airflow`` ``>=2.9.0`` -``yandexcloud`` ``>=0.308.0`` -``yandex-query-client`` ``>=0.1.4`` -======================= ================== diff --git a/docs/apache-airflow-providers-yandex/installing-providers-from-sources.rst b/docs/apache-airflow-providers-yandex/installing-providers-from-sources.rst deleted file mode 100644 index b4e730f4ff21a..0000000000000 --- a/docs/apache-airflow-providers-yandex/installing-providers-from-sources.rst +++ /dev/null @@ -1,18 +0,0 @@ - .. Licensed to the Apache Software Foundation (ASF) under one - or more contributor license agreements. See the NOTICE file - distributed with this work for additional information - regarding copyright ownership. The ASF licenses this file - to you under the Apache License, Version 2.0 (the - "License"); you may not use this file except in compliance - with the License. You may obtain a copy of the License at - - .. http://www.apache.org/licenses/LICENSE-2.0 - - .. Unless required by applicable law or agreed to in writing, - software distributed under the License is distributed on an - "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - KIND, either express or implied. See the License for the - specific language governing permissions and limitations - under the License. - -.. include:: ../exts/includes/installing-providers-from-sources.rst diff --git a/docs/apache-airflow-providers-yandex/operators/dataproc.rst b/docs/apache-airflow-providers-yandex/operators/dataproc.rst deleted file mode 100644 index b7188e2ea52f6..0000000000000 --- a/docs/apache-airflow-providers-yandex/operators/dataproc.rst +++ /dev/null @@ -1,37 +0,0 @@ - .. Licensed to the Apache Software Foundation (ASF) under one - or more contributor license agreements. See the NOTICE file - distributed with this work for additional information - regarding copyright ownership. The ASF licenses this file - to you under the Apache License, Version 2.0 (the - "License"); you may not use this file except in compliance - with the License. You may obtain a copy of the License at - - .. http://www.apache.org/licenses/LICENSE-2.0 - - .. Unless required by applicable law or agreed to in writing, - software distributed under the License is distributed on an - "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - KIND, either express or implied. See the License for the - specific language governing permissions and limitations - under the License. - - -Yandex.Cloud Data Proc Operators -================================ - -`Yandex Data Proc `__ is a service -that helps you deploy Apache Hadoop®* and Apache Spark™ clusters in the Yandex Cloud infrastructure. - -With Data Proc, you can manage the cluster size and node capacity, -as well as work with various Apache® services, -such as Spark, HDFS, YARN, Hive, HBase, Oozie, Sqoop, Flume, Tez, and Zeppelin. - -Apache Hadoop is used for storing and analyzing structured and unstructured big data. - -Apache Spark is a tool for quick data processing -that can be integrated with Apache Hadoop and other storage systems. - -Using the operators -^^^^^^^^^^^^^^^^^^^ -To learn how to use Data Proc operators, -see `example DAGs `_. diff --git a/docs/apache-airflow-providers-yandex/operators/index.rst b/docs/apache-airflow-providers-yandex/operators/index.rst deleted file mode 100644 index 12b05418e100f..0000000000000 --- a/docs/apache-airflow-providers-yandex/operators/index.rst +++ /dev/null @@ -1,28 +0,0 @@ - .. Licensed to the Apache Software Foundation (ASF) under one - or more contributor license agreements. See the NOTICE file - distributed with this work for additional information - regarding copyright ownership. The ASF licenses this file - to you under the Apache License, Version 2.0 (the - "License"); you may not use this file except in compliance - with the License. You may obtain a copy of the License at - - .. http://www.apache.org/licenses/LICENSE-2.0 - - .. Unless required by applicable law or agreed to in writing, - software distributed under the License is distributed on an - "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - KIND, either express or implied. See the License for the - specific language governing permissions and limitations - under the License. - - - -Yandex.Cloud Operators -====================== - - -.. toctree:: - :maxdepth: 1 - :glob: - - * diff --git a/docs/apache-airflow-providers-yandex/operators/yq.rst b/docs/apache-airflow-providers-yandex/operators/yq.rst deleted file mode 100644 index 23bd4ac336160..0000000000000 --- a/docs/apache-airflow-providers-yandex/operators/yq.rst +++ /dev/null @@ -1,28 +0,0 @@ - .. Licensed to the Apache Software Foundation (ASF) under one - or more contributor license agreements. See the NOTICE file - distributed with this work for additional information - regarding copyright ownership. The ASF licenses this file - to you under the Apache License, Version 2.0 (the - "License"); you may not use this file except in compliance - with the License. You may obtain a copy of the License at - - .. http://www.apache.org/licenses/LICENSE-2.0 - - .. Unless required by applicable law or agreed to in writing, - software distributed under the License is distributed on an - "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - KIND, either express or implied. See the License for the - specific language governing permissions and limitations - under the License. - - -Yandex Query Operators -====================== -`Yandex Query `__ is a service in the Yandex Cloud to process data from different sources such as -`Object Storage `__, `MDB ClickHouse `__, -`MDB PostgreSQL `__, `Yandex DataStreams `__ using SQL scripts. - -Using the operators -^^^^^^^^^^^^^^^^^^^ -To learn how to use Yandex Query operator, -see `example DAG `__. diff --git a/docs/apache-airflow-providers-yandex/secrets-backends/yandex-cloud-lockbox-secret-backend.rst b/docs/apache-airflow-providers-yandex/secrets-backends/yandex-cloud-lockbox-secret-backend.rst deleted file mode 100644 index f30346b24da80..0000000000000 --- a/docs/apache-airflow-providers-yandex/secrets-backends/yandex-cloud-lockbox-secret-backend.rst +++ /dev/null @@ -1,298 +0,0 @@ - .. Licensed to the Apache Software Foundation (ASF) under one - or more contributor license agreements. See the NOTICE file - distributed with this work for additional information - regarding copyright ownership. The ASF licenses this file - to you under the Apache License, Version 2.0 (the - "License"); you may not use this file except in compliance - with the License. You may obtain a copy of the License at - - .. http://www.apache.org/licenses/LICENSE-2.0 - - .. Unless required by applicable law or agreed to in writing, - software distributed under the License is distributed on an - "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - KIND, either express or implied. See the License for the - specific language governing permissions and limitations - under the License. - - -Yandex.Cloud Lockbox Secret Backend -=================================== - -This topic describes how to configure Apache Airflow to use `Yandex Lockbox `__ -as a secret backend and how to manage secrets. - -Getting started ---------------- - -Before you start, make sure you have installed the ``yandex`` provider in your Apache Airflow installation: - -.. code-block:: bash - - pip install apache-airflow-providers-yandex - -Enabling the Yandex Lockbox secret backend ------------------------------------------- - -To enable Yandex Lockbox as a secret backend, -specify :py:class:`~airflow.providers.yandex.secrets.lockbox.LockboxSecretBackend` -as your ``backend`` in the ``[secrets]`` section of the ``airflow.cfg`` file. - -Here is a sample configuration: - -.. code-block:: ini - - [secrets] - backend = airflow.providers.yandex.secrets.lockbox.LockboxSecretBackend - -You can also set this with an environment variable: - -.. code-block:: bash - - export AIRFLOW__SECRETS__BACKEND=airflow.providers.yandex.secrets.lockbox.LockboxSecretBackend - -You can verify whether the configuration options have been set up correctly -using the ``airflow config get-value`` command: - -.. code-block:: console - - $ airflow config get-value secrets backend - airflow.providers.yandex.secrets.lockbox.LockboxSecretBackend - -Backend parameters ------------------- - -The next step is to configure backend parameters using the ``backend_kwargs`` options -that allow you to provide the following parameters: - -* ``yc_oauth_token``: Specifies the user account OAuth token to connect to Yandex Lockbox. The parameter value should look like ``y3_xx123``. -* ``yc_sa_key_json``: Specifies the service account key in JSON. The parameter value should look like ``{"id": "...", "service_account_id": "...", "private_key": "..."}``. -* ``yc_sa_key_json_path``: Specifies the service account key in JSON file path. The parameter value should look like ``/home/airflow/authorized_key.json``, while the file content should have the following format: ``{"id": "...", "service_account_id": "...", "private_key": "..."}``. -* ``yc_connection_id``: Specifies the connection ID to connect to Yandex Lockbox. The default value is ``yandexcloud_default``. -* ``folder_id``: Specifies the folder ID to search for Yandex Lockbox secrets in. If set to ``None`` (``null`` in JSON), the requests will use the connection ``folder_id``, if specified. -* ``connections_prefix``: Specifies the prefix of the secret to read to get connections. If set to ``None`` (``null`` in JSON), the requests for connections will not be sent to Yandex Lockbox. The default value is ``airflow/connections``. -* ``variables_prefix``: Specifies the prefix of the secret to read to get variables. If set to ``None`` (``null`` in JSON), the requests for variables will not be sent to Yandex Lockbox. The default value is ``airflow/variables``. -* ``config_prefix``: Specifies the prefix of the secret to read to get configurations. If set to ``None`` (``null`` in JSON), the requests for variables will not be sent to Yandex Lockbox. The default value is ``airflow/config``. -* ``sep``: Specifies the separator to concatenate ``secret_prefix`` and ``secret_id``. The default value is ``/``. -* ``endpoint``: Specifies the API endpoint. If set to ``None`` (``null`` in JSON), the requests will use the connection endpoint, if specified; otherwise, they will use the default endpoint. - -Make sure to provide all options as a JSON dictionary. - -For example, if you want to set ``connections_prefix`` to ``"example-connections-prefix"`` -and ``variables_prefix`` to ``"example-variables-prefix"``, -your configuration file should look like this: - -.. code-block:: ini - - [secrets] - backend = airflow.providers.yandex.secrets.lockbox.LockboxSecretBackend - backend_kwargs = {"connections_prefix": "example-connections-prefix", "variables_prefix": "example-variables-prefix"} - -Setting up credentials ----------------------- - -You need to specify credentials or the ID of the ``yandexcloud`` connection to connect to Yandex Lockbox. - -The credentials will be used with the following priority: - -* OAuth token -* Service account key in JSON from file -* Service account key in JSON -* Yandex Cloud connection - -If you do not specify any credentials, the system will use the default connection ID: ``yandexcloud_default``. - -Using an OAuth token to authorize as a user account -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -First, you need to create -an `OAuth token `__ for your user account. -Your token will look like this: ``y3_Vd3eub7w9bIut67GHeL345gfb5GAnd3dZnf08FR1vjeUFve7Yi8hGvc``. - -Then, you need to specify the ``folder_id`` and your token in ``backend_kwargs``: - -.. code-block:: ini - - [secrets] - backend_kwargs = {"folder_id": "b1g66mft1vo1n4vbn57j", "yc_oauth_token": "y3_Vd3eub7w9bIut67GHeL345gfb5GAnd3dZnf08FR1vjeUFve7Yi8hGvc"} - -Using authorized keys to authorize as a service account -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -Before you start, make sure you have `created `__ -a Yandex Cloud `service account `__ -with the ``lockbox.viewer`` and ``lockbox.payloadViewer`` permissions. - -First, you need to create an `authorized key `__ -for your service account and save the generated JSON file with both public and private key parts. - -Then, you need to specify the ``folder_id`` and key in ``backend_kwargs``: - -.. code-block:: ini - - [secrets] - backend_kwargs = {"folder_id": "b1g66mft1vo1n4vbn57j", "yc_sa_key_json": {"id": "...", "service_account_id": "...", "private_key": "..."}"} - -Alternatively, you can specify the path to the JSON file in ``backend_kwargs``: - -.. code-block:: ini - - [secrets] - backend_kwargs = {"folder_id": "b1g66mft1vo1n4vbn57j", "yc_sa_key_json_path": "/home/airflow/authorized_key.json"} - -Using Yandex Cloud connection for authorization -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -First, you need to create :ref:`Yandex Cloud connection `. - -Then, you need to specify the ``connection_id`` in ``backend_kwargs``: - -.. code-block:: ini - - [secrets] - backend_kwargs = {"yc_connection_id": "my_yc_connection"} - -If you do not specify any credentials, -Lockbox Secret Backend will try to use the default connection ID: ``yandexcloud_default``. - -Lockbox Secret Backend will try to use the default folder ID from your connection. -You can also specify the ``folder_id`` in the ``backend_kwargs``: - -.. code-block:: ini - - [secrets] - backend_kwargs = {"folder_id": "b1g66mft1vo1n4vbn57j", "yc_connection_id": "my_yc_connection"} - -Storing and retrieving connections ----------------------------------- - -To store a connection, you need to `create a secret `__ -with a name in the following format: ``{connections_prefix}{sep}{connection_name}``. - -The payload must contain a text value with any key. - -Storing a connection as a URI -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -The main way to save connections is using a :ref:`connection URI representation `, such as -``mysql://myname:mypassword@myhost.com?this_param=some+val&that_param=other+val%2A``. - -Here is an example of creating a secret with the ``yc`` CLI: - -.. code-block:: console - - $ yc lockbox secret create \ - --name airflow/connections/mysqldb \ - --payload '[{"key": "value", "text_value": "mysql://myname:mypassword@myhost.com?this_param=some+val&that_param=other+val%2A"}]' - done (1s) - name: airflow/connections/mysqldb - -Storing a connection as JSON -~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -Another way to store connections is using JSON format: - -.. code-block:: json - - { - "conn_type": "mysql", - "host": "host.com", - "login": "myname", - "password": "mypassword", - "extra": { - "this_param": "some val", - "that_param": "other val*" - } - } - -Here is an example of creating a secret with the ``yc`` CLI: - -.. code-block:: console - - $ yc lockbox secret create \ - --name airflow/connections/my_sql_db_json \ - --payload '[{"key": "value", "text_value": "{\"conn_type\": \"mysql\", \"host\": \"host.com\", \"login\": \"myname\", \"password\": \"mypassword\", \"extra\": {\"this_param\": \"some val\", \"that_param\": \"other val*\"}}"}]' - done (1s) - name: airflow/connections/my_sql_db_json - -Retrieving connection -~~~~~~~~~~~~~~~~~~~~~ - -To check the connection is correctly read from the Lockbox Secret Backend, you can use ``airflow connections get``: - -.. code-block:: console - - $ airflow connections get mysqldb -o json - [{"id": null, "conn_id": "mysqldb", "conn_type": "mysql", "description": null, "host": "host.com", "schema": "", "login": "myname", "password": "mypassword", "port": null, "is_encrypted": "False", "is_extra_encrypted": "False", "extra_dejson": {"this_param": "some val", "that_param": "other val*"}, "get_uri": "mysql://myname:mypassword@myhost.com/?this_param=some+val&that_param=other+val%2A"}] - -Storing and retrieving variables --------------------------------- - -To store a variable, you need to `create a secret `__ -with a name in the following format: ``{variables_prefix}{sep}{variable_name}``. -The payload must contain a text value with any key. - -Here is how a variable value may look like: ``some_secret_data``. - -Here is an example of creating a secret with the ``yc`` CLI: - -.. code-block:: console - - $ yc lockbox secret create \ - --name airflow/variables/my_variable \ - --payload '[{"key": "value", "text_value": "some_secret_data"}]' - done (1s) - name: airflow/variables/my_variable - -To check the variable is correctly read from the Lockbox Secret Backend, you can use ``airflow variables get``: - -.. code-block:: console - - $ airflow variables get my_variable - some_secret_data - -Storing and retrieving configs ------------------------------- - -Lockbox Secret Backend is also suitable for storing sensitive configurations. - -For example, we will provide you with a secret for ``sentry.sentry_dsn`` -and use ``sentry_dsn_value`` as the config value name. - -To store a config, you need to `create a secret `__ -with a name in the following format: ``{config_prefix}{sep}{config_value_name}``. -The payload must contain a text value with any key. - -Here is an example of creating a secret with the ``yc`` CLI: - -.. code-block:: console - - $ yc lockbox secret create \ - --name airflow/config/sentry_dsn_value \ - --payload '[{"key": "value", "text_value": "https://public@sentry.example.com/1"}]' - done (1s) - name: airflow/config/sentry_dsn_value - -Then, we need to specify the config value name as ``{key}_secret`` in the Apache Airflow configuration: - -.. code-block:: ini - - [sentry] - sentry_dsn_secret = sentry_dsn_value - -To check the config value is correctly read from the Lockbox Secret Backend, you can use ``airflow config get-value``: - -.. code-block:: console - - $ airflow config get-value sentry sentry_dsn - https://public@sentry.example.com/1 - -Cleaning up your secret ------------------------ - -You can easily delete your secret with the ``yc`` CLI: - -.. code-block:: console - - $ yc lockbox secret delete --name airflow/connections/mysqldb - name: airflow/connections/mysqldb diff --git a/docs/apache-airflow-providers-yandex/security.rst b/docs/apache-airflow-providers-yandex/security.rst deleted file mode 100644 index afa13dac6fc9b..0000000000000 --- a/docs/apache-airflow-providers-yandex/security.rst +++ /dev/null @@ -1,18 +0,0 @@ - .. Licensed to the Apache Software Foundation (ASF) under one - or more contributor license agreements. See the NOTICE file - distributed with this work for additional information - regarding copyright ownership. The ASF licenses this file - to you under the Apache License, Version 2.0 (the - "License"); you may not use this file except in compliance - with the License. You may obtain a copy of the License at - - .. http://www.apache.org/licenses/LICENSE-2.0 - - .. Unless required by applicable law or agreed to in writing, - software distributed under the License is distributed on an - "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - KIND, either express or implied. See the License for the - specific language governing permissions and limitations - under the License. - -.. include:: ../exts/includes/security.rst diff --git a/docs/apache-airflow/authoring-and-scheduling/timetable.rst b/docs/apache-airflow/authoring-and-scheduling/timetable.rst index b84b0b55768f4..79d5e50777008 100644 --- a/docs/apache-airflow/authoring-and-scheduling/timetable.rst +++ b/docs/apache-airflow/authoring-and-scheduling/timetable.rst @@ -105,6 +105,45 @@ must be a :class:`datetime.timedelta` or ``dateutil.relativedelta.relativedelta` pass +.. _MultipleCronTriggerTimetable: + +MultipleCronTriggerTimetable +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +This is similar to CronTriggerTimetable_ except it takes multiple cron expressions. A DAG run is scheduled whenever any of the expressions matches the time. It is particularly useful when the desired schedule cannot be expressed by one single cron expression. + +.. code-block:: python + + from airflow.timetables.trigger import MultipleCronTriggerTimetable + + + # At 1:10 and 2:40 each day. + @dag(schedule=MultipleCronTriggerTimetable("10 1 * * *", "40 2 * * *", timezone="UTC"), ...) + def example_dag(): + pass + +The same optional ``interval`` argument as CronTriggerTimetable_ is also available. + +.. code-block:: python + + from datetime import timedelta + + from airflow.timetables.trigger import MultipleCronTriggerTimetable + + + @dag( + schedule=MultipleCronTriggerTimetable( + "10 1 * * *", + "40 2 * * *", + timezone="UTC", + interval=timedelta(hours=1), + ), + ..., + ) + def example_dag(): + pass + + .. _DeltaDataIntervalTimetable: DeltaDataIntervalTimetable diff --git a/docs/apache-airflow/howto/variable.rst b/docs/apache-airflow/howto/variable.rst index b4b395dd63c2c..5e0017fb1c09b 100644 --- a/docs/apache-airflow/howto/variable.rst +++ b/docs/apache-airflow/howto/variable.rst @@ -37,7 +37,7 @@ Storing Variables in Environment Variables Airflow Variables can also be created and managed using Environment Variables. The environment variable naming convention is :envvar:`AIRFLOW_VAR_{VARIABLE_NAME}`, all uppercase. -So if your variable key is ``FOO`` then the variable name should be ``AIRFLOW_VAR_FOO``. +So if your variable key is ``foo`` then the variable name should be ``AIRFLOW_VAR_FOO``. For example, diff --git a/docs/apache-airflow/img/airflow_erd.sha256 b/docs/apache-airflow/img/airflow_erd.sha256 index e3606c3476353..9a124863bdd4a 100644 --- a/docs/apache-airflow/img/airflow_erd.sha256 +++ b/docs/apache-airflow/img/airflow_erd.sha256 @@ -1 +1 @@ -232f2f252ce0d3889fa5a9ceb00c88788e12083a6ea0c155c74d3fe61ad02412 \ No newline at end of file +76818a684a0e05c1fd3ecee6c74b204c9a8d59b22966c62ba08089312fbd6ff4 \ No newline at end of file diff --git a/docs/apache-airflow/img/airflow_erd.svg b/docs/apache-airflow/img/airflow_erd.svg index 7efa5c1512485..94a49bb5dba18 100644 --- a/docs/apache-airflow/img/airflow_erd.svg +++ b/docs/apache-airflow/img/airflow_erd.svg @@ -1864,7 +1864,6 @@ logical_date [TIMESTAMP] - NOT NULL queued_at diff --git a/docs/apache-airflow/migrations-ref.rst b/docs/apache-airflow/migrations-ref.rst index c39bcfb3a55f7..027327806f316 100644 --- a/docs/apache-airflow/migrations-ref.rst +++ b/docs/apache-airflow/migrations-ref.rst @@ -92,7 +92,7 @@ Here's the list of all the Database Migrations that are executed via when you ru +-------------------------+------------------+-------------------+--------------------------------------------------------------+ | ``522625f6d606`` | ``1cdc775ca98f`` | ``3.0.0`` | Add tables for backfill. | +-------------------------+------------------+-------------------+--------------------------------------------------------------+ -| ``1cdc775ca98f`` | ``a2c32e6c7729`` | ``3.0.0`` | Drop ``execution_date`` unique constraint on DagRun. | +| ``1cdc775ca98f`` | ``a2c32e6c7729`` | ``3.0.0`` | Make logical_date nullable. | +-------------------------+------------------+-------------------+--------------------------------------------------------------+ | ``a2c32e6c7729`` | ``0bfc26bc256e`` | ``3.0.0`` | Add triggered_by field to DagRun. | +-------------------------+------------------+-------------------+--------------------------------------------------------------+ diff --git a/docs/apache-airflow/security/secrets/mask-sensitive-values.rst b/docs/apache-airflow/security/secrets/mask-sensitive-values.rst index a66900f3dcdad..60c7de44b379c 100644 --- a/docs/apache-airflow/security/secrets/mask-sensitive-values.rst +++ b/docs/apache-airflow/security/secrets/mask-sensitive-values.rst @@ -58,7 +58,7 @@ your DAG file or operator's ``execute`` function using the ``mask_secret`` funct @task def my_func(): - from airflow.utils.log.secrets_masker import mask_secret + from airflow.sdk.execution_time.secrets_masker import mask_secret mask_secret("custom_value") @@ -71,7 +71,7 @@ or class MyOperator(BaseOperator): def execute(self, context): - from airflow.utils.log.secrets_masker import mask_secret + from airflow.sdk.execution_time.secrets_masker import mask_secret mask_secret("custom_value") diff --git a/docs/apache-airflow/start.rst b/docs/apache-airflow/start.rst index ea8d624d74d9b..ff51325d0c842 100644 --- a/docs/apache-airflow/start.rst +++ b/docs/apache-airflow/start.rst @@ -24,7 +24,7 @@ This quick start guide will help you bootstrap an Airflow standalone instance on .. note:: - Successful installation requires a Python 3 environment. Starting with Airflow 2.7.0, Airflow supports Python 3.9, 3.10, 3.11 and 3.12. + Successful installation requires a Python 3 environment. Starting with Airflow 2.7.0, Airflow supports Python 3.9, 3.10, 3.11, and 3.12. Only ``pip`` installation is currently officially supported. @@ -44,7 +44,27 @@ This quick start guide will help you bootstrap an Airflow standalone instance on The installation of Airflow is straightforward if you follow the instructions below. Airflow uses constraint files to enable reproducible installation, so using ``pip`` and constraint files is recommended. -1. Set Airflow Home (optional): +1. **(Recommended) Create and Activate a Virtual Environment**: + + To avoid issues such as the ``externally-managed-environment`` error, particularly on modern Linux distributions like Ubuntu 22.04+ and Debian 12+, it is highly recommended to install Airflow inside a Python virtual environment. This approach prevents conflicts with system-level Python packages and ensures smooth installation. + + For more details on this error, see the Python Packaging Authority's explanation in the `PEP 668 documentation `_. + + .. code-block:: bash + + # Create a virtual environment in your desired directory + python3 -m venv airflow_venv + + # Activate the virtual environment + source airflow_venv/bin/activate + + # Upgrade pip within the virtual environment + pip install --upgrade pip + + # Optional: Deactivate the virtual environment when done + deactivate + +2. **Set Airflow Home (optional)**: Airflow requires a home directory, and uses ``~/airflow`` by default, but you can set a different location if you prefer. The ``AIRFLOW_HOME`` environment variable is used to inform Airflow of the desired location. This step of setting the environment variable should be done before installing Airflow so that the installation process knows where to store the necessary files. @@ -52,7 +72,8 @@ constraint files to enable reproducible installation, so using ``pip`` and const export AIRFLOW_HOME=~/airflow -2. Install Airflow using the constraints file, which is determined based on the URL we pass: + +3. Install Airflow using the constraints file, which is determined based on the URL we pass: .. code-block:: bash :substitutions: @@ -69,7 +90,7 @@ constraint files to enable reproducible installation, so using ``pip`` and const pip install "apache-airflow==${AIRFLOW_VERSION}" --constraint "${CONSTRAINT_URL}" -3. Run Airflow Standalone: +4. Run Airflow Standalone: The ``airflow standalone`` command initializes the database, creates a user, and starts all components. @@ -77,7 +98,7 @@ constraint files to enable reproducible installation, so using ``pip`` and const airflow standalone -4. Access the Airflow UI: +5. Access the Airflow UI: Visit ``localhost:8080`` in your browser and log in with the admin account details shown in the terminal. Enable the ``example_bash_operator`` DAG in the home page. diff --git a/docs/apache-airflow/templates-ref.rst b/docs/apache-airflow/templates-ref.rst index cf7b015141b94..c652de23912a9 100644 --- a/docs/apache-airflow/templates-ref.rst +++ b/docs/apache-airflow/templates-ref.rst @@ -41,18 +41,9 @@ Variable Type Description ``{{ logical_date }}`` `pendulum.DateTime`_ | A date-time that logically identifies the current DAG run. This value does not contain any semantics, but is simply a value for identification. | Use ``data_interval_start`` and ``data_interval_end`` instead if you want a value that has real-world semantics, | such as to get a slice of rows from the database based on timestamps. -``{{ ds }}`` str | The DAG run's logical date as ``YYYY-MM-DD``. - | Same as ``{{ logical_date | ds }}``. -``{{ ds_nodash }}`` str Same as ``{{ logical_date | ds_nodash }}``. ``{{ exception }}`` None | str | | Error occurred while running task instance. Exception | KeyboardInterrupt | -``{{ ts }}`` str | Same as ``{{ logical_date | ts }}``. - | Example: ``2018-01-01T00:00:00+00:00``. -``{{ ts_nodash_with_tz }}`` str | Same as ``{{ logical_date | ts_nodash_with_tz }}``. - | Example: ``20180101T000000+0000``. -``{{ ts_nodash }}`` str | Same as ``{{ logical_date | ts_nodash }}``. - | Example: ``20180101T000000``. ``{{ prev_data_interval_start_success }}`` `pendulum.DateTime`_ | Start of the data interval of the prior successful :class:`~airflow.models.dagrun.DagRun`. | ``None`` | Added in version 2.2. ``{{ prev_data_interval_end_success }}`` `pendulum.DateTime`_ | End of the data interval of the prior successful :class:`~airflow.models.dagrun.DagRun`. @@ -92,6 +83,22 @@ Variable Type Description | Added in version 2.4. =========================================== ===================== =================================================================== +The following are only available when the DagRun has a ``logical_date`` + +=========================================== ===================== =================================================================== +Variable Type Description +=========================================== ===================== =================================================================== +``{{ ds }}`` str | The DAG run's logical date as ``YYYY-MM-DD``. + | Same as ``{{ logical_date | ds }}``. +``{{ ds_nodash }}`` str Same as ``{{ logical_date | ds_nodash }}``. +``{{ ts }}`` str | Same as ``{{ logical_date | ts }}``. + | Example: ``2018-01-01T00:00:00+00:00``. +``{{ ts_nodash_with_tz }}`` str | Same as ``{{ logical_date | ts_nodash_with_tz }}``. + | Example: ``20180101T000000+0000``. +``{{ ts_nodash }}`` str | Same as ``{{ logical_date | ts_nodash }}``. + | Example: ``20180101T000000``. +=========================================== ===================== =================================================================== + .. note:: The DAG run's logical date, and values derived from it, such as ``ds`` and diff --git a/docs/apache-airflow/tutorial/taskflow.rst b/docs/apache-airflow/tutorial/taskflow.rst index 8c77c5ddc6718..455465f545c47 100644 --- a/docs/apache-airflow/tutorial/taskflow.rst +++ b/docs/apache-airflow/tutorial/taskflow.rst @@ -338,7 +338,7 @@ Below is an example of using the ``@task.kubernetes`` decorator to run a Python .. _taskflow/kubernetes_example: -.. exampleinclude:: /../../providers/tests/system/cncf/kubernetes/example_kubernetes_decorator.py +.. exampleinclude:: /../../providers/cncf/kubernetes/tests/system/cncf/kubernetes/example_kubernetes_decorator.py :language: python :dedent: 4 :start-after: [START howto_operator_kubernetes] diff --git a/docs/integration-logos/yandex/Yandex-Cloud.png b/docs/integration-logos/yandex/Yandex-Cloud.png deleted file mode 100644 index 33e8e1b71b2c9..0000000000000 Binary files a/docs/integration-logos/yandex/Yandex-Cloud.png and /dev/null differ diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index dbd8c44015bbc..787ced4abb28f 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -352,7 +352,7 @@ "cloudant": { "deps": [ "apache-airflow>=2.9.0", - "ibmcloudant==0.9.1 ; python_version >= \"3.10\"" + "ibmcloudant==0.9.1;python_version>=\"3.10\"" ], "devel-deps": [], "plugins": [], @@ -526,7 +526,7 @@ "deps": [ "apache-airflow>=2.10.0", "pydantic>=2.10.2", - "retryhttp>=1.2.0" + "retryhttp>=1.2.0,!=1.3.0" ], "devel-deps": [], "plugins": [ diff --git a/helm_tests/airflow_core/test_api_server.py b/helm_tests/airflow_core/test_api_server.py index 91e6418f1db5e..1d6bea96673ed 100644 --- a/helm_tests/airflow_core/test_api_server.py +++ b/helm_tests/airflow_core/test_api_server.py @@ -39,9 +39,9 @@ class TestAPIServerDeployment: [(8, 10), (10, 8), (8, None), (None, 10), (None, None)], ) def test_revision_history_limit(self, revision_history_limit, global_revision_history_limit): - values = {"apiServer": {}} + values = {"_apiServer": {}} if revision_history_limit: - values["apiServer"]["revisionHistoryLimit"] = revision_history_limit + values["_apiServer"]["revisionHistoryLimit"] = revision_history_limit if global_revision_history_limit: values["revisionHistoryLimit"] = global_revision_history_limit docs = render_chart( @@ -54,7 +54,7 @@ def test_revision_history_limit(self, revision_history_limit, global_revision_hi def test_should_add_scheme_to_liveness_and_readiness_and_startup_probes(self): docs = render_chart( values={ - "apiServer": { + "_apiServer": { "livenessProbe": {"scheme": "HTTPS"}, "readinessProbe": {"scheme": "HTTPS"}, "startupProbe": {"scheme": "HTTPS"}, @@ -77,7 +77,7 @@ def test_should_add_extra_containers(self): docs = render_chart( values={ "executor": "CeleryExecutor", - "apiServer": { + "_apiServer": { "extraContainers": [ {"name": "{{.Chart.Name}}", "image": "test-registry/test-repo:test-tag"} ], @@ -94,7 +94,7 @@ def test_should_add_extra_containers(self): def test_should_add_extraEnvs(self): docs = render_chart( values={ - "apiServer": { + "_apiServer": { "env": [{"name": "TEST_ENV_1", "value": "test_env_1"}], }, }, @@ -108,7 +108,7 @@ def test_should_add_extraEnvs(self): def test_should_add_extra_volume_and_extra_volume_mount(self): docs = render_chart( values={ - "apiServer": { + "_apiServer": { "extraVolumes": [{"name": "test-volume-{{ .Chart.Name }}", "emptyDir": {}}], "extraVolumeMounts": [ {"name": "test-volume-{{ .Chart.Name }}", "mountPath": "/opt/test"} @@ -146,7 +146,7 @@ def test_should_add_global_volume_and_global_volume_mount(self): def test_should_add_extraEnvs_to_wait_for_migration_container(self): docs = render_chart( values={ - "apiServer": { + "_apiServer": { "waitForMigrations": { "env": [{"name": "TEST_ENV_1", "value": "test_env_1"}], }, @@ -171,7 +171,7 @@ def test_wait_for_migration_airflow_version(self): def test_disable_wait_for_migration(self): docs = render_chart( values={ - "apiServer": { + "_apiServer": { "waitForMigrations": {"enabled": False}, }, }, @@ -185,7 +185,7 @@ def test_disable_wait_for_migration(self): def test_should_add_extra_init_containers(self): docs = render_chart( values={ - "apiServer": { + "_apiServer": { "extraInitContainers": [ {"name": "test-init-container", "image": "test-registry/test-repo:test-tag"} ], @@ -202,7 +202,7 @@ def test_should_add_extra_init_containers(self): def test_should_add_component_specific_labels(self): docs = render_chart( values={ - "apiServer": { + "_apiServer": { "labels": {"test_label": "test_label_value"}, }, }, @@ -215,7 +215,7 @@ def test_should_add_component_specific_labels(self): def test_should_create_valid_affinity_tolerations_and_node_selector(self): docs = render_chart( values={ - "apiServer": { + "_apiServer": { "affinity": { "nodeAffinity": { "requiredDuringSchedulingIgnoredDuringExecution": { @@ -300,7 +300,7 @@ def test_affinity_tolerations_topology_spread_constraints_and_node_selector_prec } docs = render_chart( values={ - "apiServer": { + "_apiServer": { "affinity": expected_affinity, "tolerations": [ {"key": "dynamic-pods", "operator": "Equal", "value": "true", "effect": "NoSchedule"} @@ -412,7 +412,7 @@ def test_config_volumes(self): def testapi_server_resources_are_configurable(self): docs = render_chart( values={ - "apiServer": { + "_apiServer": { "resources": { "limits": {"cpu": "200m", "memory": "128Mi"}, "requests": {"cpu": "300m", "memory": "169Mi"}, @@ -447,7 +447,7 @@ def testapi_server_resources_are_configurable(self): def test_api_server_security_contexts_are_configurable(self): docs = render_chart( values={ - "apiServer": { + "_apiServer": { "securityContexts": { "pod": { "fsGroup": 1000, @@ -480,7 +480,7 @@ def test_api_server_security_context_legacy(self): with pytest.raises(CalledProcessError, match="Additional property securityContext is not allowed"): render_chart( values={ - "apiServer": { + "_apiServer": { "securityContext": { "fsGroup": 1000, "runAsGroup": 1001, @@ -518,7 +518,7 @@ def test_api_server_resources_are_not_added_by_default(self): ) def test_update_strategy(self, airflow_version, strategy, expected_strategy): docs = render_chart( - values={"airflowVersion": airflow_version, "apiServer": {"strategy": expected_strategy}}, + values={"airflowVersion": airflow_version, "_apiServer": {"strategy": expected_strategy}}, show_only=["templates/api-server/api-server-deployment.yaml"], ) @@ -540,7 +540,7 @@ def test_default_command_and_args(self): @pytest.mark.parametrize("args", [None, ["custom", "args"]]) def test_command_and_args_overrides(self, command, args): docs = render_chart( - values={"apiServer": {"command": command, "args": args}}, + values={"_apiServer": {"command": command, "args": args}}, show_only=["templates/api-server/api-server-deployment.yaml"], ) @@ -550,7 +550,7 @@ def test_command_and_args_overrides(self, command, args): def test_command_and_args_overrides_are_templated(self): docs = render_chart( values={ - "apiServer": { + "_apiServer": { "command": ["{{ .Release.Name }}"], "args": ["{{ .Release.Service }}"], } @@ -563,7 +563,7 @@ def test_command_and_args_overrides_are_templated(self): def test_should_add_component_specific_annotations(self): docs = render_chart( - values={"apiServer": {"annotations": {"test_annotation": "test_annotation_value"}}}, + values={"_apiServer": {"annotations": {"test_annotation": "test_annotation_value"}}}, show_only=["templates/api-server/api-server-deployment.yaml"], ) assert "annotations" in jmespath.search("metadata", docs[0]) @@ -571,7 +571,7 @@ def test_should_add_component_specific_annotations(self): def test_api_server_pod_hostaliases(self): docs = render_chart( - values={"apiServer": {"hostAliases": [{"ip": "127.0.0.1", "hostnames": ["foo.local"]}]}}, + values={"_apiServer": {"hostAliases": [{"ip": "127.0.0.1", "hostnames": ["foo.local"]}]}}, show_only=["templates/api-server/api-server-deployment.yaml"], ) @@ -600,8 +600,8 @@ def test_default_service(self): def test_overrides(self): docs = render_chart( values={ - "ports": {"apiServer": 9000}, - "apiServer": { + "ports": {"_apiServer": 9000}, + "_apiServer": { "service": { "type": "LoadBalancer", "loadBalancerIP": "127.0.0.1", @@ -628,7 +628,7 @@ def test_overrides(self): { "name": "{{ .Release.Name }}", "protocol": "UDP", - "port": "{{ .Values.ports.apiServer }}", + "port": "{{ .Values.ports._apiServer }}", } ], [{"name": "release-name", "protocol": "UDP", "port": 9091}], @@ -636,7 +636,7 @@ def test_overrides(self): ([{"name": "only_sidecar", "port": "{{ int 9000 }}"}], [{"name": "only_sidecar", "port": 9000}]), ( [ - {"name": "api-server", "port": "{{ .Values.ports.apiServer }}"}, + {"name": "api-server", "port": "{{ .Values.ports._apiServer }}"}, {"name": "sidecar", "port": 80, "targetPort": "sidecar"}, ], [ @@ -648,7 +648,7 @@ def test_overrides(self): ) def test_ports_overrides(self, ports, expected_ports): docs = render_chart( - values={"apiServer": {"service": {"ports": ports}}}, + values={"_apiServer": {"service": {"ports": ports}}}, show_only=["templates/api-server/api-server-service.yaml"], ) @@ -656,7 +656,7 @@ def test_ports_overrides(self, ports, expected_ports): def test_should_add_component_specific_labels(self): docs = render_chart( - values={"apiServer": {"labels": {"test_label": "test_label_value"}}}, + values={"_apiServer": {"labels": {"test_label": "test_label_value"}}}, show_only=["templates/api-server/api-server-service.yaml"], ) assert "test_label" in jmespath.search("metadata.labels", docs[0]) @@ -677,7 +677,7 @@ def test_should_add_component_specific_labels(self): ) def test_nodeport_service(self, ports, expected_ports): docs = render_chart( - values={"apiServer": {"service": {"type": "NodePort", "ports": ports}}}, + values={"_apiServer": {"service": {"type": "NodePort", "ports": ports}}}, show_only=["templates/api-server/api-server-service.yaml"], ) @@ -698,7 +698,7 @@ def test_defaults(self): docs = render_chart( values={ "networkPolicies": {"enabled": True}, - "apiServer": { + "_apiServer": { "networkPolicy": { "ingress": { "from": [{"namespaceSelector": {"matchLabels": {"release": "myrelease"}}}] @@ -722,7 +722,7 @@ def test_defaults(self): ([{"port": "sidecar"}], [{"port": "sidecar"}]), ( [ - {"port": "{{ .Values.ports.apiServer }}"}, + {"port": "{{ .Values.ports._apiServer }}"}, {"port": 80}, ], [ @@ -736,7 +736,7 @@ def test_ports_overrides(self, ports, expected_ports): docs = render_chart( values={ "networkPolicies": {"enabled": True}, - "apiServer": { + "_apiServer": { "networkPolicy": { "ingress": { "from": [{"namespaceSelector": {"matchLabels": {"release": "myrelease"}}}], @@ -754,7 +754,7 @@ def test_should_add_component_specific_labels(self): docs = render_chart( values={ "networkPolicies": {"enabled": True}, - "apiServer": { + "_apiServer": { "labels": {"test_label": "test_label_value"}, }, }, @@ -770,7 +770,7 @@ class TestAPIServerServiceAccount: def test_should_add_component_specific_labels(self): docs = render_chart( values={ - "apiServer": { + "_apiServer": { "serviceAccount": {"create": True}, "labels": {"test_label": "test_label_value"}, }, @@ -783,7 +783,7 @@ def test_should_add_component_specific_labels(self): def test_default_automount_service_account_token(self): docs = render_chart( values={ - "apiServer": { + "_apiServer": { "serviceAccount": {"create": True}, }, }, @@ -794,7 +794,7 @@ def test_default_automount_service_account_token(self): def test_overridden_automount_service_account_token(self): docs = render_chart( values={ - "apiServer": { + "_apiServer": { "serviceAccount": {"create": True, "automountServiceAccountToken": False}, }, }, diff --git a/helm_tests/airflow_core/test_scheduler.py b/helm_tests/airflow_core/test_scheduler.py index d21393cc8476c..ce8434c944e6a 100644 --- a/helm_tests/airflow_core/test_scheduler.py +++ b/helm_tests/airflow_core/test_scheduler.py @@ -1022,7 +1022,7 @@ def test_should_add_component_specific_labels(self): ("CeleryKubernetesExecutor", None), ("KubernetesExecutor", None), ("LocalKubernetesExecutor", None), - ("CeleryExecutor,KubernetesExecutor", True), + ("CeleryExecutor,KubernetesExecutor", None), ], ) def test_default_automount_service_account_token(self, executor, default_automount_service_account): @@ -1045,6 +1045,7 @@ def test_default_automount_service_account_token(self, executor, default_automou ("CeleryKubernetesExecutor", False, None), ("KubernetesExecutor", False, None), ("LocalKubernetesExecutor", False, None), + ("CeleryExecutor,KubernetesExecutor", False, None), ], ) def test_overridden_automount_service_account_token( diff --git a/helm_tests/security/test_rbac.py b/helm_tests/security/test_rbac.py index 0eecb9aadeb4d..c494ebc7ee6e3 100644 --- a/helm_tests/security/test_rbac.py +++ b/helm_tests/security/test_rbac.py @@ -153,7 +153,7 @@ def test_deployments_no_rbac_no_sa(self, version): "scheduler": {"serviceAccount": {"create": False}}, "dagProcessor": {"serviceAccount": {"create": False}}, "webserver": {"serviceAccount": {"create": False}}, - "apiServer": {"serviceAccount": {"create": False}}, + "_apiServer": {"serviceAccount": {"create": False}}, "workers": {"serviceAccount": {"create": False}}, "triggerer": {"serviceAccount": {"create": False}}, "statsd": {"serviceAccount": {"create": False}}, @@ -206,7 +206,7 @@ def test_deployments_with_rbac_no_sa(self, version): "scheduler": {"serviceAccount": {"create": False}}, "dagProcessor": {"serviceAccount": {"create": False}}, "webserver": {"serviceAccount": {"create": False}}, - "apiServer": {"serviceAccount": {"create": False}}, + "_apiServer": {"serviceAccount": {"create": False}}, "workers": {"serviceAccount": {"create": False}}, "triggerer": {"serviceAccount": {"create": False}}, "flower": {"enabled": True, "serviceAccount": {"create": False}}, @@ -267,7 +267,7 @@ def test_service_account_custom_names(self): "scheduler": {"serviceAccount": {"name": CUSTOM_SCHEDULER_NAME}}, "dagProcessor": {"serviceAccount": {"name": CUSTOM_DAG_PROCESSOR_NAME}}, "webserver": {"serviceAccount": {"name": CUSTOM_WEBSERVER_NAME}}, - "apiServer": {"serviceAccount": {"name": CUSTOM_API_SERVER_NAME}}, + "_apiServer": {"serviceAccount": {"name": CUSTOM_API_SERVER_NAME}}, "workers": {"serviceAccount": {"name": CUSTOM_WORKER_NAME}}, "triggerer": {"serviceAccount": {"name": CUSTOM_TRIGGERER_NAME}}, "flower": {"enabled": True, "serviceAccount": {"name": CUSTOM_FLOWER_NAME}}, @@ -306,7 +306,7 @@ def test_service_account_custom_names_in_objects(self): "scheduler": {"serviceAccount": {"name": CUSTOM_SCHEDULER_NAME}}, "dagProcessor": {"serviceAccount": {"name": CUSTOM_DAG_PROCESSOR_NAME}}, "webserver": {"serviceAccount": {"name": CUSTOM_WEBSERVER_NAME}}, - "apiServer": {"serviceAccount": {"name": CUSTOM_API_SERVER_NAME}}, + "_apiServer": {"serviceAccount": {"name": CUSTOM_API_SERVER_NAME}}, "workers": {"serviceAccount": {"name": CUSTOM_WORKER_NAME}}, "triggerer": {"serviceAccount": {"name": CUSTOM_TRIGGERER_NAME}}, "flower": {"enabled": True, "serviceAccount": {"name": CUSTOM_FLOWER_NAME}}, diff --git a/kubernetes_tests/test_base.py b/kubernetes_tests/test_base.py index 27a267b4486c3..cadd8954cfb5e 100644 --- a/kubernetes_tests/test_base.py +++ b/kubernetes_tests/test_base.py @@ -196,6 +196,14 @@ def monitor_task(self, host, dag_run_id, dag_id, task_id, expected_final_state, print(f"The expected state is wrong {state} != {expected_final_state} (expected)!") assert state == expected_final_state + @staticmethod + def ensure_deployment_health(deployment_name: str, namespace: str = "airflow"): + """Watch the deployment until it is healthy.""" + deployment_rollout_status = check_output( + ["kubectl", "rollout", "status", "deployment", deployment_name, "-n", namespace, "--watch"] + ).decode() + assert "successfully rolled out" in deployment_rollout_status + def ensure_dag_expected_state(self, host, logical_date, dag_id, expected_final_state, timeout): tries = 0 state = "" diff --git a/kubernetes_tests/test_kubernetes_executor.py b/kubernetes_tests/test_kubernetes_executor.py index a0bd832156189..622a4daaa0df0 100644 --- a/kubernetes_tests/test_kubernetes_executor.py +++ b/kubernetes_tests/test_kubernetes_executor.py @@ -16,8 +16,6 @@ # under the License. from __future__ import annotations -import time - import pytest from kubernetes_tests.test_base import ( @@ -59,8 +57,7 @@ def test_integration_run_dag_with_scheduler_failure(self): dag_run_id, logical_date = self.start_job_in_kubernetes(dag_id, self.host) self._delete_airflow_pod("scheduler") - - time.sleep(10) # give time for pod to restart + self.ensure_deployment_health("airflow-scheduler") # Wait some time for the operator to complete self.monitor_task( diff --git a/newsfragments/42404.significant.rst b/newsfragments/42404.significant.rst index c9d1212a3f204..da57b01bde197 100644 --- a/newsfragments/42404.significant.rst +++ b/newsfragments/42404.significant.rst @@ -4,6 +4,8 @@ The shift towards using ``run_id`` as the sole identifier for DAG runs eliminate - Removed ``logical_date`` arguments from public APIs and Python functions related to DAG run lookups. - ``run_id`` is now the exclusive identifier for DAG runs in these contexts. +- ``ds``, ``ds_nodash``, ``ts``, ``ts_nodash``, ``ts_nodash_with_tz`` (and ``logical_date``) will no longer exist for non-scheduled DAG runs (i.e. manually triggered runs) +- ``task_instance_key_str`` template variable has changed to use ``run_id``, not the logical_date. This means the value of it will change compared to 2.x, even for old runs * Types of change diff --git a/newsfragments/46375.significant.rst b/newsfragments/46375.significant.rst new file mode 100644 index 0000000000000..fd3a91ad7157b --- /dev/null +++ b/newsfragments/46375.significant.rst @@ -0,0 +1,22 @@ +``SecretsMasker`` has now been moved into the task SDK to be consumed by DAG authors and users + +Any occurrences of the ``secrets_masker`` module will have to be updated from ``airflow.utils.log.secrets_masker`` to the new path: ``airflow.sdk.execution_time.secrets_masker`` + +* Types of change + + * [ ] Dag changes + * [ ] Config changes + * [ ] API changes + * [ ] CLI changes + * [x] Behaviour changes + * [ ] Plugin changes + * [ ] Dependency changes + * [ ] Code interface changes + +* Migration rules needed + + * ruff + + * AIR302 + + * [ ] ``airflow.utils.log.secrets_masker`` -> ``airflow.sdk.execution_time.secrets_masker`` diff --git a/newsfragments/46408.significant.rst b/newsfragments/46408.significant.rst new file mode 100644 index 0000000000000..5088f81d4d8b1 --- /dev/null +++ b/newsfragments/46408.significant.rst @@ -0,0 +1,30 @@ +DAG processor related config options removed + +The follow configuration options have been removed: + +- ``[logging] dag_processor_manager_log_location`` +- ``[logging] dag_processor_manager_log_stdout`` +- ``[logging] log_processor_filename_template`` + +If these config options are still present, they will have no effect any longer. + +* Types of change + + * [ ] Dag changes + * [x] Config changes + * [ ] API changes + * [ ] CLI changes + * [ ] Behaviour changes + * [ ] Plugin changes + * [ ] Dependency changes + * [ ] Code interface changes + +.. List the migration rules needed for this change (see https://github.com/apache/airflow/issues/41641) + +* Migration rules needed + + * ``airflow config lint`` + + * Remove ``[logging] dag_processor_manager_log_location`` + * Remove ``[logging] dag_processor_manager_log_stdout`` + * Remove ``[logging] log_processor_filename_template`` diff --git a/providers/apache/flink/README.rst b/providers/apache/flink/README.rst new file mode 100644 index 0000000000000..bd30dfb93b8e2 --- /dev/null +++ b/providers/apache/flink/README.rst @@ -0,0 +1,82 @@ + + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + .. NOTE! THIS FILE IS AUTOMATICALLY GENERATED AND WILL BE OVERWRITTEN! + + .. IF YOU WANT TO MODIFY TEMPLATE FOR THIS FILE, YOU SHOULD MODIFY THE TEMPLATE + `PROVIDER_README_TEMPLATE.rst.jinja2` IN the `dev/breeze/src/airflow_breeze/templates` DIRECTORY + + +Package ``apache-airflow-providers-apache-flink`` + +Release: ``1.6.0`` + + +`Apache Flink `__ + + +Provider package +---------------- + +This is a provider package for ``apache.flink`` provider. All classes for this provider package +are in ``airflow.providers.apache.flink`` python package. + +You can find package information and changelog for the provider +in the `documentation `_. + +Installation +------------ + +You can install this package on top of an existing Airflow 2 installation (see ``Requirements`` below +for the minimum Airflow version supported) via +``pip install apache-airflow-providers-apache-flink`` + +The package supports the following python versions: 3.9,3.10,3.11,3.12 + +Requirements +------------ + +============================================ ================== +PIP package Version required +============================================ ================== +``apache-airflow`` ``>=2.9.0`` +``cryptography`` ``>=41.0.0`` +``apache-airflow-providers-cncf-kubernetes`` ``>=5.1.0`` +============================================ ================== + +Cross provider package dependencies +----------------------------------- + +Those are dependencies that might be needed in order to use all the features of the package. +You need to install the specified provider packages in order to use them. + +You can install such cross-provider dependencies when installing from PyPI. For example: + +.. code-block:: bash + + pip install apache-airflow-providers-apache-flink[cncf.kubernetes] + + +====================================================================================================================== =================== +Dependent package Extra +====================================================================================================================== =================== +`apache-airflow-providers-cncf-kubernetes `_ ``cncf.kubernetes`` +====================================================================================================================== =================== + +The changelog for the provider package can be found in the +`changelog `_. diff --git a/providers/src/airflow/providers/apache/flink/.latest-doc-only-change.txt b/providers/apache/flink/docs/.latest-doc-only-change.txt similarity index 100% rename from providers/src/airflow/providers/apache/flink/.latest-doc-only-change.txt rename to providers/apache/flink/docs/.latest-doc-only-change.txt diff --git a/providers/src/airflow/providers/apache/flink/CHANGELOG.rst b/providers/apache/flink/docs/changelog.rst similarity index 100% rename from providers/src/airflow/providers/apache/flink/CHANGELOG.rst rename to providers/apache/flink/docs/changelog.rst diff --git a/docs/apache-airflow-providers-apache-flink/commits.rst b/providers/apache/flink/docs/commits.rst similarity index 100% rename from docs/apache-airflow-providers-apache-flink/commits.rst rename to providers/apache/flink/docs/commits.rst diff --git a/docs/apache-airflow-providers-apache-flink/index.rst b/providers/apache/flink/docs/index.rst similarity index 100% rename from docs/apache-airflow-providers-apache-flink/index.rst rename to providers/apache/flink/docs/index.rst diff --git a/docs/apache-airflow-providers-apache-flink/installing-providers-from-sources.rst b/providers/apache/flink/docs/installing-providers-from-sources.rst similarity index 100% rename from docs/apache-airflow-providers-apache-flink/installing-providers-from-sources.rst rename to providers/apache/flink/docs/installing-providers-from-sources.rst diff --git a/docs/integration-logos/kubernetes/FlinkOnK8s.png b/providers/apache/flink/docs/integration-logos/FlinkOnK8s.png similarity index 100% rename from docs/integration-logos/kubernetes/FlinkOnK8s.png rename to providers/apache/flink/docs/integration-logos/FlinkOnK8s.png diff --git a/docs/apache-airflow-providers-apache-flink/operators.rst b/providers/apache/flink/docs/operators.rst similarity index 100% rename from docs/apache-airflow-providers-apache-flink/operators.rst rename to providers/apache/flink/docs/operators.rst diff --git a/docs/apache-airflow-providers-apache-flink/security.rst b/providers/apache/flink/docs/security.rst similarity index 100% rename from docs/apache-airflow-providers-apache-flink/security.rst rename to providers/apache/flink/docs/security.rst diff --git a/providers/src/airflow/providers/apache/flink/provider.yaml b/providers/apache/flink/provider.yaml similarity index 90% rename from providers/src/airflow/providers/apache/flink/provider.yaml rename to providers/apache/flink/provider.yaml index 8085f4cb9dac8..4b6085a8f2c50 100644 --- a/providers/src/airflow/providers/apache/flink/provider.yaml +++ b/providers/apache/flink/provider.yaml @@ -40,17 +40,12 @@ versions: - 1.0.1 - 1.0.0 -dependencies: - - apache-airflow>=2.9.0 - - cryptography>=41.0.0 - - apache-airflow-providers-cncf-kubernetes>=5.1.0 - integrations: - integration-name: Apache Flink external-doc-url: https://github.com/apache/flink-kubernetes-operator how-to-guide: - /docs/apache-airflow-providers-apache-flink/operators.rst - logo: /integration-logos/kubernetes/FlinkOnK8s.png + logo: /docs/integration-logos/FlinkOnK8s.png tags: [apache] operators: diff --git a/providers/apache/flink/pyproject.toml b/providers/apache/flink/pyproject.toml new file mode 100644 index 0000000000000..9aa8e089c7f71 --- /dev/null +++ b/providers/apache/flink/pyproject.toml @@ -0,0 +1,79 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# NOTE! THIS FILE IS AUTOMATICALLY GENERATED AND WILL BE OVERWRITTEN! + +# IF YOU WANT TO MODIFY THIS FILE EXCEPT DEPENDENCIES, YOU SHOULD MODIFY THE TEMPLATE +# `pyproject_TEMPLATE.toml.jinja2` IN the `dev/breeze/src/airflow_breeze/templates` DIRECTORY +[build-system] +requires = ["flit_core==3.10.1"] +build-backend = "flit_core.buildapi" + +[project] +name = "apache-airflow-providers-apache-flink" +version = "1.6.0" +description = "Provider package apache-airflow-providers-apache-flink for Apache Airflow" +readme = "README.rst" +authors = [ + {name="Apache Software Foundation", email="dev@airflow.apache.org"}, +] +maintainers = [ + {name="Apache Software Foundation", email="dev@airflow.apache.org"}, +] +keywords = [ "airflow-provider", "apache.flink", "airflow", "integration" ] +classifiers = [ + "Development Status :: 5 - Production/Stable", + "Environment :: Console", + "Environment :: Web Environment", + "Intended Audience :: Developers", + "Intended Audience :: System Administrators", + "Framework :: Apache Airflow", + "Framework :: Apache Airflow :: Provider", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Topic :: System :: Monitoring", +] +requires-python = "~=3.9" + +# The dependencies should be modified in place in the generated file +# Any change in the dependencies is preserved when the file is regenerated +dependencies = [ + "apache-airflow>=2.9.0", + "cryptography>=41.0.0", + "apache-airflow-providers-cncf-kubernetes>=5.1.0", +] + +[project.urls] +"Documentation" = "https://airflow.apache.org/docs/apache-airflow-providers-apache-flink/1.6.0" +"Changelog" = "https://airflow.apache.org/docs/apache-airflow-providers-apache-flink/1.6.0/changelog.html" +"Bug Tracker" = "https://github.com/apache/airflow/issues" +"Source Code" = "https://github.com/apache/airflow" +"Slack Chat" = "https://s.apache.org/airflow-slack" +"Twitter" = "https://x.com/ApacheAirflow" +"YouTube" = "https://www.youtube.com/channel/UCSXwxpWZQ7XZ1WL3wqevChA/" + +[project.entry-points."apache_airflow_provider"] +provider_info = "airflow.providers.apache.flink.get_provider_info:get_provider_info" + +[tool.flit.module] +name = "airflow.providers.apache.flink" + +[tool.pytest.ini_options] +ignore = "tests/system/" diff --git a/providers/apache/flink/src/airflow/providers/apache/flink/LICENSE b/providers/apache/flink/src/airflow/providers/apache/flink/LICENSE new file mode 100644 index 0000000000000..11069edd79019 --- /dev/null +++ b/providers/apache/flink/src/airflow/providers/apache/flink/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + +TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + +1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + +2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + +3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + +4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + +5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + +6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + +7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + +8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + +9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + +END OF TERMS AND CONDITIONS + +APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + +Copyright [yyyy] [name of copyright owner] + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/providers/src/airflow/providers/apache/flink/__init__.py b/providers/apache/flink/src/airflow/providers/apache/flink/__init__.py similarity index 100% rename from providers/src/airflow/providers/apache/flink/__init__.py rename to providers/apache/flink/src/airflow/providers/apache/flink/__init__.py diff --git a/providers/apache/flink/src/airflow/providers/apache/flink/get_provider_info.py b/providers/apache/flink/src/airflow/providers/apache/flink/get_provider_info.py new file mode 100644 index 0000000000000..141b206377b28 --- /dev/null +++ b/providers/apache/flink/src/airflow/providers/apache/flink/get_provider_info.py @@ -0,0 +1,73 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# NOTE! THIS FILE IS AUTOMATICALLY GENERATED AND WILL BE OVERWRITTEN! +# +# IF YOU WANT TO MODIFY THIS FILE, YOU SHOULD MODIFY THE TEMPLATE +# `get_provider_info_TEMPLATE.py.jinja2` IN the `dev/breeze/src/airflow_breeze/templates` DIRECTORY + + +def get_provider_info(): + return { + "package-name": "apache-airflow-providers-apache-flink", + "name": "Apache Flink", + "description": "`Apache Flink `__\n", + "state": "ready", + "source-date-epoch": 1734527925, + "versions": [ + "1.6.0", + "1.5.1", + "1.5.0", + "1.4.2", + "1.4.1", + "1.4.0", + "1.3.0", + "1.2.0", + "1.1.3", + "1.1.2", + "1.1.1", + "1.1.0", + "1.0.1", + "1.0.0", + ], + "integrations": [ + { + "integration-name": "Apache Flink", + "external-doc-url": "https://github.com/apache/flink-kubernetes-operator", + "how-to-guide": ["/docs/apache-airflow-providers-apache-flink/operators.rst"], + "logo": "/docs/integration-logos/FlinkOnK8s.png", + "tags": ["apache"], + } + ], + "operators": [ + { + "integration-name": "Apache Flink", + "python-modules": ["airflow.providers.apache.flink.operators.flink_kubernetes"], + } + ], + "sensors": [ + { + "integration-name": "Apache Flink", + "python-modules": ["airflow.providers.apache.flink.sensors.flink_kubernetes"], + } + ], + "dependencies": [ + "apache-airflow>=2.9.0", + "cryptography>=41.0.0", + "apache-airflow-providers-cncf-kubernetes>=5.1.0", + ], + } diff --git a/providers/src/airflow/providers/apache/flink/hooks/__init__.py b/providers/apache/flink/src/airflow/providers/apache/flink/hooks/__init__.py similarity index 100% rename from providers/src/airflow/providers/apache/flink/hooks/__init__.py rename to providers/apache/flink/src/airflow/providers/apache/flink/hooks/__init__.py diff --git a/providers/src/airflow/providers/apache/flink/operators/__init__.py b/providers/apache/flink/src/airflow/providers/apache/flink/operators/__init__.py similarity index 100% rename from providers/src/airflow/providers/apache/flink/operators/__init__.py rename to providers/apache/flink/src/airflow/providers/apache/flink/operators/__init__.py diff --git a/providers/src/airflow/providers/apache/flink/operators/flink_kubernetes.py b/providers/apache/flink/src/airflow/providers/apache/flink/operators/flink_kubernetes.py similarity index 100% rename from providers/src/airflow/providers/apache/flink/operators/flink_kubernetes.py rename to providers/apache/flink/src/airflow/providers/apache/flink/operators/flink_kubernetes.py diff --git a/providers/src/airflow/providers/apache/flink/sensors/__init__.py b/providers/apache/flink/src/airflow/providers/apache/flink/sensors/__init__.py similarity index 100% rename from providers/src/airflow/providers/apache/flink/sensors/__init__.py rename to providers/apache/flink/src/airflow/providers/apache/flink/sensors/__init__.py diff --git a/providers/src/airflow/providers/apache/flink/sensors/flink_kubernetes.py b/providers/apache/flink/src/airflow/providers/apache/flink/sensors/flink_kubernetes.py similarity index 100% rename from providers/src/airflow/providers/apache/flink/sensors/flink_kubernetes.py rename to providers/apache/flink/src/airflow/providers/apache/flink/sensors/flink_kubernetes.py diff --git a/providers/apache/flink/tests/conftest.py b/providers/apache/flink/tests/conftest.py new file mode 100644 index 0000000000000..068fe6bbf5ae9 --- /dev/null +++ b/providers/apache/flink/tests/conftest.py @@ -0,0 +1,32 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import pathlib + +import pytest + +pytest_plugins = "tests_common.pytest_plugin" + + +@pytest.hookimpl(tryfirst=True) +def pytest_configure(config: pytest.Config) -> None: + deprecations_ignore_path = pathlib.Path(__file__).parent.joinpath("deprecations_ignore.yml") + dep_path = [deprecations_ignore_path] if deprecations_ignore_path.exists() else [] + config.inicfg["airflow_deprecations_ignore"] = ( + config.inicfg.get("airflow_deprecations_ignore", []) + dep_path # type: ignore[assignment,operator] + ) diff --git a/providers/apache/flink/tests/provider_tests/__init__.py b/providers/apache/flink/tests/provider_tests/__init__.py new file mode 100644 index 0000000000000..e8fd22856438c --- /dev/null +++ b/providers/apache/flink/tests/provider_tests/__init__.py @@ -0,0 +1,17 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +__path__ = __import__("pkgutil").extend_path(__path__, __name__) # type: ignore diff --git a/providers/apache/flink/tests/provider_tests/apache/__init__.py b/providers/apache/flink/tests/provider_tests/apache/__init__.py new file mode 100644 index 0000000000000..e8fd22856438c --- /dev/null +++ b/providers/apache/flink/tests/provider_tests/apache/__init__.py @@ -0,0 +1,17 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +__path__ = __import__("pkgutil").extend_path(__path__, __name__) # type: ignore diff --git a/providers/src/airflow/providers/apache/impala/hooks/__init__.py b/providers/apache/flink/tests/provider_tests/apache/flink/__init__.py similarity index 100% rename from providers/src/airflow/providers/apache/impala/hooks/__init__.py rename to providers/apache/flink/tests/provider_tests/apache/flink/__init__.py diff --git a/providers/src/airflow/providers/cloudant/hooks/__init__.py b/providers/apache/flink/tests/provider_tests/apache/flink/operators/__init__.py similarity index 100% rename from providers/src/airflow/providers/cloudant/hooks/__init__.py rename to providers/apache/flink/tests/provider_tests/apache/flink/operators/__init__.py diff --git a/providers/tests/apache/flink/operators/test_flink_kubernetes.py b/providers/apache/flink/tests/provider_tests/apache/flink/operators/test_flink_kubernetes.py similarity index 100% rename from providers/tests/apache/flink/operators/test_flink_kubernetes.py rename to providers/apache/flink/tests/provider_tests/apache/flink/operators/test_flink_kubernetes.py diff --git a/providers/src/airflow/providers/cncf/kubernetes/decorators/__init__.py b/providers/apache/flink/tests/provider_tests/apache/flink/sensors/__init__.py similarity index 100% rename from providers/src/airflow/providers/cncf/kubernetes/decorators/__init__.py rename to providers/apache/flink/tests/provider_tests/apache/flink/sensors/__init__.py diff --git a/providers/tests/apache/flink/sensors/test_flink_kubernetes.py b/providers/apache/flink/tests/provider_tests/apache/flink/sensors/test_flink_kubernetes.py similarity index 100% rename from providers/tests/apache/flink/sensors/test_flink_kubernetes.py rename to providers/apache/flink/tests/provider_tests/apache/flink/sensors/test_flink_kubernetes.py diff --git a/providers/apache/impala/README.rst b/providers/apache/impala/README.rst new file mode 100644 index 0000000000000..d8c915f479aa6 --- /dev/null +++ b/providers/apache/impala/README.rst @@ -0,0 +1,82 @@ + + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + .. NOTE! THIS FILE IS AUTOMATICALLY GENERATED AND WILL BE OVERWRITTEN! + + .. IF YOU WANT TO MODIFY TEMPLATE FOR THIS FILE, YOU SHOULD MODIFY THE TEMPLATE + `PROVIDER_README_TEMPLATE.rst.jinja2` IN the `dev/breeze/src/airflow_breeze/templates` DIRECTORY + + +Package ``apache-airflow-providers-apache-impala`` + +Release: ``1.6.0`` + + +`Apache Impala `__. + + +Provider package +---------------- + +This is a provider package for ``apache.impala`` provider. All classes for this provider package +are in ``airflow.providers.apache.impala`` python package. + +You can find package information and changelog for the provider +in the `documentation `_. + +Installation +------------ + +You can install this package on top of an existing Airflow 2 installation (see ``Requirements`` below +for the minimum Airflow version supported) via +``pip install apache-airflow-providers-apache-impala`` + +The package supports the following python versions: 3.9,3.10,3.11,3.12 + +Requirements +------------ + +======================================= ================== +PIP package Version required +======================================= ================== +``impyla`` ``>=0.18.0,<1.0`` +``apache-airflow-providers-common-sql`` ``>=1.20.0`` +``apache-airflow`` ``>=2.9.0`` +======================================= ================== + +Cross provider package dependencies +----------------------------------- + +Those are dependencies that might be needed in order to use all the features of the package. +You need to install the specified provider packages in order to use them. + +You can install such cross-provider dependencies when installing from PyPI. For example: + +.. code-block:: bash + + pip install apache-airflow-providers-apache-impala[common.sql] + + +============================================================================================================ ============== +Dependent package Extra +============================================================================================================ ============== +`apache-airflow-providers-common-sql `_ ``common.sql`` +============================================================================================================ ============== + +The changelog for the provider package can be found in the +`changelog `_. diff --git a/providers/src/airflow/providers/apache/impala/.latest-doc-only-change.txt b/providers/apache/impala/docs/.latest-doc-only-change.txt similarity index 100% rename from providers/src/airflow/providers/apache/impala/.latest-doc-only-change.txt rename to providers/apache/impala/docs/.latest-doc-only-change.txt diff --git a/providers/src/airflow/providers/apache/impala/CHANGELOG.rst b/providers/apache/impala/docs/changelog.rst similarity index 100% rename from providers/src/airflow/providers/apache/impala/CHANGELOG.rst rename to providers/apache/impala/docs/changelog.rst diff --git a/docs/apache-airflow-providers-apache-impala/commits.rst b/providers/apache/impala/docs/commits.rst similarity index 100% rename from docs/apache-airflow-providers-apache-impala/commits.rst rename to providers/apache/impala/docs/commits.rst diff --git a/docs/apache-airflow-providers-apache-impala/connections/impala.rst b/providers/apache/impala/docs/connections/impala.rst similarity index 100% rename from docs/apache-airflow-providers-apache-impala/connections/impala.rst rename to providers/apache/impala/docs/connections/impala.rst diff --git a/docs/apache-airflow-providers-apache-impala/index.rst b/providers/apache/impala/docs/index.rst similarity index 100% rename from docs/apache-airflow-providers-apache-impala/index.rst rename to providers/apache/impala/docs/index.rst diff --git a/docs/apache-airflow-providers-apache-impala/installing-providers-from-sources.rst b/providers/apache/impala/docs/installing-providers-from-sources.rst similarity index 100% rename from docs/apache-airflow-providers-apache-impala/installing-providers-from-sources.rst rename to providers/apache/impala/docs/installing-providers-from-sources.rst diff --git a/docs/apache-airflow-providers-apache-impala/security.rst b/providers/apache/impala/docs/security.rst similarity index 100% rename from docs/apache-airflow-providers-apache-impala/security.rst rename to providers/apache/impala/docs/security.rst diff --git a/providers/src/airflow/providers/apache/impala/provider.yaml b/providers/apache/impala/provider.yaml similarity index 87% rename from providers/src/airflow/providers/apache/impala/provider.yaml rename to providers/apache/impala/provider.yaml index 52f9e003972f5..1946b027155e0 100644 --- a/providers/src/airflow/providers/apache/impala/provider.yaml +++ b/providers/apache/impala/provider.yaml @@ -41,19 +41,6 @@ versions: - 1.1.0 - 1.0.0 -dependencies: - - impyla>=0.18.0,<1.0 - - apache-airflow-providers-common-sql>=1.20.0 - - apache-airflow>=2.9.0 - -additional-extras: - - name: kerberos - dependencies: - - kerberos>=1.3.0 - -devel-dependencies: - - kerberos>=1.3.0 - integrations: - integration-name: Apache Impala external-doc-url: https://impala.apache.org diff --git a/providers/apache/impala/pyproject.toml b/providers/apache/impala/pyproject.toml new file mode 100644 index 0000000000000..ef8f12e045a01 --- /dev/null +++ b/providers/apache/impala/pyproject.toml @@ -0,0 +1,93 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# NOTE! THIS FILE IS AUTOMATICALLY GENERATED AND WILL BE OVERWRITTEN! + +# IF YOU WANT TO MODIFY THIS FILE EXCEPT DEPENDENCIES, YOU SHOULD MODIFY THE TEMPLATE +# `pyproject_TEMPLATE.toml.jinja2` IN the `dev/breeze/src/airflow_breeze/templates` DIRECTORY +[build-system] +requires = ["flit_core==3.10.1"] +build-backend = "flit_core.buildapi" + +[project] +name = "apache-airflow-providers-apache-impala" +version = "1.6.0" +description = "Provider package apache-airflow-providers-apache-impala for Apache Airflow" +readme = "README.rst" +authors = [ + {name="Apache Software Foundation", email="dev@airflow.apache.org"}, +] +maintainers = [ + {name="Apache Software Foundation", email="dev@airflow.apache.org"}, +] +keywords = [ "airflow-provider", "apache.impala", "airflow", "integration" ] +classifiers = [ + "Development Status :: 5 - Production/Stable", + "Environment :: Console", + "Environment :: Web Environment", + "Intended Audience :: Developers", + "Intended Audience :: System Administrators", + "Framework :: Apache Airflow", + "Framework :: Apache Airflow :: Provider", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Topic :: System :: Monitoring", +] +requires-python = "~=3.9" + +# The dependencies should be modified in place in the generated file +# Any change in the dependencies is preserved when the file is regenerated +dependencies = [ + "impyla>=0.18.0,<1.0", + "apache-airflow-providers-common-sql>=1.20.0", + "apache-airflow>=2.9.0", +] + +# The optional dependencies should be modified in place in the generated file +# Any change in the dependencies is preserved when the file is regenerated +[project.optional-dependencies] +"kerberos" = [ + "kerberos>=1.3.0", +] + +# The dependency groups should be modified in place in the generated file +# Any change in the dependencies is preserved when the file is regenerated +[dependency-groups] +dev = [ + "kerberos>=1.3.0", +] + +[project.urls] +"Documentation" = "https://airflow.apache.org/docs/apache-airflow-providers-apache-impala/1.6.0" +"Changelog" = "https://airflow.apache.org/docs/apache-airflow-providers-apache-impala/1.6.0/changelog.html" +"Bug Tracker" = "https://github.com/apache/airflow/issues" +"Source Code" = "https://github.com/apache/airflow" +"Slack Chat" = "https://s.apache.org/airflow-slack" +"Twitter" = "https://x.com/ApacheAirflow" +"YouTube" = "https://www.youtube.com/channel/UCSXwxpWZQ7XZ1WL3wqevChA/" + +[project.entry-points."apache_airflow_provider"] +provider_info = "airflow.providers.apache.impala.get_provider_info:get_provider_info" + +[tool.flit.module] +name = "airflow.providers.apache.impala" + +[tool.pytest.ini_options] +ignore = "tests/system/" diff --git a/providers/apache/impala/src/airflow/providers/apache/impala/LICENSE b/providers/apache/impala/src/airflow/providers/apache/impala/LICENSE new file mode 100644 index 0000000000000..11069edd79019 --- /dev/null +++ b/providers/apache/impala/src/airflow/providers/apache/impala/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + +TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + +1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + +2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + +3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + +4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + +5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + +6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + +7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + +8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + +9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + +END OF TERMS AND CONDITIONS + +APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + +Copyright [yyyy] [name of copyright owner] + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/providers/src/airflow/providers/apache/impala/__init__.py b/providers/apache/impala/src/airflow/providers/apache/impala/__init__.py similarity index 100% rename from providers/src/airflow/providers/apache/impala/__init__.py rename to providers/apache/impala/src/airflow/providers/apache/impala/__init__.py diff --git a/providers/apache/impala/src/airflow/providers/apache/impala/get_provider_info.py b/providers/apache/impala/src/airflow/providers/apache/impala/get_provider_info.py new file mode 100644 index 0000000000000..a39e1413ab0f7 --- /dev/null +++ b/providers/apache/impala/src/airflow/providers/apache/impala/get_provider_info.py @@ -0,0 +1,74 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# NOTE! THIS FILE IS AUTOMATICALLY GENERATED AND WILL BE OVERWRITTEN! +# +# IF YOU WANT TO MODIFY THIS FILE, YOU SHOULD MODIFY THE TEMPLATE +# `get_provider_info_TEMPLATE.py.jinja2` IN the `dev/breeze/src/airflow_breeze/templates` DIRECTORY + + +def get_provider_info(): + return { + "package-name": "apache-airflow-providers-apache-impala", + "name": "Apache Impala", + "description": "`Apache Impala `__.\n", + "state": "ready", + "source-date-epoch": 1734528154, + "versions": [ + "1.6.0", + "1.5.2", + "1.5.1", + "1.5.0", + "1.4.2", + "1.4.1", + "1.4.0", + "1.3.0", + "1.2.1", + "1.2.0", + "1.1.3", + "1.1.2", + "1.1.1", + "1.1.0", + "1.0.0", + ], + "integrations": [ + { + "integration-name": "Apache Impala", + "external-doc-url": "https://impala.apache.org", + "tags": ["apache"], + } + ], + "hooks": [ + { + "integration-name": "Apache Impala", + "python-modules": ["airflow.providers.apache.impala.hooks.impala"], + } + ], + "connection-types": [ + { + "hook-class-name": "airflow.providers.apache.impala.hooks.impala.ImpalaHook", + "connection-type": "impala", + } + ], + "dependencies": [ + "impyla>=0.18.0,<1.0", + "apache-airflow-providers-common-sql>=1.20.0", + "apache-airflow>=2.9.0", + ], + "optional-dependencies": {"kerberos": ["kerberos>=1.3.0"]}, + "devel-dependencies": ["kerberos>=1.3.0"], + } diff --git a/providers/src/airflow/providers/cncf/kubernetes/executors/__init__.py b/providers/apache/impala/src/airflow/providers/apache/impala/hooks/__init__.py similarity index 100% rename from providers/src/airflow/providers/cncf/kubernetes/executors/__init__.py rename to providers/apache/impala/src/airflow/providers/apache/impala/hooks/__init__.py diff --git a/providers/src/airflow/providers/apache/impala/hooks/impala.py b/providers/apache/impala/src/airflow/providers/apache/impala/hooks/impala.py similarity index 99% rename from providers/src/airflow/providers/apache/impala/hooks/impala.py rename to providers/apache/impala/src/airflow/providers/apache/impala/hooks/impala.py index aaa510945553b..e95eeac908c7c 100644 --- a/providers/src/airflow/providers/apache/impala/hooks/impala.py +++ b/providers/apache/impala/src/airflow/providers/apache/impala/hooks/impala.py @@ -18,9 +18,8 @@ from typing import TYPE_CHECKING -from impala.dbapi import connect - from airflow.providers.common.sql.hooks.sql import DbApiHook +from impala.dbapi import connect if TYPE_CHECKING: from impala.interface import Connection diff --git a/providers/apache/impala/tests/conftest.py b/providers/apache/impala/tests/conftest.py new file mode 100644 index 0000000000000..068fe6bbf5ae9 --- /dev/null +++ b/providers/apache/impala/tests/conftest.py @@ -0,0 +1,32 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import pathlib + +import pytest + +pytest_plugins = "tests_common.pytest_plugin" + + +@pytest.hookimpl(tryfirst=True) +def pytest_configure(config: pytest.Config) -> None: + deprecations_ignore_path = pathlib.Path(__file__).parent.joinpath("deprecations_ignore.yml") + dep_path = [deprecations_ignore_path] if deprecations_ignore_path.exists() else [] + config.inicfg["airflow_deprecations_ignore"] = ( + config.inicfg.get("airflow_deprecations_ignore", []) + dep_path # type: ignore[assignment,operator] + ) diff --git a/providers/apache/impala/tests/provider_tests/__init__.py b/providers/apache/impala/tests/provider_tests/__init__.py new file mode 100644 index 0000000000000..e8fd22856438c --- /dev/null +++ b/providers/apache/impala/tests/provider_tests/__init__.py @@ -0,0 +1,17 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +__path__ = __import__("pkgutil").extend_path(__path__, __name__) # type: ignore diff --git a/providers/apache/impala/tests/provider_tests/apache/__init__.py b/providers/apache/impala/tests/provider_tests/apache/__init__.py new file mode 100644 index 0000000000000..e8fd22856438c --- /dev/null +++ b/providers/apache/impala/tests/provider_tests/apache/__init__.py @@ -0,0 +1,17 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +__path__ = __import__("pkgutil").extend_path(__path__, __name__) # type: ignore diff --git a/providers/src/airflow/providers/amazon/aws/auth_manager/views/__init__.py b/providers/apache/impala/tests/provider_tests/apache/impala/__init__.py similarity index 100% rename from providers/src/airflow/providers/amazon/aws/auth_manager/views/__init__.py rename to providers/apache/impala/tests/provider_tests/apache/impala/__init__.py diff --git a/providers/src/airflow/providers/cncf/kubernetes/cli/__init__.py b/providers/apache/impala/tests/provider_tests/apache/impala/hooks/__init__.py similarity index 100% rename from providers/src/airflow/providers/cncf/kubernetes/cli/__init__.py rename to providers/apache/impala/tests/provider_tests/apache/impala/hooks/__init__.py diff --git a/providers/tests/apache/impala/hooks/test_impala.py b/providers/apache/impala/tests/provider_tests/apache/impala/hooks/test_impala.py similarity index 100% rename from providers/tests/apache/impala/hooks/test_impala.py rename to providers/apache/impala/tests/provider_tests/apache/impala/hooks/test_impala.py diff --git a/providers/cloudant/README.rst b/providers/cloudant/README.rst new file mode 100644 index 0000000000000..84bef64ed2376 --- /dev/null +++ b/providers/cloudant/README.rst @@ -0,0 +1,62 @@ + + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + .. NOTE! THIS FILE IS AUTOMATICALLY GENERATED AND WILL BE OVERWRITTEN! + + .. IF YOU WANT TO MODIFY TEMPLATE FOR THIS FILE, YOU SHOULD MODIFY THE TEMPLATE + `PROVIDER_README_TEMPLATE.rst.jinja2` IN the `dev/breeze/src/airflow_breeze/templates` DIRECTORY + + +Package ``apache-airflow-providers-cloudant`` + +Release: ``4.1.0`` + + +`IBM Cloudant `__ + + +Provider package +---------------- + +This is a provider package for ``cloudant`` provider. All classes for this provider package +are in ``airflow.providers.cloudant`` python package. + +You can find package information and changelog for the provider +in the `documentation `_. + +Installation +------------ + +You can install this package on top of an existing Airflow 2 installation (see ``Requirements`` below +for the minimum Airflow version supported) via +``pip install apache-airflow-providers-cloudant`` + +The package supports the following python versions: 3.10,3.11,3.12 + +Requirements +------------ + +================== ===================================== +PIP package Version required +================== ===================================== +``apache-airflow`` ``>=2.9.0`` +``ibmcloudant`` ``==0.9.1; python_version >= "3.10"`` +================== ===================================== + +The changelog for the provider package can be found in the +`changelog `_. diff --git a/providers/src/airflow/providers/cloudant/.latest-doc-only-change.txt b/providers/cloudant/docs/.latest-doc-only-change.txt similarity index 100% rename from providers/src/airflow/providers/cloudant/.latest-doc-only-change.txt rename to providers/cloudant/docs/.latest-doc-only-change.txt diff --git a/providers/src/airflow/providers/cloudant/CHANGELOG.rst b/providers/cloudant/docs/changelog.rst similarity index 100% rename from providers/src/airflow/providers/cloudant/CHANGELOG.rst rename to providers/cloudant/docs/changelog.rst diff --git a/docs/apache-airflow-providers-cloudant/commits.rst b/providers/cloudant/docs/commits.rst similarity index 100% rename from docs/apache-airflow-providers-cloudant/commits.rst rename to providers/cloudant/docs/commits.rst diff --git a/docs/apache-airflow-providers-cloudant/index.rst b/providers/cloudant/docs/index.rst similarity index 100% rename from docs/apache-airflow-providers-cloudant/index.rst rename to providers/cloudant/docs/index.rst diff --git a/docs/apache-airflow-providers-cloudant/installing-providers-from-sources.rst b/providers/cloudant/docs/installing-providers-from-sources.rst similarity index 100% rename from docs/apache-airflow-providers-cloudant/installing-providers-from-sources.rst rename to providers/cloudant/docs/installing-providers-from-sources.rst diff --git a/docs/integration-logos/cloudant/Cloudant.png b/providers/cloudant/docs/integration-logos/Cloudant.png similarity index 100% rename from docs/integration-logos/cloudant/Cloudant.png rename to providers/cloudant/docs/integration-logos/Cloudant.png diff --git a/docs/apache-airflow-providers-cloudant/security.rst b/providers/cloudant/docs/security.rst similarity index 100% rename from docs/apache-airflow-providers-cloudant/security.rst rename to providers/cloudant/docs/security.rst diff --git a/providers/src/airflow/providers/cloudant/provider.yaml b/providers/cloudant/provider.yaml similarity index 86% rename from providers/src/airflow/providers/cloudant/provider.yaml rename to providers/cloudant/provider.yaml index dbbabf157a88a..8a4944a2f6278 100644 --- a/providers/src/airflow/providers/cloudant/provider.yaml +++ b/providers/cloudant/provider.yaml @@ -49,12 +49,6 @@ versions: - 1.0.1 - 1.0.0 -dependencies: - - apache-airflow>=2.9.0 - # Even though 3.9 is excluded below, we need to make this python_version aware so that `uv` can generate a - # full lock file when building lock file from provider sources - - 'ibmcloudant==0.9.1 ; python_version >= "3.10"' - excluded-python-versions: # ibmcloudant transitively brings in urllib3 2.x, but the snowflake provider has a dependency that pins # urllib3 to 1.x on Python 3.9; thus we exclude those Python versions from taking the update @@ -65,7 +59,7 @@ excluded-python-versions: integrations: - integration-name: IBM Cloudant external-doc-url: https://www.ibm.com/cloud/cloudant - logo: /integration-logos/cloudant/Cloudant.png + logo: /docs/integration-logos/Cloudant.png tags: [service] hooks: diff --git a/providers/cloudant/pyproject.toml b/providers/cloudant/pyproject.toml new file mode 100644 index 0000000000000..61f2ce3c0ead7 --- /dev/null +++ b/providers/cloudant/pyproject.toml @@ -0,0 +1,79 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# NOTE! THIS FILE IS AUTOMATICALLY GENERATED AND WILL BE OVERWRITTEN! + +# IF YOU WANT TO MODIFY THIS FILE EXCEPT DEPENDENCIES, YOU SHOULD MODIFY THE TEMPLATE +# `pyproject_TEMPLATE.toml.jinja2` IN the `dev/breeze/src/airflow_breeze/templates` DIRECTORY +[build-system] +requires = ["flit_core==3.10.1"] +build-backend = "flit_core.buildapi" + +[project] +name = "apache-airflow-providers-cloudant" +version = "4.1.0" +description = "Provider package apache-airflow-providers-cloudant for Apache Airflow" +readme = "README.rst" +authors = [ + {name="Apache Software Foundation", email="dev@airflow.apache.org"}, +] +maintainers = [ + {name="Apache Software Foundation", email="dev@airflow.apache.org"}, +] +keywords = [ "airflow-provider", "cloudant", "airflow", "integration" ] +classifiers = [ + "Development Status :: 5 - Production/Stable", + "Environment :: Console", + "Environment :: Web Environment", + "Intended Audience :: Developers", + "Intended Audience :: System Administrators", + "Framework :: Apache Airflow", + "Framework :: Apache Airflow :: Provider", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Topic :: System :: Monitoring", +] +requires-python = "~=3.9,!=3.9" + +# The dependencies should be modified in place in the generated file +# Any change in the dependencies is preserved when the file is regenerated +dependencies = [ + "apache-airflow>=2.9.0", + # Even though 3.9 is excluded below, we need to make this python_version aware so that `uv` can generate a + # full lock file when building lock file from provider sources + "ibmcloudant==0.9.1;python_version>=\"3.10\"", +] + +[project.urls] +"Documentation" = "https://airflow.apache.org/docs/apache-airflow-providers-cloudant/4.1.0" +"Changelog" = "https://airflow.apache.org/docs/apache-airflow-providers-cloudant/4.1.0/changelog.html" +"Bug Tracker" = "https://github.com/apache/airflow/issues" +"Source Code" = "https://github.com/apache/airflow" +"Slack Chat" = "https://s.apache.org/airflow-slack" +"Twitter" = "https://x.com/ApacheAirflow" +"YouTube" = "https://www.youtube.com/channel/UCSXwxpWZQ7XZ1WL3wqevChA/" + +[project.entry-points."apache_airflow_provider"] +provider_info = "airflow.providers.cloudant.get_provider_info:get_provider_info" + +[tool.flit.module] +name = "airflow.providers.cloudant" + +[tool.pytest.ini_options] +ignore = "tests/system/" diff --git a/providers/cloudant/src/airflow/providers/cloudant/LICENSE b/providers/cloudant/src/airflow/providers/cloudant/LICENSE new file mode 100644 index 0000000000000..11069edd79019 --- /dev/null +++ b/providers/cloudant/src/airflow/providers/cloudant/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + +TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + +1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + +2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + +3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + +4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + +5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + +6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + +7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + +8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + +9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + +END OF TERMS AND CONDITIONS + +APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + +Copyright [yyyy] [name of copyright owner] + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/providers/src/airflow/providers/cloudant/__init__.py b/providers/cloudant/src/airflow/providers/cloudant/__init__.py similarity index 100% rename from providers/src/airflow/providers/cloudant/__init__.py rename to providers/cloudant/src/airflow/providers/cloudant/__init__.py diff --git a/providers/src/airflow/providers/cloudant/cloudant_fake.py b/providers/cloudant/src/airflow/providers/cloudant/cloudant_fake.py similarity index 100% rename from providers/src/airflow/providers/cloudant/cloudant_fake.py rename to providers/cloudant/src/airflow/providers/cloudant/cloudant_fake.py diff --git a/providers/cloudant/src/airflow/providers/cloudant/get_provider_info.py b/providers/cloudant/src/airflow/providers/cloudant/get_provider_info.py new file mode 100644 index 0000000000000..b893532368780 --- /dev/null +++ b/providers/cloudant/src/airflow/providers/cloudant/get_provider_info.py @@ -0,0 +1,78 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# NOTE! THIS FILE IS AUTOMATICALLY GENERATED AND WILL BE OVERWRITTEN! +# +# IF YOU WANT TO MODIFY THIS FILE, YOU SHOULD MODIFY THE TEMPLATE +# `get_provider_info_TEMPLATE.py.jinja2` IN the `dev/breeze/src/airflow_breeze/templates` DIRECTORY + + +def get_provider_info(): + return { + "package-name": "apache-airflow-providers-cloudant", + "name": "IBM Cloudant", + "description": "`IBM Cloudant `__\n", + "state": "ready", + "source-date-epoch": 1734529058, + "versions": [ + "4.1.0", + "4.0.3", + "4.0.2", + "4.0.1", + "4.0.0", + "3.6.0", + "3.5.2", + "3.5.1", + "3.5.0", + "3.4.1", + "3.4.0", + "3.3.0", + "3.2.1", + "3.2.0", + "3.1.0", + "3.0.0", + "2.0.4", + "2.0.3", + "2.0.2", + "2.0.1", + "2.0.0", + "1.0.1", + "1.0.0", + ], + "excluded-python-versions": ["3.9"], + "integrations": [ + { + "integration-name": "IBM Cloudant", + "external-doc-url": "https://www.ibm.com/cloud/cloudant", + "logo": "/docs/integration-logos/Cloudant.png", + "tags": ["service"], + } + ], + "hooks": [ + { + "integration-name": "IBM Cloudant", + "python-modules": ["airflow.providers.cloudant.hooks.cloudant"], + } + ], + "connection-types": [ + { + "hook-class-name": "airflow.providers.cloudant.hooks.cloudant.CloudantHook", + "connection-type": "cloudant", + } + ], + "dependencies": ["apache-airflow>=2.9.0", 'ibmcloudant==0.9.1;python_version>="3.10"'], + } diff --git a/providers/src/airflow/providers/cncf/kubernetes/operators/__init__.py b/providers/cloudant/src/airflow/providers/cloudant/hooks/__init__.py similarity index 100% rename from providers/src/airflow/providers/cncf/kubernetes/operators/__init__.py rename to providers/cloudant/src/airflow/providers/cloudant/hooks/__init__.py diff --git a/providers/src/airflow/providers/cloudant/hooks/cloudant.py b/providers/cloudant/src/airflow/providers/cloudant/hooks/cloudant.py similarity index 100% rename from providers/src/airflow/providers/cloudant/hooks/cloudant.py rename to providers/cloudant/src/airflow/providers/cloudant/hooks/cloudant.py diff --git a/providers/cloudant/tests/conftest.py b/providers/cloudant/tests/conftest.py new file mode 100644 index 0000000000000..068fe6bbf5ae9 --- /dev/null +++ b/providers/cloudant/tests/conftest.py @@ -0,0 +1,32 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import pathlib + +import pytest + +pytest_plugins = "tests_common.pytest_plugin" + + +@pytest.hookimpl(tryfirst=True) +def pytest_configure(config: pytest.Config) -> None: + deprecations_ignore_path = pathlib.Path(__file__).parent.joinpath("deprecations_ignore.yml") + dep_path = [deprecations_ignore_path] if deprecations_ignore_path.exists() else [] + config.inicfg["airflow_deprecations_ignore"] = ( + config.inicfg.get("airflow_deprecations_ignore", []) + dep_path # type: ignore[assignment,operator] + ) diff --git a/providers/cloudant/tests/provider_tests/__init__.py b/providers/cloudant/tests/provider_tests/__init__.py new file mode 100644 index 0000000000000..e8fd22856438c --- /dev/null +++ b/providers/cloudant/tests/provider_tests/__init__.py @@ -0,0 +1,17 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +__path__ = __import__("pkgutil").extend_path(__path__, __name__) # type: ignore diff --git a/providers/src/airflow/providers/microsoft/winrm/hooks/__init__.py b/providers/cloudant/tests/provider_tests/cloudant/__init__.py similarity index 100% rename from providers/src/airflow/providers/microsoft/winrm/hooks/__init__.py rename to providers/cloudant/tests/provider_tests/cloudant/__init__.py diff --git a/providers/src/airflow/providers/microsoft/winrm/operators/__init__.py b/providers/cloudant/tests/provider_tests/cloudant/hooks/__init__.py similarity index 100% rename from providers/src/airflow/providers/microsoft/winrm/operators/__init__.py rename to providers/cloudant/tests/provider_tests/cloudant/hooks/__init__.py diff --git a/providers/tests/cloudant/hooks/test_cloudant.py b/providers/cloudant/tests/provider_tests/cloudant/hooks/test_cloudant.py similarity index 100% rename from providers/tests/cloudant/hooks/test_cloudant.py rename to providers/cloudant/tests/provider_tests/cloudant/hooks/test_cloudant.py diff --git a/providers/cncf/kubernetes/README.rst b/providers/cncf/kubernetes/README.rst new file mode 100644 index 0000000000000..8cdf33e1696e7 --- /dev/null +++ b/providers/cncf/kubernetes/README.rst @@ -0,0 +1,67 @@ + + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + .. NOTE! THIS FILE IS AUTOMATICALLY GENERATED AND WILL BE OVERWRITTEN! + + .. IF YOU WANT TO MODIFY TEMPLATE FOR THIS FILE, YOU SHOULD MODIFY THE TEMPLATE + `PROVIDER_README_TEMPLATE.rst.jinja2` IN the `dev/breeze/src/airflow_breeze/templates` DIRECTORY + + +Package ``apache-airflow-providers-cncf-kubernetes`` + +Release: ``10.1.0`` + + +`Kubernetes `__ + + +Provider package +---------------- + +This is a provider package for ``cncf.kubernetes`` provider. All classes for this provider package +are in ``airflow.providers.cncf.kubernetes`` python package. + +You can find package information and changelog for the provider +in the `documentation `_. + +Installation +------------ + +You can install this package on top of an existing Airflow 2 installation (see ``Requirements`` below +for the minimum Airflow version supported) via +``pip install apache-airflow-providers-cncf-kubernetes`` + +The package supports the following python versions: 3.9,3.10,3.11,3.12 + +Requirements +------------ + +====================== ===================== +PIP package Version required +====================== ===================== +``aiofiles`` ``>=23.2.0`` +``apache-airflow`` ``>=2.9.0`` +``asgiref`` ``>=3.5.2`` +``cryptography`` ``>=41.0.0`` +``kubernetes`` ``>=29.0.0,<=31.0.0`` +``kubernetes_asyncio`` ``>=29.0.0,<=31.0.0`` +``google-re2`` ``>=1.0`` +====================== ===================== + +The changelog for the provider package can be found in the +`changelog `_. diff --git a/providers/src/airflow/providers/cncf/kubernetes/.latest-doc-only-change.txt b/providers/cncf/kubernetes/docs/.latest-doc-only-change.txt similarity index 100% rename from providers/src/airflow/providers/cncf/kubernetes/.latest-doc-only-change.txt rename to providers/cncf/kubernetes/docs/.latest-doc-only-change.txt diff --git a/providers/src/airflow/providers/cncf/kubernetes/CHANGELOG.rst b/providers/cncf/kubernetes/docs/changelog.rst similarity index 100% rename from providers/src/airflow/providers/cncf/kubernetes/CHANGELOG.rst rename to providers/cncf/kubernetes/docs/changelog.rst diff --git a/docs/apache-airflow-providers-cncf-kubernetes/cli-ref.rst b/providers/cncf/kubernetes/docs/cli-ref.rst similarity index 100% rename from docs/apache-airflow-providers-cncf-kubernetes/cli-ref.rst rename to providers/cncf/kubernetes/docs/cli-ref.rst diff --git a/docs/apache-airflow-providers-cncf-kubernetes/commits.rst b/providers/cncf/kubernetes/docs/commits.rst similarity index 100% rename from docs/apache-airflow-providers-cncf-kubernetes/commits.rst rename to providers/cncf/kubernetes/docs/commits.rst diff --git a/docs/apache-airflow-providers-cncf-kubernetes/configurations-ref.rst b/providers/cncf/kubernetes/docs/configurations-ref.rst similarity index 100% rename from docs/apache-airflow-providers-cncf-kubernetes/configurations-ref.rst rename to providers/cncf/kubernetes/docs/configurations-ref.rst diff --git a/docs/apache-airflow-providers-cncf-kubernetes/connections/kubernetes.rst b/providers/cncf/kubernetes/docs/connections/kubernetes.rst similarity index 100% rename from docs/apache-airflow-providers-cncf-kubernetes/connections/kubernetes.rst rename to providers/cncf/kubernetes/docs/connections/kubernetes.rst diff --git a/docs/apache-airflow-providers-cncf-kubernetes/img/arch-diag-kubernetes.png b/providers/cncf/kubernetes/docs/img/arch-diag-kubernetes.png similarity index 100% rename from docs/apache-airflow-providers-cncf-kubernetes/img/arch-diag-kubernetes.png rename to providers/cncf/kubernetes/docs/img/arch-diag-kubernetes.png diff --git a/docs/apache-airflow-providers-cncf-kubernetes/img/arch-diag-kubernetes2.png b/providers/cncf/kubernetes/docs/img/arch-diag-kubernetes2.png similarity index 100% rename from docs/apache-airflow-providers-cncf-kubernetes/img/arch-diag-kubernetes2.png rename to providers/cncf/kubernetes/docs/img/arch-diag-kubernetes2.png diff --git a/docs/apache-airflow-providers-cncf-kubernetes/img/k8s-failed-pod.png b/providers/cncf/kubernetes/docs/img/k8s-failed-pod.png similarity index 100% rename from docs/apache-airflow-providers-cncf-kubernetes/img/k8s-failed-pod.png rename to providers/cncf/kubernetes/docs/img/k8s-failed-pod.png diff --git a/docs/apache-airflow-providers-cncf-kubernetes/img/k8s-happy-path.png b/providers/cncf/kubernetes/docs/img/k8s-happy-path.png similarity index 100% rename from docs/apache-airflow-providers-cncf-kubernetes/img/k8s-happy-path.png rename to providers/cncf/kubernetes/docs/img/k8s-happy-path.png diff --git a/docs/apache-airflow-providers-cncf-kubernetes/index.rst b/providers/cncf/kubernetes/docs/index.rst similarity index 100% rename from docs/apache-airflow-providers-cncf-kubernetes/index.rst rename to providers/cncf/kubernetes/docs/index.rst diff --git a/docs/apache-airflow-providers-cncf-kubernetes/installing-providers-from-sources.rst b/providers/cncf/kubernetes/docs/installing-providers-from-sources.rst similarity index 100% rename from docs/apache-airflow-providers-cncf-kubernetes/installing-providers-from-sources.rst rename to providers/cncf/kubernetes/docs/installing-providers-from-sources.rst diff --git a/docs/integration-logos/kubernetes/Kubernetes.png b/providers/cncf/kubernetes/docs/integration-logos/Kubernetes.png similarity index 100% rename from docs/integration-logos/kubernetes/Kubernetes.png rename to providers/cncf/kubernetes/docs/integration-logos/Kubernetes.png diff --git a/docs/integration-logos/kubernetes/Spark-On-Kubernetes.png b/providers/cncf/kubernetes/docs/integration-logos/Spark-On-Kubernetes.png similarity index 100% rename from docs/integration-logos/kubernetes/Spark-On-Kubernetes.png rename to providers/cncf/kubernetes/docs/integration-logos/Spark-On-Kubernetes.png diff --git a/docs/apache-airflow-providers-cncf-kubernetes/kubernetes_executor.rst b/providers/cncf/kubernetes/docs/kubernetes_executor.rst similarity index 96% rename from docs/apache-airflow-providers-cncf-kubernetes/kubernetes_executor.rst rename to providers/cncf/kubernetes/docs/kubernetes_executor.rst index a85a793712885..91f1747e97ab2 100644 --- a/docs/apache-airflow-providers-cncf-kubernetes/kubernetes_executor.rst +++ b/providers/cncf/kubernetes/docs/kubernetes_executor.rst @@ -108,21 +108,21 @@ With these requirements in mind, here are some examples of basic ``pod_template_ Storing DAGs in the image: -.. literalinclude:: /../../providers/src/airflow/providers/cncf/kubernetes/pod_template_file_examples/dags_in_image_template.yaml +.. literalinclude:: /../../providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/pod_template_file_examples/dags_in_image_template.yaml :language: yaml :start-after: [START template_with_dags_in_image] :end-before: [END template_with_dags_in_image] Storing DAGs in a ``persistentVolume``: -.. literalinclude:: /../../providers/src/airflow/providers/cncf/kubernetes/pod_template_file_examples/dags_in_volume_template.yaml +.. literalinclude:: /../../providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/pod_template_file_examples/dags_in_volume_template.yaml :language: yaml :start-after: [START template_with_dags_in_volume] :end-before: [END template_with_dags_in_volume] Pulling DAGs from ``git``: -.. literalinclude:: /../../providers/src/airflow/providers/cncf/kubernetes/pod_template_file_examples/git_sync_template.yaml +.. literalinclude:: /../../providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/pod_template_file_examples/git_sync_template.yaml :language: yaml :start-after: [START git_sync_template] :end-before: [END git_sync_template] diff --git a/docs/apache-airflow-providers-cncf-kubernetes/local_kubernetes_executor.rst b/providers/cncf/kubernetes/docs/local_kubernetes_executor.rst similarity index 100% rename from docs/apache-airflow-providers-cncf-kubernetes/local_kubernetes_executor.rst rename to providers/cncf/kubernetes/docs/local_kubernetes_executor.rst diff --git a/docs/apache-airflow-providers-cncf-kubernetes/operators.rst b/providers/cncf/kubernetes/docs/operators.rst similarity index 95% rename from docs/apache-airflow-providers-cncf-kubernetes/operators.rst rename to providers/cncf/kubernetes/docs/operators.rst index b9876ea60d0cb..160f9cfccffd9 100644 --- a/docs/apache-airflow-providers-cncf-kubernetes/operators.rst +++ b/providers/cncf/kubernetes/docs/operators.rst @@ -102,7 +102,7 @@ Using this method will ensure correctness and type safety. While we have removed almost all Kubernetes convenience classes, we have kept the :class:`~airflow.providers.cncf.kubernetes.secret.Secret` class to simplify the process of generating secret volumes/env variables. -.. exampleinclude:: /../../providers/tests/system/cncf/kubernetes/example_kubernetes.py +.. exampleinclude:: /../../providers/cncf/kubernetes/tests/system/cncf/kubernetes/example_kubernetes.py :language: python :start-after: [START howto_operator_k8s_cluster_resources] :end-before: [END howto_operator_k8s_cluster_resources] @@ -135,21 +135,21 @@ Create the Secret using ``kubectl``: Then use it in your pod like so: -.. exampleinclude:: /../../providers/tests/system/cncf/kubernetes/example_kubernetes.py +.. exampleinclude:: /../../providers/cncf/kubernetes/tests/system/cncf/kubernetes/example_kubernetes.py :language: python :start-after: [START howto_operator_k8s_private_image] :end-before: [END howto_operator_k8s_private_image] Also for this action you can use operator in the deferrable mode: -.. exampleinclude:: /../../providers/tests/system/cncf/kubernetes/example_kubernetes_async.py +.. exampleinclude:: /../../providers/cncf/kubernetes/tests/system/cncf/kubernetes/example_kubernetes_async.py :language: python :start-after: [START howto_operator_k8s_private_image_async] :end-before: [END howto_operator_k8s_private_image_async] Example to fetch and display container log periodically -.. exampleinclude:: /../../providers/tests/system/cncf/kubernetes/example_kubernetes_async.py +.. exampleinclude:: /../../providers/cncf/kubernetes/tests/system/cncf/kubernetes/example_kubernetes_async.py :language: python :start-after: [START howto_operator_async_log] :end-before: [END howto_operator_async_log] @@ -168,7 +168,7 @@ alongside the Pod. The Pod must write the XCom value into this location at the ` See the following example on how this occurs: -.. exampleinclude:: /../../providers/tests/system/cncf/kubernetes/example_kubernetes.py +.. exampleinclude:: /../../providers/cncf/kubernetes/tests/system/cncf/kubernetes/example_kubernetes.py :language: python :start-after: [START howto_operator_k8s_write_xcom] :end-before: [END howto_operator_k8s_write_xcom] @@ -177,7 +177,7 @@ See the following example on how this occurs: Also for this action you can use operator in the deferrable mode: -.. exampleinclude:: /../../providers/tests/system/cncf/kubernetes/example_kubernetes_async.py +.. exampleinclude:: /../../providers/cncf/kubernetes/tests/system/cncf/kubernetes/example_kubernetes_async.py :language: python :start-after: [START howto_operator_k8s_write_xcom_async] :end-before: [END howto_operator_k8s_write_xcom_async] @@ -621,7 +621,7 @@ request that dynamically launches this Job. Users can specify a kubeconfig file using the ``config_file`` parameter, otherwise the operator will default to ``~/.kube/config``. It also allows users to supply a template YAML file using the ``job_template_file`` parameter. -.. exampleinclude:: /../../providers/tests/system/cncf/kubernetes/example_kubernetes_job.py +.. exampleinclude:: /../../providers/cncf/kubernetes/tests/system/cncf/kubernetes/example_kubernetes_job.py :language: python :dedent: 4 :start-after: [START howto_operator_k8s_job] @@ -629,7 +629,7 @@ to ``~/.kube/config``. It also allows users to supply a template YAML file using The :class:`~airflow.providers.cncf.kubernetes.operators.job.KubernetesJobOperator` also supports deferrable mode: -.. exampleinclude:: /../../providers/tests/system/cncf/kubernetes/example_kubernetes_job.py +.. exampleinclude:: /../../providers/cncf/kubernetes/tests/system/cncf/kubernetes/example_kubernetes_job.py :language: python :dedent: 4 :start-after: [START howto_operator_k8s_job_deferrable] @@ -656,7 +656,7 @@ KubernetesDeleteJobOperator The :class:`~airflow.providers.cncf.kubernetes.operators.job.KubernetesDeleteJobOperator` allows you to delete Jobs on a Kubernetes cluster. -.. exampleinclude:: /../../providers/tests/system/cncf/kubernetes/example_kubernetes_job.py +.. exampleinclude:: /../../providers/cncf/kubernetes/tests/system/cncf/kubernetes/example_kubernetes_job.py :language: python :dedent: 4 :start-after: [START howto_operator_delete_k8s_job] @@ -671,7 +671,7 @@ KubernetesPatchJobOperator The :class:`~airflow.providers.cncf.kubernetes.operators.job.KubernetesPatchJobOperator` allows you to update Jobs on a Kubernetes cluster. -.. exampleinclude:: /../../providers/tests/system/cncf/kubernetes/example_kubernetes_job.py +.. exampleinclude:: /../../providers/cncf/kubernetes/tests/system/cncf/kubernetes/example_kubernetes_job.py :language: python :dedent: 4 :start-after: [START howto_operator_update_job] @@ -686,7 +686,7 @@ KubernetesInstallKueueOperator The :class:`~airflow.providers.cncf.kubernetes.operators.kueue.KubernetesInstallKueueOperator` allows you to install the Kueue component in a Kubernetes cluster -.. exampleinclude:: /../../providers/tests/system/cncf/kubernetes/example_kubernetes_kueue.py +.. exampleinclude:: /../../providers/cncf/kubernetes/tests/system/cncf/kubernetes/example_kubernetes_kueue.py :language: python :dedent: 4 :start-after: [START howto_operator_k8s_kueue_install] @@ -709,7 +709,7 @@ KubernetesStartKueueJobOperator The :class:`~airflow.providers.cncf.kubernetes.operators.kueue.KubernetesStartKueueJobOperator` allows you to start a Kueue job in a Kubernetes cluster -.. exampleinclude:: /../../providers/tests/system/cncf/kubernetes/example_kubernetes_kueue.py +.. exampleinclude:: /../../providers/cncf/kubernetes/tests/system/cncf/kubernetes/example_kubernetes_kueue.py :language: python :dedent: 4 :start-after: [START howto_operator_k8s_install_kueue] diff --git a/docs/apache-airflow-providers-cncf-kubernetes/redirects.txt b/providers/cncf/kubernetes/docs/redirects.txt similarity index 100% rename from docs/apache-airflow-providers-cncf-kubernetes/redirects.txt rename to providers/cncf/kubernetes/docs/redirects.txt diff --git a/docs/apache-airflow-providers-cncf-kubernetes/security.rst b/providers/cncf/kubernetes/docs/security.rst similarity index 100% rename from docs/apache-airflow-providers-cncf-kubernetes/security.rst rename to providers/cncf/kubernetes/docs/security.rst diff --git a/providers/src/airflow/providers/cncf/kubernetes/provider.yaml b/providers/cncf/kubernetes/provider.yaml similarity index 88% rename from providers/src/airflow/providers/cncf/kubernetes/provider.yaml rename to providers/cncf/kubernetes/provider.yaml index 9d38a70aa14b9..c9d213373fce7 100644 --- a/providers/src/airflow/providers/cncf/kubernetes/provider.yaml +++ b/providers/cncf/kubernetes/provider.yaml @@ -95,38 +95,16 @@ versions: - 1.0.1 - 1.0.0 -dependencies: - - aiofiles>=23.2.0 - - apache-airflow>=2.9.0 - - asgiref>=3.5.2 - - cryptography>=41.0.0 - # The Kubernetes API is known to introduce problems when upgraded to a MAJOR version. Airflow Core - # Uses Kubernetes for Kubernetes executor, and we also know that Kubernetes Python client follows SemVer - # (https://github.com/kubernetes-client/python#compatibility). This is a crucial component of Airflow - # So we should limit it to the next MAJOR version and only deliberately bump the version when we - # tested it, and we know it can be bumped. Bumping this version should also be connected with - # limiting minimum airflow version supported in cncf.kubernetes provider, due to the - # potential breaking changes in Airflow Core as well (kubernetes is added as extra, so Airflow - # core is not hard-limited via install-requirements, only by extra). - - kubernetes>=29.0.0,<=31.0.0 - # The Kubernetes_asyncio package is used for providing Asynchronous (AsyncIO) client library for - # standard Kubernetes API. The version is limited by minimum 18.20.1 because of introducing the ability to - # load kubernetes config file from dictionary in that release and is limited to the next MAJOR version - # (started from current 24.2.2 version) to prevent introducing some problems that could be due to some - # major changes in the package. - - kubernetes_asyncio>=29.0.0,<=31.0.0 - - google-re2>=1.0 - integrations: - integration-name: Kubernetes external-doc-url: https://kubernetes.io/ how-to-guide: - /docs/apache-airflow-providers-cncf-kubernetes/operators.rst - logo: /integration-logos/kubernetes/Kubernetes.png + logo: /docs/integration-logos/Kubernetes.png tags: [software] - integration-name: Spark on Kubernetes external-doc-url: https://github.com/GoogleCloudPlatform/spark-on-k8s-operator - logo: /integration-logos/kubernetes/Spark-On-Kubernetes.png + logo: /docs/integration-logos/Spark-On-Kubernetes.png tags: [software] operators: diff --git a/providers/cncf/kubernetes/pyproject.toml b/providers/cncf/kubernetes/pyproject.toml new file mode 100644 index 0000000000000..32728f7c81e9a --- /dev/null +++ b/providers/cncf/kubernetes/pyproject.toml @@ -0,0 +1,96 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# NOTE! THIS FILE IS AUTOMATICALLY GENERATED AND WILL BE OVERWRITTEN! + +# IF YOU WANT TO MODIFY THIS FILE EXCEPT DEPENDENCIES, YOU SHOULD MODIFY THE TEMPLATE +# `pyproject_TEMPLATE.toml.jinja2` IN the `dev/breeze/src/airflow_breeze/templates` DIRECTORY +[build-system] +requires = ["flit_core==3.10.1"] +build-backend = "flit_core.buildapi" + +[project] +name = "apache-airflow-providers-cncf-kubernetes" +version = "10.1.0" +description = "Provider package apache-airflow-providers-cncf-kubernetes for Apache Airflow" +readme = "README.rst" +authors = [ + {name="Apache Software Foundation", email="dev@airflow.apache.org"}, +] +maintainers = [ + {name="Apache Software Foundation", email="dev@airflow.apache.org"}, +] +keywords = [ "airflow-provider", "cncf.kubernetes", "airflow", "integration" ] +classifiers = [ + "Development Status :: 5 - Production/Stable", + "Environment :: Console", + "Environment :: Web Environment", + "Intended Audience :: Developers", + "Intended Audience :: System Administrators", + "Framework :: Apache Airflow", + "Framework :: Apache Airflow :: Provider", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Topic :: System :: Monitoring", +] +requires-python = "~=3.9" + +# The dependencies should be modified in place in the generated file +# Any change in the dependencies is preserved when the file is regenerated +dependencies = [ + "aiofiles>=23.2.0", + "apache-airflow>=2.9.0", + "asgiref>=3.5.2", + "cryptography>=41.0.0", + # The Kubernetes API is known to introduce problems when upgraded to a MAJOR version. Airflow Core + # Uses Kubernetes for Kubernetes executor, and we also know that Kubernetes Python client follows SemVer + # (https://github.com/kubernetes-client/python#compatibility). This is a crucial component of Airflow + # So we should limit it to the next MAJOR version and only deliberately bump the version when we + # tested it, and we know it can be bumped. Bumping this version should also be connected with + # limiting minimum airflow version supported in cncf.kubernetes provider, due to the + # potential breaking changes in Airflow Core as well (kubernetes is added as extra, so Airflow + # core is not hard-limited via install-requirements, only by extra). + "kubernetes>=29.0.0,<=31.0.0", + # The Kubernetes_asyncio package is used for providing Asynchronous (AsyncIO) client library for + # standard Kubernetes API. The version is limited by minimum 18.20.1 because of introducing the ability to + # load kubernetes config file from dictionary in that release and is limited to the next MAJOR version + # (started from current 24.2.2 version) to prevent introducing some problems that could be due to some + # major changes in the package. + "kubernetes_asyncio>=29.0.0,<=31.0.0", + "google-re2>=1.0", +] + +[project.urls] +"Documentation" = "https://airflow.apache.org/docs/apache-airflow-providers-cncf-kubernetes/10.1.0" +"Changelog" = "https://airflow.apache.org/docs/apache-airflow-providers-cncf-kubernetes/10.1.0/changelog.html" +"Bug Tracker" = "https://github.com/apache/airflow/issues" +"Source Code" = "https://github.com/apache/airflow" +"Slack Chat" = "https://s.apache.org/airflow-slack" +"Twitter" = "https://x.com/ApacheAirflow" +"YouTube" = "https://www.youtube.com/channel/UCSXwxpWZQ7XZ1WL3wqevChA/" + +[project.entry-points."apache_airflow_provider"] +provider_info = "airflow.providers.cncf.kubernetes.get_provider_info:get_provider_info" + +[tool.flit.module] +name = "airflow.providers.cncf.kubernetes" + +[tool.pytest.ini_options] +ignore = "tests/system/" diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/LICENSE b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/LICENSE new file mode 100644 index 0000000000000..11069edd79019 --- /dev/null +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + +TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + +1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + +2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + +3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + +4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + +5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + +6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + +7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + +8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + +9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + +END OF TERMS AND CONDITIONS + +APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + +Copyright [yyyy] [name of copyright owner] + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/providers/src/airflow/providers/cncf/kubernetes/__init__.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/__init__.py similarity index 100% rename from providers/src/airflow/providers/cncf/kubernetes/__init__.py rename to providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/__init__.py diff --git a/providers/src/airflow/providers/cncf/kubernetes/backcompat/__init__.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/backcompat/__init__.py similarity index 100% rename from providers/src/airflow/providers/cncf/kubernetes/backcompat/__init__.py rename to providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/backcompat/__init__.py diff --git a/providers/src/airflow/providers/cncf/kubernetes/backcompat/backwards_compat_converters.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/backcompat/backwards_compat_converters.py similarity index 99% rename from providers/src/airflow/providers/cncf/kubernetes/backcompat/backwards_compat_converters.py rename to providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/backcompat/backwards_compat_converters.py index 9bc821a816d99..2d6f94d6ea018 100644 --- a/providers/src/airflow/providers/cncf/kubernetes/backcompat/backwards_compat_converters.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/backcompat/backwards_compat_converters.py @@ -18,9 +18,8 @@ from __future__ import annotations -from kubernetes.client import ApiClient, models as k8s - from airflow.exceptions import AirflowException +from kubernetes.client import ApiClient, models as k8s def _convert_kube_model_object(obj, new_class): diff --git a/providers/src/airflow/providers/cncf/kubernetes/callbacks.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/callbacks.py similarity index 99% rename from providers/src/airflow/providers/cncf/kubernetes/callbacks.py rename to providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/callbacks.py index d87e8065dbd1a..723de2f94f990 100644 --- a/providers/src/airflow/providers/cncf/kubernetes/callbacks.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/callbacks.py @@ -19,9 +19,10 @@ from enum import Enum from typing import TYPE_CHECKING, Union -import kubernetes.client as k8s import kubernetes_asyncio.client as async_k8s +import kubernetes.client as k8s + if TYPE_CHECKING: from airflow.providers.cncf.kubernetes.operators.pod import KubernetesPodOperator from airflow.utils.context import Context diff --git a/providers/src/airflow/providers/cncf/kubernetes/hooks/__init__.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/cli/__init__.py similarity index 100% rename from providers/src/airflow/providers/cncf/kubernetes/hooks/__init__.py rename to providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/cli/__init__.py diff --git a/providers/src/airflow/providers/cncf/kubernetes/cli/kubernetes_command.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/cli/kubernetes_command.py similarity index 99% rename from providers/src/airflow/providers/cncf/kubernetes/cli/kubernetes_command.py rename to providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/cli/kubernetes_command.py index ee629cc16e9a5..d138fd9c77c00 100644 --- a/providers/src/airflow/providers/cncf/kubernetes/cli/kubernetes_command.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/cli/kubernetes_command.py @@ -22,10 +22,6 @@ import sys from datetime import datetime, timedelta -from kubernetes import client -from kubernetes.client.api_client import ApiClient -from kubernetes.client.rest import ApiException - from airflow.models import DagRun, TaskInstance from airflow.providers.cncf.kubernetes import pod_generator from airflow.providers.cncf.kubernetes.executors.kubernetes_executor import KubeConfig @@ -36,6 +32,9 @@ from airflow.utils import cli as cli_utils, yaml from airflow.utils.cli import get_dag from airflow.utils.providers_configuration_loader import providers_configuration_loaded +from kubernetes import client +from kubernetes.client.api_client import ApiClient +from kubernetes.client.rest import ApiException @cli_utils.action_cli diff --git a/providers/tests/apache/flink/__init__.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/decorators/__init__.py similarity index 100% rename from providers/tests/apache/flink/__init__.py rename to providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/decorators/__init__.py diff --git a/providers/src/airflow/providers/cncf/kubernetes/decorators/kubernetes.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/decorators/kubernetes.py similarity index 99% rename from providers/src/airflow/providers/cncf/kubernetes/decorators/kubernetes.py rename to providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/decorators/kubernetes.py index cce1ffa8d6c19..c35b492dcabe0 100644 --- a/providers/src/airflow/providers/cncf/kubernetes/decorators/kubernetes.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/decorators/kubernetes.py @@ -26,13 +26,13 @@ from typing import TYPE_CHECKING, Callable import dill -from kubernetes.client import models as k8s from airflow.decorators.base import DecoratedOperator, TaskDecorator, task_decorator_factory from airflow.providers.cncf.kubernetes.operators.pod import KubernetesPodOperator from airflow.providers.cncf.kubernetes.python_kubernetes_script import ( write_python_script, ) +from kubernetes.client import models as k8s if TYPE_CHECKING: from airflow.utils.context import Context @@ -119,7 +119,7 @@ def execute(self, context: Context): } write_python_script(jinja_context=jinja_context, filename=script_filename) - self.env_vars = [ + self.env_vars: list[k8s.V1EnvVar] = [ *self.env_vars, k8s.V1EnvVar(name=_PYTHON_SCRIPT_ENV, value=_read_file_contents(script_filename)), k8s.V1EnvVar(name=_PYTHON_INPUT_ENV, value=_read_file_contents(input_filename)), diff --git a/providers/src/airflow/providers/cncf/kubernetes/exceptions.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/exceptions.py similarity index 100% rename from providers/src/airflow/providers/cncf/kubernetes/exceptions.py rename to providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/exceptions.py diff --git a/providers/tests/apache/flink/operators/__init__.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/__init__.py similarity index 100% rename from providers/tests/apache/flink/operators/__init__.py rename to providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/__init__.py diff --git a/providers/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py similarity index 99% rename from providers/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py rename to providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py index 0f98f2f7dcb46..845c91e0db37d 100644 --- a/providers/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py @@ -37,9 +37,10 @@ from typing import TYPE_CHECKING, Any from deprecated import deprecated -from kubernetes.dynamic import DynamicClient from sqlalchemy import select +from kubernetes.dynamic import DynamicClient + try: from airflow.cli.cli_config import ARG_LOGICAL_DATE except ImportError: # 2.x compatibility. @@ -75,8 +76,6 @@ if TYPE_CHECKING: import argparse - from kubernetes import client - from kubernetes.client import models as k8s from sqlalchemy.orm import Session from airflow.executors.base_executor import CommandType @@ -89,6 +88,8 @@ from airflow.providers.cncf.kubernetes.executors.kubernetes_executor_utils import ( AirflowKubernetesScheduler, ) + from kubernetes import client + from kubernetes.client import models as k8s # CLI Args ARG_NAMESPACE = Arg( diff --git a/providers/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor_types.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor_types.py similarity index 100% rename from providers/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor_types.py rename to providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor_types.py diff --git a/providers/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor_utils.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor_utils.py similarity index 99% rename from providers/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor_utils.py rename to providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor_utils.py index 15fa954439a9d..1b7917502f0b6 100644 --- a/providers/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor_utils.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor_utils.py @@ -23,8 +23,6 @@ from queue import Empty, Queue from typing import TYPE_CHECKING, Any -from kubernetes import client, watch -from kubernetes.client.rest import ApiException from urllib3.exceptions import ReadTimeoutError from airflow.exceptions import AirflowException @@ -45,15 +43,16 @@ from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.singleton import Singleton from airflow.utils.state import TaskInstanceState +from kubernetes import client, watch +from kubernetes.client.rest import ApiException if TYPE_CHECKING: - from kubernetes.client import Configuration, models as k8s - from airflow.providers.cncf.kubernetes.executors.kubernetes_executor_types import ( KubernetesJobType, KubernetesResultsType, KubernetesWatchType, ) + from kubernetes.client import Configuration, models as k8s class ResourceVersion(metaclass=Singleton): diff --git a/providers/src/airflow/providers/cncf/kubernetes/executors/local_kubernetes_executor.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/local_kubernetes_executor.py similarity index 100% rename from providers/src/airflow/providers/cncf/kubernetes/executors/local_kubernetes_executor.py rename to providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/local_kubernetes_executor.py diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/get_provider_info.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/get_provider_info.py new file mode 100644 index 0000000000000..53e3d1f831c22 --- /dev/null +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/get_provider_info.py @@ -0,0 +1,361 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# NOTE! THIS FILE IS AUTOMATICALLY GENERATED AND WILL BE OVERWRITTEN! +# +# IF YOU WANT TO MODIFY THIS FILE, YOU SHOULD MODIFY THE TEMPLATE +# `get_provider_info_TEMPLATE.py.jinja2` IN the `dev/breeze/src/airflow_breeze/templates` DIRECTORY + + +def get_provider_info(): + return { + "package-name": "apache-airflow-providers-cncf-kubernetes", + "name": "Kubernetes", + "description": "`Kubernetes `__\n", + "state": "ready", + "source-date-epoch": 1734537609, + "versions": [ + "10.1.0", + "10.0.1", + "10.0.0", + "9.0.1", + "9.0.0", + "8.4.2", + "8.4.1", + "8.4.0", + "8.3.4", + "8.3.3", + "8.3.2", + "8.3.1", + "8.3.0", + "8.2.0", + "8.1.1", + "8.1.0", + "8.0.1", + "8.0.0", + "7.14.0", + "7.13.0", + "7.12.0", + "7.11.0", + "7.10.0", + "7.9.0", + "7.8.0", + "7.7.0", + "7.6.0", + "7.5.1", + "7.5.0", + "7.4.2", + "7.4.1", + "7.4.0", + "7.3.0", + "7.2.0", + "7.1.0", + "7.0.0", + "6.1.0", + "6.0.0", + "5.3.0", + "5.2.2", + "5.2.1", + "5.2.0", + "5.1.1", + "5.1.0", + "5.0.0", + "4.4.0", + "4.3.0", + "4.2.0", + "4.1.0", + "4.0.2", + "4.0.1", + "4.0.0", + "3.1.2", + "3.1.1", + "3.1.0", + "3.0.2", + "3.0.1", + "3.0.0", + "2.2.0", + "2.1.0", + "2.0.3", + "2.0.2", + "2.0.1", + "2.0.0", + "1.2.0", + "1.1.0", + "1.0.2", + "1.0.1", + "1.0.0", + ], + "integrations": [ + { + "integration-name": "Kubernetes", + "external-doc-url": "https://kubernetes.io/", + "how-to-guide": ["/docs/apache-airflow-providers-cncf-kubernetes/operators.rst"], + "logo": "/docs/integration-logos/Kubernetes.png", + "tags": ["software"], + }, + { + "integration-name": "Spark on Kubernetes", + "external-doc-url": "https://github.com/GoogleCloudPlatform/spark-on-k8s-operator", + "logo": "/docs/integration-logos/Spark-On-Kubernetes.png", + "tags": ["software"], + }, + ], + "operators": [ + { + "integration-name": "Kubernetes", + "python-modules": [ + "airflow.providers.cncf.kubernetes.operators.custom_object_launcher", + "airflow.providers.cncf.kubernetes.operators.kueue", + "airflow.providers.cncf.kubernetes.operators.pod", + "airflow.providers.cncf.kubernetes.operators.spark_kubernetes", + "airflow.providers.cncf.kubernetes.operators.resource", + "airflow.providers.cncf.kubernetes.operators.job", + ], + } + ], + "sensors": [ + { + "integration-name": "Kubernetes", + "python-modules": ["airflow.providers.cncf.kubernetes.sensors.spark_kubernetes"], + } + ], + "hooks": [ + { + "integration-name": "Kubernetes", + "python-modules": ["airflow.providers.cncf.kubernetes.hooks.kubernetes"], + } + ], + "triggers": [ + { + "integration-name": "Kubernetes", + "python-modules": [ + "airflow.providers.cncf.kubernetes.triggers.pod", + "airflow.providers.cncf.kubernetes.triggers.job", + ], + } + ], + "connection-types": [ + { + "hook-class-name": "airflow.providers.cncf.kubernetes.hooks.kubernetes.KubernetesHook", + "connection-type": "kubernetes", + } + ], + "task-decorators": [ + { + "class-name": "airflow.providers.cncf.kubernetes.decorators.kubernetes.kubernetes_task", + "name": "kubernetes", + } + ], + "config": { + "local_kubernetes_executor": { + "description": "This section only applies if you are using the ``LocalKubernetesExecutor`` in\n``[core]`` section above\n", + "options": { + "kubernetes_queue": { + "description": "Define when to send a task to ``KubernetesExecutor`` when using ``LocalKubernetesExecutor``.\nWhen the queue of a task is the value of ``kubernetes_queue`` (default ``kubernetes``),\nthe task is executed via ``KubernetesExecutor``,\notherwise via ``LocalExecutor``\n", + "version_added": None, + "type": "string", + "example": None, + "default": "kubernetes", + } + }, + }, + "kubernetes_executor": { + "description": None, + "options": { + "api_client_retry_configuration": { + "description": "Kwargs to override the default urllib3 Retry used in the kubernetes API client\n", + "version_added": None, + "type": "string", + "example": '{ "total": 3, "backoff_factor": 0.5 }', + "default": "", + }, + "logs_task_metadata": { + "description": "Flag to control the information added to kubernetes executor logs for better traceability\n", + "version_added": None, + "type": "boolean", + "example": None, + "default": "False", + }, + "pod_template_file": { + "description": "Path to the YAML pod file that forms the basis for KubernetesExecutor workers.\n", + "version_added": None, + "type": "string", + "example": None, + "default": "", + "see_also": ":ref:`concepts:pod_template_file`", + }, + "worker_container_repository": { + "description": "The repository of the Kubernetes Image for the Worker to Run\n", + "version_added": None, + "type": "string", + "example": None, + "default": "", + }, + "worker_container_tag": { + "description": "The tag of the Kubernetes Image for the Worker to Run\n", + "version_added": None, + "type": "string", + "example": None, + "default": "", + }, + "namespace": { + "description": "The Kubernetes namespace where airflow workers should be created. Defaults to ``default``\n", + "version_added": None, + "type": "string", + "example": None, + "default": "default", + }, + "delete_worker_pods": { + "description": "If True, all worker pods will be deleted upon termination\n", + "version_added": None, + "type": "string", + "example": None, + "default": "True", + }, + "delete_worker_pods_on_failure": { + "description": "If False (and delete_worker_pods is True),\nfailed worker pods will not be deleted so users can investigate them.\nThis only prevents removal of worker pods where the worker itself failed,\nnot when the task it ran failed.\n", + "version_added": None, + "type": "string", + "example": None, + "default": "False", + }, + "worker_pod_pending_fatal_container_state_reasons": { + "description": "If the worker pods are in a pending state due to a fatal container\nstate reasons, then fail the task and delete the worker pod\nif delete_worker_pods is True and delete_worker_pods_on_failure is True.\n", + "version_added": "8.1.0", + "type": "string", + "example": None, + "default": "CreateContainerConfigError,ErrImagePull,CreateContainerError,ImageInspectError, InvalidImageName", + }, + "worker_pods_creation_batch_size": { + "description": 'Number of Kubernetes Worker Pod creation calls per scheduler loop.\nNote that the current default of "1" will only launch a single pod\nper-heartbeat. It is HIGHLY recommended that users increase this\nnumber to match the tolerance of their kubernetes cluster for\nbetter performance.\n', + "version_added": None, + "type": "string", + "example": None, + "default": "1", + }, + "multi_namespace_mode": { + "description": "Allows users to launch pods in multiple namespaces.\nWill require creating a cluster-role for the scheduler,\nor use multi_namespace_mode_namespace_list configuration.\n", + "version_added": None, + "type": "boolean", + "example": None, + "default": "False", + }, + "multi_namespace_mode_namespace_list": { + "description": "If multi_namespace_mode is True while scheduler does not have a cluster-role,\ngive the list of namespaces where the scheduler will schedule jobs\nScheduler needs to have the necessary permissions in these namespaces.\n", + "version_added": None, + "type": "string", + "example": None, + "default": "", + }, + "in_cluster": { + "description": "Use the service account kubernetes gives to pods to connect to kubernetes cluster.\nIt's intended for clients that expect to be running inside a pod running on kubernetes.\nIt will raise an exception if called from a process not running in a kubernetes environment.\n", + "version_added": None, + "type": "string", + "example": None, + "default": "True", + }, + "cluster_context": { + "description": "When running with in_cluster=False change the default cluster_context or config_file\noptions to Kubernetes client. Leave blank these to use default behaviour like ``kubectl`` has.\n", + "version_added": None, + "type": "string", + "example": None, + "default": None, + }, + "config_file": { + "description": "Path to the kubernetes configfile to be used when ``in_cluster`` is set to False\n", + "version_added": None, + "type": "string", + "example": None, + "default": None, + }, + "kube_client_request_args": { + "description": "Keyword parameters to pass while calling a kubernetes client core_v1_api methods\nfrom Kubernetes Executor provided as a single line formatted JSON dictionary string.\nList of supported params are similar for all core_v1_apis, hence a single config\nvariable for all apis. See:\nhttps://raw.githubusercontent.com/kubernetes-client/python/41f11a09995efcd0142e25946adc7591431bfb2f/kubernetes/client/api/core_v1_api.py\n", + "version_added": None, + "type": "string", + "example": None, + "default": "", + }, + "delete_option_kwargs": { + "description": "Optional keyword arguments to pass to the ``delete_namespaced_pod`` kubernetes client\n``core_v1_api`` method when using the Kubernetes Executor.\nThis should be an object and can contain any of the options listed in the ``v1DeleteOptions``\nclass defined here:\nhttps://github.com/kubernetes-client/python/blob/41f11a09995efcd0142e25946adc7591431bfb2f/kubernetes/client/models/v1_delete_options.py#L19\n", + "version_added": None, + "type": "string", + "example": '{"grace_period_seconds": 10}', + "default": "", + }, + "enable_tcp_keepalive": { + "description": "Enables TCP keepalive mechanism. This prevents Kubernetes API requests to hang indefinitely\nwhen idle connection is time-outed on services like cloud load balancers or firewalls.\n", + "version_added": None, + "type": "boolean", + "example": None, + "default": "True", + }, + "tcp_keep_idle": { + "description": "When the `enable_tcp_keepalive` option is enabled, TCP probes a connection that has\nbeen idle for `tcp_keep_idle` seconds.\n", + "version_added": None, + "type": "integer", + "example": None, + "default": "120", + }, + "tcp_keep_intvl": { + "description": "When the `enable_tcp_keepalive` option is enabled, if Kubernetes API does not respond\nto a keepalive probe, TCP retransmits the probe after `tcp_keep_intvl` seconds.\n", + "version_added": None, + "type": "integer", + "example": None, + "default": "30", + }, + "tcp_keep_cnt": { + "description": "When the `enable_tcp_keepalive` option is enabled, if Kubernetes API does not respond\nto a keepalive probe, TCP retransmits the probe `tcp_keep_cnt number` of times before\na connection is considered to be broken.\n", + "version_added": None, + "type": "integer", + "example": None, + "default": "6", + }, + "verify_ssl": { + "description": "Set this to false to skip verifying SSL certificate of Kubernetes python client.\n", + "version_added": None, + "type": "boolean", + "example": None, + "default": "True", + }, + "ssl_ca_cert": { + "description": "Path to a CA certificate to be used by the Kubernetes client to verify the server's SSL certificate.\n", + "version_added": None, + "type": "string", + "example": None, + "default": "", + }, + "task_publish_max_retries": { + "description": "The Maximum number of retries for queuing the task to the kubernetes scheduler when\nfailing due to Kube API exceeded quota errors before giving up and marking task as failed.\n-1 for unlimited times.\n", + "version_added": None, + "type": "integer", + "example": None, + "default": "0", + }, + }, + }, + }, + "executors": ["airflow.providers.cncf.kubernetes.kubernetes_executor.KubernetesExecutor"], + "dependencies": [ + "aiofiles>=23.2.0", + "apache-airflow>=2.9.0", + "asgiref>=3.5.2", + "cryptography>=41.0.0", + "kubernetes>=29.0.0,<=31.0.0", + "kubernetes_asyncio>=29.0.0,<=31.0.0", + "google-re2>=1.0", + ], + } diff --git a/providers/src/airflow/providers/cncf/kubernetes/kubernetes_executor_templates/__init__.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/hooks/__init__.py similarity index 100% rename from providers/src/airflow/providers/cncf/kubernetes/kubernetes_executor_templates/__init__.py rename to providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/hooks/__init__.py diff --git a/providers/src/airflow/providers/cncf/kubernetes/hooks/kubernetes.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/hooks/kubernetes.py similarity index 100% rename from providers/src/airflow/providers/cncf/kubernetes/hooks/kubernetes.py rename to providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/hooks/kubernetes.py index 6a377be3eb27c..114a933e50e35 100644 --- a/providers/src/airflow/providers/cncf/kubernetes/hooks/kubernetes.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/hooks/kubernetes.py @@ -29,9 +29,6 @@ import requests import tenacity from asgiref.sync import sync_to_async -from kubernetes import client, config, utils, watch -from kubernetes.client.models import V1Deployment -from kubernetes.config import ConfigException from kubernetes_asyncio import client as async_client, config as async_config from urllib3.exceptions import HTTPError @@ -46,6 +43,9 @@ container_is_running, ) from airflow.utils import yaml +from kubernetes import client, config, utils, watch +from kubernetes.client.models import V1Deployment +from kubernetes.config import ConfigException if TYPE_CHECKING: from kubernetes.client import V1JobList diff --git a/providers/src/airflow/providers/cncf/kubernetes/k8s_model.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/k8s_model.py similarity index 100% rename from providers/src/airflow/providers/cncf/kubernetes/k8s_model.py rename to providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/k8s_model.py diff --git a/providers/src/airflow/providers/cncf/kubernetes/kube_client.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/kube_client.py similarity index 100% rename from providers/src/airflow/providers/cncf/kubernetes/kube_client.py rename to providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/kube_client.py diff --git a/providers/src/airflow/providers/cncf/kubernetes/kube_config.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/kube_config.py similarity index 100% rename from providers/src/airflow/providers/cncf/kubernetes/kube_config.py rename to providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/kube_config.py diff --git a/providers/src/airflow/providers/cncf/kubernetes/pod_template_file_examples/__init__.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/kubernetes_executor_templates/__init__.py similarity index 100% rename from providers/src/airflow/providers/cncf/kubernetes/pod_template_file_examples/__init__.py rename to providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/kubernetes_executor_templates/__init__.py diff --git a/providers/src/airflow/providers/cncf/kubernetes/kubernetes_executor_templates/basic_template.yaml b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/kubernetes_executor_templates/basic_template.yaml similarity index 100% rename from providers/src/airflow/providers/cncf/kubernetes/kubernetes_executor_templates/basic_template.yaml rename to providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/kubernetes_executor_templates/basic_template.yaml diff --git a/providers/src/airflow/providers/cncf/kubernetes/kubernetes_helper_functions.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/kubernetes_helper_functions.py similarity index 100% rename from providers/src/airflow/providers/cncf/kubernetes/kubernetes_helper_functions.py rename to providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/kubernetes_helper_functions.py index a353b7a5c8026..2732907886970 100644 --- a/providers/src/airflow/providers/cncf/kubernetes/kubernetes_helper_functions.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/kubernetes_helper_functions.py @@ -23,11 +23,11 @@ from typing import TYPE_CHECKING import pendulum -from kubernetes.client.rest import ApiException from slugify import slugify from airflow.configuration import conf from airflow.providers.cncf.kubernetes.backcompat import get_logical_date_key +from kubernetes.client.rest import ApiException if TYPE_CHECKING: from airflow.models.taskinstancekey import TaskInstanceKey diff --git a/providers/tests/apache/flink/sensors/__init__.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/__init__.py similarity index 100% rename from providers/tests/apache/flink/sensors/__init__.py rename to providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/__init__.py diff --git a/providers/src/airflow/providers/cncf/kubernetes/operators/custom_object_launcher.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/custom_object_launcher.py similarity index 99% rename from providers/src/airflow/providers/cncf/kubernetes/operators/custom_object_launcher.py rename to providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/custom_object_launcher.py index 8e2edc260670b..33c204b8d3f19 100644 --- a/providers/src/airflow/providers/cncf/kubernetes/operators/custom_object_launcher.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/custom_object_launcher.py @@ -24,8 +24,6 @@ from functools import cached_property import tenacity -from kubernetes.client import CoreV1Api, CustomObjectsApi, models as k8s -from kubernetes.client.rest import ApiException from airflow.exceptions import AirflowException from airflow.providers.cncf.kubernetes.resource_convert.configmap import ( @@ -39,6 +37,8 @@ ) from airflow.providers.cncf.kubernetes.utils.pod_manager import PodManager from airflow.utils.log.logging_mixin import LoggingMixin +from kubernetes.client import CoreV1Api, CustomObjectsApi, models as k8s +from kubernetes.client.rest import ApiException def should_retry_start_spark_job(exception: BaseException) -> bool: @@ -291,7 +291,7 @@ def start_spark_job(self, image=None, code_path=None, startup_timeout: int = 600 # Wait for the driver pod to come alive self.pod_spec = k8s.V1Pod( metadata=k8s.V1ObjectMeta( - labels=self.spark_obj_spec["spec"]["driver"]["labels"], + labels=self.spark_obj_spec["spec"]["driver"].get("labels"), name=self.spark_obj_spec["metadata"]["name"] + "-driver", namespace=self.namespace, ) diff --git a/providers/src/airflow/providers/cncf/kubernetes/operators/job.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/job.py similarity index 99% rename from providers/src/airflow/providers/cncf/kubernetes/operators/job.py rename to providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/job.py index 36fa4b92a91be..eb6dfae43162a 100644 --- a/providers/src/airflow/providers/cncf/kubernetes/operators/job.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/job.py @@ -26,10 +26,6 @@ from functools import cached_property from typing import TYPE_CHECKING -from kubernetes.client import BatchV1Api, models as k8s -from kubernetes.client.api_client import ApiClient -from kubernetes.client.rest import ApiException - from airflow.configuration import conf from airflow.exceptions import AirflowException from airflow.models import BaseOperator @@ -44,6 +40,9 @@ from airflow.providers.cncf.kubernetes.utils.pod_manager import EMPTY_XCOM_RESULT, PodNotFoundException from airflow.utils import yaml from airflow.utils.context import Context +from kubernetes.client import BatchV1Api, models as k8s +from kubernetes.client.api_client import ApiClient +from kubernetes.client.rest import ApiException if TYPE_CHECKING: from airflow.utils.context import Context @@ -167,6 +166,7 @@ def execute(self, context: Context): ti.xcom_push(key="job_name", value=self.job.metadata.name) ti.xcom_push(key="job_namespace", value=self.job.metadata.namespace) + self.pod: k8s.V1Pod | None if self.pod is None: self.pod = self.get_or_create_pod( # must set `self.pod` for `on_kill` pod_request_obj=self.pod_request_obj, diff --git a/providers/src/airflow/providers/cncf/kubernetes/operators/kueue.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/kueue.py similarity index 99% rename from providers/src/airflow/providers/cncf/kubernetes/operators/kueue.py rename to providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/kueue.py index 64965b34a9526..aa6ebb32ad5c5 100644 --- a/providers/src/airflow/providers/cncf/kubernetes/operators/kueue.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/kueue.py @@ -22,12 +22,11 @@ from collections.abc import Sequence from functools import cached_property -from kubernetes.utils import FailToCreateError - from airflow.exceptions import AirflowException from airflow.models import BaseOperator from airflow.providers.cncf.kubernetes.hooks.kubernetes import KubernetesHook from airflow.providers.cncf.kubernetes.operators.job import KubernetesJobOperator +from kubernetes.utils import FailToCreateError class KubernetesInstallKueueOperator(BaseOperator): @@ -95,6 +94,7 @@ def __init__(self, queue_name: str, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.queue_name = queue_name + self.suspend: bool if self.suspend is False: raise AirflowException( "The `suspend` parameter can't be False. If you want to use Kueue for running Job" diff --git a/providers/src/airflow/providers/cncf/kubernetes/operators/pod.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/pod.py similarity index 99% rename from providers/src/airflow/providers/cncf/kubernetes/operators/pod.py rename to providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/pod.py index 5d1bea11922e8..d4e2e76a66af6 100644 --- a/providers/src/airflow/providers/cncf/kubernetes/operators/pod.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/pod.py @@ -32,13 +32,10 @@ from functools import cached_property from typing import TYPE_CHECKING, Any, Callable, Literal -import kubernetes import tenacity -from kubernetes.client import CoreV1Api, V1Pod, models as k8s -from kubernetes.client.exceptions import ApiException -from kubernetes.stream import stream from urllib3.exceptions import HTTPError +import kubernetes from airflow.configuration import conf from airflow.exceptions import ( AirflowException, @@ -84,6 +81,9 @@ from airflow.utils import yaml from airflow.utils.helpers import prune_dict, validate_key from airflow.version import version as airflow_version +from kubernetes.client import CoreV1Api, V1Pod, models as k8s +from kubernetes.client.exceptions import ApiException +from kubernetes.stream import stream if TYPE_CHECKING: import jinja2 @@ -489,7 +489,7 @@ def _get_ti_pod_labels(context: Context | None = None, include_try_number: bool } map_index = ti.map_index - if map_index >= 0: + if map_index is not None and map_index >= 0: labels["map_index"] = str(map_index) if include_try_number: diff --git a/providers/src/airflow/providers/cncf/kubernetes/operators/resource.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/resource.py similarity index 100% rename from providers/src/airflow/providers/cncf/kubernetes/operators/resource.py rename to providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/resource.py index aef972faf26e8..ca2767d48e520 100644 --- a/providers/src/airflow/providers/cncf/kubernetes/operators/resource.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/resource.py @@ -25,7 +25,6 @@ import tenacity import yaml -from kubernetes.utils import create_from_yaml from airflow.exceptions import AirflowException from airflow.models import BaseOperator @@ -33,6 +32,7 @@ from airflow.providers.cncf.kubernetes.kubernetes_helper_functions import should_retry_creation from airflow.providers.cncf.kubernetes.utils.delete_from import delete_from_yaml from airflow.providers.cncf.kubernetes.utils.k8s_resource_iterator import k8s_resource_iterator +from kubernetes.utils import create_from_yaml if TYPE_CHECKING: from kubernetes.client import ApiClient, CustomObjectsApi diff --git a/providers/src/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py similarity index 94% rename from providers/src/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py rename to providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py index c0b90ebacb9a6..d90085e68c44b 100644 --- a/providers/src/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py @@ -19,9 +19,7 @@ from functools import cached_property from pathlib import Path -from typing import TYPE_CHECKING, Any - -from kubernetes.client import CoreV1Api, CustomObjectsApi, models as k8s +from typing import TYPE_CHECKING, Any, cast from airflow.exceptions import AirflowException from airflow.providers.cncf.kubernetes import pod_generator @@ -32,6 +30,7 @@ from airflow.providers.cncf.kubernetes.pod_generator import MAX_LABEL_LEN, PodGenerator from airflow.providers.cncf.kubernetes.utils.pod_manager import PodManager from airflow.utils.helpers import prune_dict +from kubernetes.client import CoreV1Api, CustomObjectsApi, models as k8s if TYPE_CHECKING: import jinja2 @@ -117,6 +116,10 @@ def __init__( self.success_run_history_limit = success_run_history_limit self.random_name_suffix = random_name_suffix + # fix mypy typing + self.base_container_name: str + self.container_logs: list[str] + if self.base_container_name != self.BASE_CONTAINER_NAME: self.log.warning( "base_container_name is not supported and will be overridden to %s", self.BASE_CONTAINER_NAME @@ -177,12 +180,7 @@ def create_job_name(self): return self._set_name(updated_name) @staticmethod - def _get_pod_identifying_label_string(labels) -> str: - filtered_labels = {label_id: label for label_id, label in labels.items() if label_id != "try_number"} - return ",".join([label_id + "=" + label for label_id, label in sorted(filtered_labels.items())]) - - @staticmethod - def create_labels_for_pod(context: dict | None = None, include_try_number: bool = True) -> dict: + def _get_ti_pod_labels(context: Context | None = None, include_try_number: bool = True) -> dict[str, str]: """ Generate labels for the pod to track the pod in case of Operator crash. @@ -193,8 +191,9 @@ def create_labels_for_pod(context: dict | None = None, include_try_number: bool if not context: return {} - ti = context["ti"] - run_id = context["run_id"] + context_dict = cast(dict, context) + ti = context_dict["ti"] + run_id = context_dict["run_id"] labels = { "dag_id": ti.dag_id, @@ -213,8 +212,8 @@ def create_labels_for_pod(context: dict | None = None, include_try_number: bool # In the case of sub dags this is just useful # TODO: Remove this when the minimum version of Airflow is bumped to 3.0 - if getattr(context["dag"], "is_subdag", False): - labels["parent_dag_id"] = context["dag"].parent_dag.dag_id + if getattr(context_dict["dag"], "is_subdag", False): + labels["parent_dag_id"] = context_dict["dag"].parent_dag.dag_id # Ensure that label is valid for Kube, # and if not truncate/remove invalid chars and replace with short hash. for label_id, label in labels.items(): @@ -235,9 +234,11 @@ def template_body(self): """Templated body for CustomObjectLauncher.""" return self.manage_template_specs() - def find_spark_job(self, context): - labels = self.create_labels_for_pod(context, include_try_number=False) - label_selector = self._get_pod_identifying_label_string(labels) + ",spark-role=driver" + def find_spark_job(self, context, exclude_checked: bool = True): + label_selector = ( + self._build_find_pod_label_selector(context, exclude_checked=exclude_checked) + + ",spark-role=driver" + ) pod_list = self.client.list_namespaced_pod(self.namespace, label_selector=label_selector).items pod = None diff --git a/providers/src/airflow/providers/cncf/kubernetes/pod_generator.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/pod_generator.py similarity index 100% rename from providers/src/airflow/providers/cncf/kubernetes/pod_generator.py rename to providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/pod_generator.py index b90fa715333bf..5fab194963e4c 100644 --- a/providers/src/airflow/providers/cncf/kubernetes/pod_generator.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/pod_generator.py @@ -34,8 +34,6 @@ import re2 from dateutil import parser -from kubernetes.client import models as k8s -from kubernetes.client.api_client import ApiClient from airflow.exceptions import ( AirflowConfigException, @@ -49,6 +47,8 @@ from airflow.utils import yaml from airflow.utils.hashlib_wrapper import md5 from airflow.version import version as airflow_version +from kubernetes.client import models as k8s +from kubernetes.client.api_client import ApiClient if TYPE_CHECKING: import datetime diff --git a/providers/src/airflow/providers/cncf/kubernetes/resource_convert/__init__.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/pod_template_file_examples/__init__.py similarity index 100% rename from providers/src/airflow/providers/cncf/kubernetes/resource_convert/__init__.py rename to providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/pod_template_file_examples/__init__.py diff --git a/providers/src/airflow/providers/cncf/kubernetes/pod_template_file_examples/dags_in_image_template.yaml b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/pod_template_file_examples/dags_in_image_template.yaml similarity index 100% rename from providers/src/airflow/providers/cncf/kubernetes/pod_template_file_examples/dags_in_image_template.yaml rename to providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/pod_template_file_examples/dags_in_image_template.yaml diff --git a/providers/src/airflow/providers/cncf/kubernetes/pod_template_file_examples/dags_in_volume_template.yaml b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/pod_template_file_examples/dags_in_volume_template.yaml similarity index 100% rename from providers/src/airflow/providers/cncf/kubernetes/pod_template_file_examples/dags_in_volume_template.yaml rename to providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/pod_template_file_examples/dags_in_volume_template.yaml diff --git a/providers/src/airflow/providers/cncf/kubernetes/pod_template_file_examples/git_sync_template.yaml b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/pod_template_file_examples/git_sync_template.yaml similarity index 100% rename from providers/src/airflow/providers/cncf/kubernetes/pod_template_file_examples/git_sync_template.yaml rename to providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/pod_template_file_examples/git_sync_template.yaml diff --git a/providers/src/airflow/providers/cncf/kubernetes/python_kubernetes_script.jinja2 b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/python_kubernetes_script.jinja2 similarity index 100% rename from providers/src/airflow/providers/cncf/kubernetes/python_kubernetes_script.jinja2 rename to providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/python_kubernetes_script.jinja2 diff --git a/providers/src/airflow/providers/cncf/kubernetes/python_kubernetes_script.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/python_kubernetes_script.py similarity index 100% rename from providers/src/airflow/providers/cncf/kubernetes/python_kubernetes_script.py rename to providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/python_kubernetes_script.py diff --git a/providers/src/airflow/providers/cncf/kubernetes/sensors/__init__.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/resource_convert/__init__.py similarity index 100% rename from providers/src/airflow/providers/cncf/kubernetes/sensors/__init__.py rename to providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/resource_convert/__init__.py diff --git a/providers/src/airflow/providers/cncf/kubernetes/resource_convert/configmap.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/resource_convert/configmap.py similarity index 100% rename from providers/src/airflow/providers/cncf/kubernetes/resource_convert/configmap.py rename to providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/resource_convert/configmap.py diff --git a/providers/src/airflow/providers/cncf/kubernetes/resource_convert/env_variable.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/resource_convert/env_variable.py similarity index 99% rename from providers/src/airflow/providers/cncf/kubernetes/resource_convert/env_variable.py rename to providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/resource_convert/env_variable.py index db8c5301cb051..d950acf5ac42b 100644 --- a/providers/src/airflow/providers/cncf/kubernetes/resource_convert/env_variable.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/resource_convert/env_variable.py @@ -16,9 +16,8 @@ # under the License. from __future__ import annotations -from kubernetes.client import models as k8s - from airflow.exceptions import AirflowException +from kubernetes.client import models as k8s def convert_env_vars(env_vars) -> list[k8s.V1EnvVar]: diff --git a/providers/src/airflow/providers/cncf/kubernetes/resource_convert/secret.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/resource_convert/secret.py similarity index 100% rename from providers/src/airflow/providers/cncf/kubernetes/resource_convert/secret.py rename to providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/resource_convert/secret.py diff --git a/providers/src/airflow/providers/cncf/kubernetes/secret.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/secret.py similarity index 99% rename from providers/src/airflow/providers/cncf/kubernetes/secret.py rename to providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/secret.py index 692777de447cd..2fc6c87757313 100644 --- a/providers/src/airflow/providers/cncf/kubernetes/secret.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/secret.py @@ -21,10 +21,9 @@ import copy import uuid -from kubernetes.client import models as k8s - from airflow.exceptions import AirflowConfigException from airflow.providers.cncf.kubernetes.k8s_model import K8SModel +from kubernetes.client import models as k8s class Secret(K8SModel): diff --git a/providers/src/airflow/providers/cncf/kubernetes/triggers/__init__.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/sensors/__init__.py similarity index 100% rename from providers/src/airflow/providers/cncf/kubernetes/triggers/__init__.py rename to providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/sensors/__init__.py diff --git a/providers/src/airflow/providers/cncf/kubernetes/sensors/spark_kubernetes.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/sensors/spark_kubernetes.py similarity index 99% rename from providers/src/airflow/providers/cncf/kubernetes/sensors/spark_kubernetes.py rename to providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/sensors/spark_kubernetes.py index 08968d245fde7..7a42c21d7184c 100644 --- a/providers/src/airflow/providers/cncf/kubernetes/sensors/spark_kubernetes.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/sensors/spark_kubernetes.py @@ -21,11 +21,10 @@ from functools import cached_property from typing import TYPE_CHECKING -from kubernetes import client - from airflow.exceptions import AirflowException from airflow.providers.cncf.kubernetes.hooks.kubernetes import KubernetesHook from airflow.sensors.base import BaseSensorOperator +from kubernetes import client if TYPE_CHECKING: from airflow.utils.context import Context diff --git a/providers/src/airflow/providers/cncf/kubernetes/template_rendering.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/template_rendering.py similarity index 100% rename from providers/src/airflow/providers/cncf/kubernetes/template_rendering.py rename to providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/template_rendering.py index c499dad24038c..6abeeddde7884 100644 --- a/providers/src/airflow/providers/cncf/kubernetes/template_rendering.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/template_rendering.py @@ -20,13 +20,13 @@ from typing import TYPE_CHECKING from jinja2 import TemplateAssertionError, UndefinedError -from kubernetes.client.api_client import ApiClient from airflow.exceptions import AirflowException from airflow.providers.cncf.kubernetes.kube_config import KubeConfig from airflow.providers.cncf.kubernetes.kubernetes_helper_functions import create_unique_id from airflow.providers.cncf.kubernetes.pod_generator import PodGenerator from airflow.utils.session import NEW_SESSION, provide_session +from kubernetes.client.api_client import ApiClient if TYPE_CHECKING: from airflow.models.taskinstance import TaskInstance diff --git a/providers/src/airflow/providers/yandex/hooks/__init__.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/triggers/__init__.py similarity index 100% rename from providers/src/airflow/providers/yandex/hooks/__init__.py rename to providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/triggers/__init__.py diff --git a/providers/src/airflow/providers/cncf/kubernetes/triggers/job.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/triggers/job.py similarity index 100% rename from providers/src/airflow/providers/cncf/kubernetes/triggers/job.py rename to providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/triggers/job.py diff --git a/providers/src/airflow/providers/cncf/kubernetes/triggers/pod.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/triggers/pod.py similarity index 100% rename from providers/src/airflow/providers/cncf/kubernetes/triggers/pod.py rename to providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/triggers/pod.py diff --git a/providers/src/airflow/providers/cncf/kubernetes/utils/__init__.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/utils/__init__.py similarity index 100% rename from providers/src/airflow/providers/cncf/kubernetes/utils/__init__.py rename to providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/utils/__init__.py diff --git a/providers/src/airflow/providers/cncf/kubernetes/utils/delete_from.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/utils/delete_from.py similarity index 100% rename from providers/src/airflow/providers/cncf/kubernetes/utils/delete_from.py rename to providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/utils/delete_from.py diff --git a/providers/src/airflow/providers/cncf/kubernetes/utils/k8s_resource_iterator.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/utils/k8s_resource_iterator.py similarity index 99% rename from providers/src/airflow/providers/cncf/kubernetes/utils/k8s_resource_iterator.py rename to providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/utils/k8s_resource_iterator.py index d66eb9fc80bc0..6433ae465a40b 100644 --- a/providers/src/airflow/providers/cncf/kubernetes/utils/k8s_resource_iterator.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/utils/k8s_resource_iterator.py @@ -19,9 +19,8 @@ from collections.abc import Iterator from typing import Callable -from kubernetes.utils import FailToCreateError - from airflow.providers.cncf.kubernetes.utils.delete_from import FailToDeleteError +from kubernetes.utils import FailToCreateError def k8s_resource_iterator(callback: Callable[[dict], None], resources: Iterator) -> None: diff --git a/providers/src/airflow/providers/cncf/kubernetes/utils/pod_manager.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/utils/pod_manager.py similarity index 99% rename from providers/src/airflow/providers/cncf/kubernetes/utils/pod_manager.py rename to providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/utils/pod_manager.py index 199d6a6d35dd3..9ea5228dac040 100644 --- a/providers/src/airflow/providers/cncf/kubernetes/utils/pod_manager.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/utils/pod_manager.py @@ -31,9 +31,6 @@ import pendulum import tenacity -from kubernetes import client, watch -from kubernetes.client.rest import ApiException -from kubernetes.stream import stream as kubernetes_stream from pendulum import DateTime from pendulum.parsing.exceptions import ParserError from typing_extensions import Literal @@ -44,12 +41,16 @@ from airflow.providers.cncf.kubernetes.utils.xcom_sidecar import PodDefaults from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.timezone import utcnow +from kubernetes import client, watch +from kubernetes.client.rest import ApiException +from kubernetes.stream import stream as kubernetes_stream if TYPE_CHECKING: + from urllib3.response import HTTPResponse + from kubernetes.client.models.core_v1_event_list import CoreV1EventList from kubernetes.client.models.v1_container_status import V1ContainerStatus from kubernetes.client.models.v1_pod import V1Pod - from urllib3.response import HTTPResponse EMPTY_XCOM_RESULT = "__airflow_xcom_result_empty__" diff --git a/providers/src/airflow/providers/cncf/kubernetes/utils/xcom_sidecar.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/utils/xcom_sidecar.py similarity index 100% rename from providers/src/airflow/providers/cncf/kubernetes/utils/xcom_sidecar.py rename to providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/utils/xcom_sidecar.py diff --git a/providers/src/airflow/providers/cncf/kubernetes/version_compat.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/version_compat.py similarity index 100% rename from providers/src/airflow/providers/cncf/kubernetes/version_compat.py rename to providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/version_compat.py diff --git a/providers/cncf/kubernetes/tests/conftest.py b/providers/cncf/kubernetes/tests/conftest.py new file mode 100644 index 0000000000000..068fe6bbf5ae9 --- /dev/null +++ b/providers/cncf/kubernetes/tests/conftest.py @@ -0,0 +1,32 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import pathlib + +import pytest + +pytest_plugins = "tests_common.pytest_plugin" + + +@pytest.hookimpl(tryfirst=True) +def pytest_configure(config: pytest.Config) -> None: + deprecations_ignore_path = pathlib.Path(__file__).parent.joinpath("deprecations_ignore.yml") + dep_path = [deprecations_ignore_path] if deprecations_ignore_path.exists() else [] + config.inicfg["airflow_deprecations_ignore"] = ( + config.inicfg.get("airflow_deprecations_ignore", []) + dep_path # type: ignore[assignment,operator] + ) diff --git a/providers/cncf/kubernetes/tests/provider_tests/__init__.py b/providers/cncf/kubernetes/tests/provider_tests/__init__.py new file mode 100644 index 0000000000000..e8fd22856438c --- /dev/null +++ b/providers/cncf/kubernetes/tests/provider_tests/__init__.py @@ -0,0 +1,17 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +__path__ = __import__("pkgutil").extend_path(__path__, __name__) # type: ignore diff --git a/providers/cncf/kubernetes/tests/provider_tests/cncf/__init__.py b/providers/cncf/kubernetes/tests/provider_tests/cncf/__init__.py new file mode 100644 index 0000000000000..e8fd22856438c --- /dev/null +++ b/providers/cncf/kubernetes/tests/provider_tests/cncf/__init__.py @@ -0,0 +1,17 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +__path__ = __import__("pkgutil").extend_path(__path__, __name__) # type: ignore diff --git a/providers/tests/cloudant/__init__.py b/providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/__init__.py similarity index 100% rename from providers/tests/cloudant/__init__.py rename to providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/__init__.py diff --git a/providers/src/airflow/providers/yandex/links/__init__.py b/providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/backcompat/__init__.py similarity index 100% rename from providers/src/airflow/providers/yandex/links/__init__.py rename to providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/backcompat/__init__.py diff --git a/providers/tests/cncf/kubernetes/backcompat/test_backwards_compat_converters.py b/providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/backcompat/test_backwards_compat_converters.py similarity index 100% rename from providers/tests/cncf/kubernetes/backcompat/test_backwards_compat_converters.py rename to providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/backcompat/test_backwards_compat_converters.py diff --git a/providers/src/airflow/providers/yandex/operators/__init__.py b/providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/cli/__init__.py similarity index 100% rename from providers/src/airflow/providers/yandex/operators/__init__.py rename to providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/cli/__init__.py diff --git a/providers/tests/cncf/kubernetes/cli/test_kubernetes_command.py b/providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/cli/test_kubernetes_command.py similarity index 100% rename from providers/tests/cncf/kubernetes/cli/test_kubernetes_command.py rename to providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/cli/test_kubernetes_command.py diff --git a/providers/tests/cncf/kubernetes/conftest.py b/providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/conftest.py similarity index 100% rename from providers/tests/cncf/kubernetes/conftest.py rename to providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/conftest.py diff --git a/providers/src/airflow/providers/yandex/secrets/__init__.py b/providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/data_files/__init__.py similarity index 100% rename from providers/src/airflow/providers/yandex/secrets/__init__.py rename to providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/data_files/__init__.py diff --git a/providers/src/airflow/providers/yandex/utils/__init__.py b/providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/data_files/executor/__init__.py similarity index 100% rename from providers/src/airflow/providers/yandex/utils/__init__.py rename to providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/data_files/executor/__init__.py diff --git a/providers/tests/cncf/kubernetes/data_files/executor/basic_template.yaml b/providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/data_files/executor/basic_template.yaml similarity index 100% rename from providers/tests/cncf/kubernetes/data_files/executor/basic_template.yaml rename to providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/data_files/executor/basic_template.yaml diff --git a/providers/tests/cncf/kubernetes/data_files/kube_config b/providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/data_files/kube_config similarity index 100% rename from providers/tests/cncf/kubernetes/data_files/kube_config rename to providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/data_files/kube_config diff --git a/providers/tests/amazon/aws/auth_manager/views/__init__.py b/providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/data_files/pods/__init__.py similarity index 100% rename from providers/tests/amazon/aws/auth_manager/views/__init__.py rename to providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/data_files/pods/__init__.py diff --git a/providers/tests/cncf/kubernetes/data_files/pods/generator_base.yaml b/providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/data_files/pods/generator_base.yaml similarity index 100% rename from providers/tests/cncf/kubernetes/data_files/pods/generator_base.yaml rename to providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/data_files/pods/generator_base.yaml diff --git a/providers/tests/cncf/kubernetes/data_files/pods/generator_base_with_secrets.yaml b/providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/data_files/pods/generator_base_with_secrets.yaml similarity index 100% rename from providers/tests/cncf/kubernetes/data_files/pods/generator_base_with_secrets.yaml rename to providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/data_files/pods/generator_base_with_secrets.yaml diff --git a/providers/tests/cncf/kubernetes/data_files/pods/template.yaml b/providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/data_files/pods/template.yaml similarity index 100% rename from providers/tests/cncf/kubernetes/data_files/pods/template.yaml rename to providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/data_files/pods/template.yaml diff --git a/providers/tests/apache/impala/__init__.py b/providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/data_files/spark/__init__.py similarity index 100% rename from providers/tests/apache/impala/__init__.py rename to providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/data_files/spark/__init__.py diff --git a/providers/tests/cncf/kubernetes/data_files/spark/application_template.yaml b/providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/data_files/spark/application_template.yaml similarity index 100% rename from providers/tests/cncf/kubernetes/data_files/spark/application_template.yaml rename to providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/data_files/spark/application_template.yaml diff --git a/providers/tests/cncf/kubernetes/data_files/spark/application_test.json b/providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/data_files/spark/application_test.json similarity index 100% rename from providers/tests/cncf/kubernetes/data_files/spark/application_test.json rename to providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/data_files/spark/application_test.json diff --git a/providers/tests/cncf/kubernetes/data_files/spark/application_test.yaml b/providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/data_files/spark/application_test.yaml similarity index 100% rename from providers/tests/cncf/kubernetes/data_files/spark/application_test.yaml rename to providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/data_files/spark/application_test.yaml diff --git a/providers/tests/cncf/kubernetes/data_files/spark/application_test_with_no_name_from_config.json b/providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/data_files/spark/application_test_with_no_name_from_config.json similarity index 100% rename from providers/tests/cncf/kubernetes/data_files/spark/application_test_with_no_name_from_config.json rename to providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/data_files/spark/application_test_with_no_name_from_config.json diff --git a/providers/tests/cncf/kubernetes/data_files/spark/application_test_with_no_name_from_config.yaml b/providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/data_files/spark/application_test_with_no_name_from_config.yaml similarity index 100% rename from providers/tests/cncf/kubernetes/data_files/spark/application_test_with_no_name_from_config.yaml rename to providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/data_files/spark/application_test_with_no_name_from_config.yaml diff --git a/providers/tests/apache/impala/hooks/__init__.py b/providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/decorators/__init__.py similarity index 100% rename from providers/tests/apache/impala/hooks/__init__.py rename to providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/decorators/__init__.py diff --git a/providers/tests/cncf/kubernetes/decorators/test_kubernetes.py b/providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/decorators/test_kubernetes.py similarity index 100% rename from providers/tests/cncf/kubernetes/decorators/test_kubernetes.py rename to providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/decorators/test_kubernetes.py diff --git a/providers/tests/cncf/kubernetes/backcompat/__init__.py b/providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/executors/__init__.py similarity index 100% rename from providers/tests/cncf/kubernetes/backcompat/__init__.py rename to providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/executors/__init__.py diff --git a/providers/tests/cncf/kubernetes/executors/test_kubernetes_executor.py b/providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/executors/test_kubernetes_executor.py similarity index 100% rename from providers/tests/cncf/kubernetes/executors/test_kubernetes_executor.py rename to providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/executors/test_kubernetes_executor.py diff --git a/providers/tests/cncf/kubernetes/executors/test_local_kubernetes_executor.py b/providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/executors/test_local_kubernetes_executor.py similarity index 100% rename from providers/tests/cncf/kubernetes/executors/test_local_kubernetes_executor.py rename to providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/executors/test_local_kubernetes_executor.py diff --git a/providers/tests/cncf/kubernetes/cli/__init__.py b/providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/hooks/__init__.py similarity index 100% rename from providers/tests/cncf/kubernetes/cli/__init__.py rename to providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/hooks/__init__.py diff --git a/providers/tests/cncf/kubernetes/hooks/test_kubernetes.py b/providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/hooks/test_kubernetes.py similarity index 100% rename from providers/tests/cncf/kubernetes/hooks/test_kubernetes.py rename to providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/hooks/test_kubernetes.py diff --git a/providers/tests/cncf/kubernetes/data_files/__init__.py b/providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/log_handlers/__init__.py similarity index 100% rename from providers/tests/cncf/kubernetes/data_files/__init__.py rename to providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/log_handlers/__init__.py diff --git a/providers/tests/cncf/kubernetes/log_handlers/test_log_handlers.py b/providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/log_handlers/test_log_handlers.py similarity index 100% rename from providers/tests/cncf/kubernetes/log_handlers/test_log_handlers.py rename to providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/log_handlers/test_log_handlers.py diff --git a/providers/tests/cloudant/hooks/__init__.py b/providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/models/__init__.py similarity index 100% rename from providers/tests/cloudant/hooks/__init__.py rename to providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/models/__init__.py diff --git a/providers/tests/cncf/kubernetes/models/test_secret.py b/providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/models/test_secret.py similarity index 100% rename from providers/tests/cncf/kubernetes/models/test_secret.py rename to providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/models/test_secret.py diff --git a/providers/tests/cncf/kubernetes/data_files/executor/__init__.py b/providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/operators/__init__.py similarity index 100% rename from providers/tests/cncf/kubernetes/data_files/executor/__init__.py rename to providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/operators/__init__.py diff --git a/providers/tests/cncf/kubernetes/operators/test_custom_object_launcher.py b/providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/operators/test_custom_object_launcher.py similarity index 91% rename from providers/tests/cncf/kubernetes/operators/test_custom_object_launcher.py rename to providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/operators/test_custom_object_launcher.py index 244fcf6fd2757..eab85e5a87339 100644 --- a/providers/tests/cncf/kubernetes/operators/test_custom_object_launcher.py +++ b/providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/operators/test_custom_object_launcher.py @@ -20,6 +20,7 @@ import pytest from kubernetes.client import ( + CustomObjectsApi, V1ContainerState, V1ContainerStateWaiting, V1ContainerStatus, @@ -37,24 +38,33 @@ @pytest.fixture def mock_launcher(): + name = "test-spark-job" + spec = { + "image": "gcr.io/spark-operator/spark-py:v3.0.0", + "driver": {}, + "executor": {}, + } + + custom_object_api = CustomObjectsApi() + custom_object_api.create_namespaced_custom_object = MagicMock( + return_value={"spec": spec, "metadata": {"name": name}} + ) + launcher = CustomObjectLauncher( - name="test-spark-job", + name=name, namespace="default", kube_client=MagicMock(), - custom_obj_api=MagicMock(), + custom_obj_api=custom_object_api, template_body={ "spark": { - "spec": { - "image": "gcr.io/spark-operator/spark-py:v3.0.0", - "driver": {}, - "executor": {}, - }, + "spec": spec, "apiVersion": "sparkoperator.k8s.io/v1beta2", "kind": "SparkApplication", }, }, ) launcher.pod_spec = V1Pod() + launcher.spark_job_not_running = MagicMock(return_value=False) return launcher @@ -203,6 +213,10 @@ def get_pod_status(self, reason: str, message: str | None = None): ] ) + @patch("airflow.providers.cncf.kubernetes.operators.custom_object_launcher.PodManager") + def test_start_spark_job_no_error(self, mock_pod_manager, mock_launcher): + mock_launcher.start_spark_job() + @patch("airflow.providers.cncf.kubernetes.operators.custom_object_launcher.PodManager") def test_check_pod_start_failure_no_error(self, mock_pod_manager, mock_launcher): mock_pod_manager.return_value.read_pod.return_value.status = self.get_pod_status("ContainerCreating") diff --git a/providers/tests/cncf/kubernetes/operators/test_job.py b/providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/operators/test_job.py similarity index 100% rename from providers/tests/cncf/kubernetes/operators/test_job.py rename to providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/operators/test_job.py diff --git a/providers/tests/cncf/kubernetes/operators/test_kueue.py b/providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/operators/test_kueue.py similarity index 100% rename from providers/tests/cncf/kubernetes/operators/test_kueue.py rename to providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/operators/test_kueue.py diff --git a/providers/tests/cncf/kubernetes/operators/test_pod.py b/providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/operators/test_pod.py similarity index 99% rename from providers/tests/cncf/kubernetes/operators/test_pod.py rename to providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/operators/test_pod.py index 0aacc4b782fd0..309d2abe642e1 100644 --- a/providers/tests/cncf/kubernetes/operators/test_pod.py +++ b/providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/operators/test_pod.py @@ -1543,8 +1543,7 @@ def test_get_logs_but_not_for_base_container( @patch(KUB_OP_PATH.format("find_pod")) def test_execute_sync_callbacks(self, find_pod_mock): from airflow.providers.cncf.kubernetes.callbacks import ExecutionMode - - from providers.tests.cncf.kubernetes.test_callbacks import ( + from provider_tests.cncf.kubernetes.test_callbacks import ( MockKubernetesPodOperatorCallback, MockWrapper, ) @@ -1630,8 +1629,7 @@ def test_execute_sync_callbacks(self, find_pod_mock): @patch(KUB_OP_PATH.format("find_pod")) def test_execute_sync_multiple_callbacks(self, find_pod_mock): from airflow.providers.cncf.kubernetes.callbacks import ExecutionMode - - from providers.tests.cncf.kubernetes.test_callbacks import ( + from provider_tests.cncf.kubernetes.test_callbacks import ( MockKubernetesPodOperatorCallback, MockWrapper, ) @@ -1716,8 +1714,7 @@ def test_execute_sync_multiple_callbacks(self, find_pod_mock): @patch(HOOK_CLASS, new=MagicMock) def test_execute_async_callbacks(self): from airflow.providers.cncf.kubernetes.callbacks import ExecutionMode - - from providers.tests.cncf.kubernetes.test_callbacks import ( + from provider_tests.cncf.kubernetes.test_callbacks import ( MockKubernetesPodOperatorCallback, MockWrapper, ) diff --git a/providers/tests/cncf/kubernetes/operators/test_resource.py b/providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/operators/test_resource.py similarity index 100% rename from providers/tests/cncf/kubernetes/operators/test_resource.py rename to providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/operators/test_resource.py diff --git a/providers/tests/cncf/kubernetes/operators/test_spark_kubernetes.py b/providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/operators/test_spark_kubernetes.py similarity index 96% rename from providers/tests/cncf/kubernetes/operators/test_spark_kubernetes.py rename to providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/operators/test_spark_kubernetes.py index e1e7e85bcdab6..c6814222c4775 100644 --- a/providers/tests/cncf/kubernetes/operators/test_spark_kubernetes.py +++ b/providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/operators/test_spark_kubernetes.py @@ -701,6 +701,35 @@ def test_get_logs_from_driver( follow_logs=True, ) + def test_find_custom_pod_labels( + self, + mock_create_namespaced_crd, + mock_get_namespaced_custom_object_status, + mock_cleanup, + mock_create_job_name, + mock_get_kube_client, + mock_create_pod, + mock_await_pod_start, + mock_await_pod_completion, + mock_fetch_requested_container_logs, + data_file, + ): + task_name = "test_find_custom_pod_labels" + job_spec = yaml.safe_load(data_file("spark/application_template.yaml").read_text()) + + mock_create_job_name.return_value = task_name + op = SparkKubernetesOperator( + template_spec=job_spec, + kubernetes_conn_id="kubernetes_default_kube_config", + task_id=task_name, + get_logs=True, + ) + context = create_context(op) + op.execute(context) + label_selector = op._build_find_pod_label_selector(context) + ",spark-role=driver" + op.find_spark_job(context) + mock_get_kube_client.list_namespaced_pod.assert_called_with("default", label_selector=label_selector) + @pytest.mark.db_test def test_template_body_templating(create_task_instance_of_operator, session): diff --git a/providers/tests/cncf/kubernetes/data_files/pods/__init__.py b/providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/resource_convert/__init__.py similarity index 100% rename from providers/tests/cncf/kubernetes/data_files/pods/__init__.py rename to providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/resource_convert/__init__.py diff --git a/providers/tests/cncf/kubernetes/resource_convert/test_configmap.py b/providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/resource_convert/test_configmap.py similarity index 100% rename from providers/tests/cncf/kubernetes/resource_convert/test_configmap.py rename to providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/resource_convert/test_configmap.py diff --git a/providers/tests/cncf/kubernetes/resource_convert/test_env_variable.py b/providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/resource_convert/test_env_variable.py similarity index 100% rename from providers/tests/cncf/kubernetes/resource_convert/test_env_variable.py rename to providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/resource_convert/test_env_variable.py diff --git a/providers/tests/cncf/kubernetes/resource_convert/test_secret.py b/providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/resource_convert/test_secret.py similarity index 100% rename from providers/tests/cncf/kubernetes/resource_convert/test_secret.py rename to providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/resource_convert/test_secret.py diff --git a/providers/tests/cncf/kubernetes/data_files/spark/__init__.py b/providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/sensors/__init__.py similarity index 100% rename from providers/tests/cncf/kubernetes/data_files/spark/__init__.py rename to providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/sensors/__init__.py diff --git a/providers/tests/cncf/kubernetes/sensors/test_spark_kubernetes.py b/providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/sensors/test_spark_kubernetes.py similarity index 100% rename from providers/tests/cncf/kubernetes/sensors/test_spark_kubernetes.py rename to providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/sensors/test_spark_kubernetes.py diff --git a/providers/tests/cncf/kubernetes/test_callbacks.py b/providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/test_callbacks.py similarity index 100% rename from providers/tests/cncf/kubernetes/test_callbacks.py rename to providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/test_callbacks.py diff --git a/providers/tests/cncf/kubernetes/test_client.py b/providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/test_client.py similarity index 100% rename from providers/tests/cncf/kubernetes/test_client.py rename to providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/test_client.py diff --git a/providers/tests/cncf/kubernetes/test_kubernetes_helper_functions.py b/providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/test_kubernetes_helper_functions.py similarity index 100% rename from providers/tests/cncf/kubernetes/test_kubernetes_helper_functions.py rename to providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/test_kubernetes_helper_functions.py diff --git a/providers/tests/cncf/kubernetes/test_pod_generator.py b/providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/test_pod_generator.py similarity index 100% rename from providers/tests/cncf/kubernetes/test_pod_generator.py rename to providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/test_pod_generator.py diff --git a/providers/tests/cncf/kubernetes/test_template_rendering.py b/providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/test_template_rendering.py similarity index 73% rename from providers/tests/cncf/kubernetes/test_template_rendering.py rename to providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/test_template_rendering.py index 7c587de7b9280..6f512cdfe805e 100644 --- a/providers/tests/cncf/kubernetes/test_template_rendering.py +++ b/providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/test_template_rendering.py @@ -118,45 +118,52 @@ def test_get_rendered_k8s_spec(render_k8s_pod_yaml, rtif_get_k8s_pod_yaml, creat @mock.patch.dict(os.environ, {"AIRFLOW_IS_K8S_EXECUTOR_POD": "True"}) -@mock.patch("airflow.utils.log.secrets_masker.redact", autospec=True, side_effect=lambda d, _=None: d) @mock.patch("airflow.providers.cncf.kubernetes.template_rendering.render_k8s_pod_yaml") -def test_get_k8s_pod_yaml(render_k8s_pod_yaml, redact, dag_maker, session): +def test_get_k8s_pod_yaml(render_k8s_pod_yaml, dag_maker, session): """ Test that k8s_pod_yaml is rendered correctly, stored in the Database, and are correctly fetched using RTIF.get_k8s_pod_yaml """ - with dag_maker("test_get_k8s_pod_yaml") as dag: - task = BashOperator(task_id="test", bash_command="echo hi") - dr = dag_maker.create_dagrun() - dag.fileloc = "/test_get_k8s_pod_yaml.py" + from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS - ti = dr.task_instances[0] - ti.task = task + target = ( + "airflow.sdk.execution_time.secrets_masker.redact" + if AIRFLOW_V_3_0_PLUS + else "airflow.utils.log.secrets_masker.redact" + ) + with mock.patch(target, autospec=True, side_effect=lambda d, _=None: d) as redact: + with dag_maker("test_get_k8s_pod_yaml") as dag: + task = BashOperator(task_id="test", bash_command="echo hi") + dr = dag_maker.create_dagrun() + dag.fileloc = "/test_get_k8s_pod_yaml.py" + + ti = dr.task_instances[0] + ti.task = task - render_k8s_pod_yaml.return_value = {"I'm a": "pod"} + render_k8s_pod_yaml.return_value = {"I'm a": "pod"} - rtif = RTIF(ti=ti) + rtif = RTIF(ti=ti) - assert ti.dag_id == rtif.dag_id - assert ti.task_id == rtif.task_id - assert ti.run_id == rtif.run_id + assert ti.dag_id == rtif.dag_id + assert ti.task_id == rtif.task_id + assert ti.run_id == rtif.run_id - expected_pod_yaml = {"I'm a": "pod"} + expected_pod_yaml = {"I'm a": "pod"} - assert rtif.k8s_pod_yaml == render_k8s_pod_yaml.return_value - # K8s pod spec dict was passed to redact - redact.assert_any_call(rtif.k8s_pod_yaml) + assert rtif.k8s_pod_yaml == render_k8s_pod_yaml.return_value + # K8s pod spec dict was passed to redact + redact.assert_any_call(rtif.k8s_pod_yaml) - with create_session() as session: - session.add(rtif) - session.flush() + with create_session() as session: + session.add(rtif) + session.flush() - assert expected_pod_yaml == RTIF.get_k8s_pod_yaml(ti=ti, session=session) - make_transient(ti) - # "Delete" it from the DB - session.rollback() + assert expected_pod_yaml == RTIF.get_k8s_pod_yaml(ti=ti, session=session) + make_transient(ti) + # "Delete" it from the DB + session.rollback() - # Test the else part of get_k8s_pod_yaml - # i.e. for the TIs that are not stored in RTIF table - # Fetching them will return None - assert RTIF.get_k8s_pod_yaml(ti=ti, session=session) is None + # Test the else part of get_k8s_pod_yaml + # i.e. for the TIs that are not stored in RTIF table + # Fetching them will return None + assert RTIF.get_k8s_pod_yaml(ti=ti, session=session) is None diff --git a/providers/tests/cncf/kubernetes/decorators/__init__.py b/providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/triggers/__init__.py similarity index 100% rename from providers/tests/cncf/kubernetes/decorators/__init__.py rename to providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/triggers/__init__.py diff --git a/providers/tests/cncf/kubernetes/triggers/test_job.py b/providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/triggers/test_job.py similarity index 100% rename from providers/tests/cncf/kubernetes/triggers/test_job.py rename to providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/triggers/test_job.py diff --git a/providers/tests/cncf/kubernetes/triggers/test_pod.py b/providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/triggers/test_pod.py similarity index 100% rename from providers/tests/cncf/kubernetes/triggers/test_pod.py rename to providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/triggers/test_pod.py diff --git a/providers/tests/cncf/kubernetes/executors/__init__.py b/providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/utils/__init__.py similarity index 100% rename from providers/tests/cncf/kubernetes/executors/__init__.py rename to providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/utils/__init__.py diff --git a/providers/tests/cncf/kubernetes/utils/test_k8s_resource_iterator.py b/providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/utils/test_k8s_resource_iterator.py similarity index 100% rename from providers/tests/cncf/kubernetes/utils/test_k8s_resource_iterator.py rename to providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/utils/test_k8s_resource_iterator.py diff --git a/providers/tests/cncf/kubernetes/utils/test_pod_manager.py b/providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/utils/test_pod_manager.py similarity index 99% rename from providers/tests/cncf/kubernetes/utils/test_pod_manager.py rename to providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/utils/test_pod_manager.py index 24ea794c2bffd..bc11e84dfb582 100644 --- a/providers/tests/cncf/kubernetes/utils/test_pod_manager.py +++ b/providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/utils/test_pod_manager.py @@ -40,8 +40,7 @@ container_is_terminated, ) from airflow.utils.timezone import utc - -from providers.tests.cncf.kubernetes.test_callbacks import MockKubernetesPodOperatorCallback, MockWrapper +from provider_tests.cncf.kubernetes.test_callbacks import MockKubernetesPodOperatorCallback, MockWrapper if TYPE_CHECKING: from pendulum import DateTime diff --git a/providers/tests/cncf/kubernetes/hooks/__init__.py b/providers/cncf/kubernetes/tests/system/cncf/kubernetes/__init__.py similarity index 100% rename from providers/tests/cncf/kubernetes/hooks/__init__.py rename to providers/cncf/kubernetes/tests/system/cncf/kubernetes/__init__.py diff --git a/providers/tests/system/cncf/kubernetes/example_kubernetes.py b/providers/cncf/kubernetes/tests/system/cncf/kubernetes/example_kubernetes.py similarity index 99% rename from providers/tests/system/cncf/kubernetes/example_kubernetes.py rename to providers/cncf/kubernetes/tests/system/cncf/kubernetes/example_kubernetes.py index dc74895a3e0fd..08a85a4a71da6 100644 --- a/providers/tests/system/cncf/kubernetes/example_kubernetes.py +++ b/providers/cncf/kubernetes/tests/system/cncf/kubernetes/example_kubernetes.py @@ -24,12 +24,11 @@ import os from datetime import datetime -from kubernetes.client import models as k8s - from airflow import DAG from airflow.providers.cncf.kubernetes.operators.pod import KubernetesPodOperator from airflow.providers.cncf.kubernetes.secret import Secret from airflow.providers.standard.operators.bash import BashOperator +from kubernetes.client import models as k8s # [START howto_operator_k8s_cluster_resources] secret_file = Secret("volume", "/etc/sql_conn", "airflow-secrets", "sql_alchemy_conn") diff --git a/providers/tests/system/cncf/kubernetes/example_kubernetes_async.py b/providers/cncf/kubernetes/tests/system/cncf/kubernetes/example_kubernetes_async.py similarity index 99% rename from providers/tests/system/cncf/kubernetes/example_kubernetes_async.py rename to providers/cncf/kubernetes/tests/system/cncf/kubernetes/example_kubernetes_async.py index 7eb08442be3ba..2b8b7387e09f5 100644 --- a/providers/tests/system/cncf/kubernetes/example_kubernetes_async.py +++ b/providers/cncf/kubernetes/tests/system/cncf/kubernetes/example_kubernetes_async.py @@ -24,12 +24,11 @@ import os from datetime import datetime -from kubernetes.client import models as k8s - from airflow import DAG from airflow.providers.cncf.kubernetes.operators.pod import KubernetesPodOperator from airflow.providers.cncf.kubernetes.secret import Secret from airflow.providers.standard.operators.bash import BashOperator +from kubernetes.client import models as k8s # [START howto_operator_k8s_cluster_resources] secret_file = Secret("volume", "/etc/sql_conn", "airflow-secrets", "sql_alchemy_conn") diff --git a/providers/tests/system/cncf/kubernetes/example_kubernetes_decorator.py b/providers/cncf/kubernetes/tests/system/cncf/kubernetes/example_kubernetes_decorator.py similarity index 100% rename from providers/tests/system/cncf/kubernetes/example_kubernetes_decorator.py rename to providers/cncf/kubernetes/tests/system/cncf/kubernetes/example_kubernetes_decorator.py diff --git a/providers/tests/system/cncf/kubernetes/example_kubernetes_job.py b/providers/cncf/kubernetes/tests/system/cncf/kubernetes/example_kubernetes_job.py similarity index 100% rename from providers/tests/system/cncf/kubernetes/example_kubernetes_job.py rename to providers/cncf/kubernetes/tests/system/cncf/kubernetes/example_kubernetes_job.py diff --git a/providers/tests/system/cncf/kubernetes/example_kubernetes_kueue.py b/providers/cncf/kubernetes/tests/system/cncf/kubernetes/example_kubernetes_kueue.py similarity index 99% rename from providers/tests/system/cncf/kubernetes/example_kubernetes_kueue.py rename to providers/cncf/kubernetes/tests/system/cncf/kubernetes/example_kubernetes_kueue.py index 9aac715fdbed9..063be244b7aa7 100644 --- a/providers/tests/system/cncf/kubernetes/example_kubernetes_kueue.py +++ b/providers/cncf/kubernetes/tests/system/cncf/kubernetes/example_kubernetes_kueue.py @@ -24,14 +24,13 @@ import os from datetime import datetime -from kubernetes.client import models as k8s - from airflow import DAG from airflow.providers.cncf.kubernetes.operators.kueue import ( KubernetesInstallKueueOperator, KubernetesStartKueueJobOperator, ) from airflow.providers.cncf.kubernetes.operators.resource import KubernetesCreateResourceOperator +from kubernetes.client import models as k8s ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID", "default") DAG_ID = "example_kubernetes_kueue_operators" diff --git a/providers/tests/system/cncf/kubernetes/example_kubernetes_resource.py b/providers/cncf/kubernetes/tests/system/cncf/kubernetes/example_kubernetes_resource.py similarity index 100% rename from providers/tests/system/cncf/kubernetes/example_kubernetes_resource.py rename to providers/cncf/kubernetes/tests/system/cncf/kubernetes/example_kubernetes_resource.py diff --git a/providers/tests/system/cncf/kubernetes/example_spark_kubernetes.py b/providers/cncf/kubernetes/tests/system/cncf/kubernetes/example_spark_kubernetes.py similarity index 100% rename from providers/tests/system/cncf/kubernetes/example_spark_kubernetes.py rename to providers/cncf/kubernetes/tests/system/cncf/kubernetes/example_spark_kubernetes.py diff --git a/providers/tests/system/cncf/kubernetes/example_spark_kubernetes_spark_pi.yaml b/providers/cncf/kubernetes/tests/system/cncf/kubernetes/example_spark_kubernetes_spark_pi.yaml similarity index 100% rename from providers/tests/system/cncf/kubernetes/example_spark_kubernetes_spark_pi.yaml rename to providers/cncf/kubernetes/tests/system/cncf/kubernetes/example_spark_kubernetes_spark_pi.yaml diff --git a/providers/tests/system/cncf/kubernetes/spark_job_template.yaml b/providers/cncf/kubernetes/tests/system/cncf/kubernetes/spark_job_template.yaml similarity index 100% rename from providers/tests/system/cncf/kubernetes/spark_job_template.yaml rename to providers/cncf/kubernetes/tests/system/cncf/kubernetes/spark_job_template.yaml diff --git a/providers/edge/README.rst b/providers/edge/README.rst index 4b1f8d835b853..625ad0a865fec 100644 --- a/providers/edge/README.rst +++ b/providers/edge/README.rst @@ -24,7 +24,7 @@ Package ``apache-airflow-providers-edge`` -Release: ``0.12.0pre0`` +Release: ``0.13.1pre0`` Handle edge workers on remote sites via HTTP(s) connection and orchestrates work over distributed sites @@ -37,7 +37,7 @@ This is a provider package for ``edge`` provider. All classes for this provider are in ``airflow.providers.edge`` python package. You can find package information and changelog for the provider -in the `documentation `_. +in the `documentation `_. Installation ------------ @@ -51,13 +51,13 @@ The package supports the following python versions: 3.9,3.10,3.11,3.12 Requirements ------------ -================== ================== +================== =================== PIP package Version required -================== ================== +================== =================== ``apache-airflow`` ``>=2.10.0`` ``pydantic`` ``>=2.10.2`` -``retryhttp`` ``>=1.2.0`` -================== ================== +``retryhttp`` ``>=1.2.0,!=1.3.0`` +================== =================== The changelog for the provider package can be found in the -`changelog `_. +`changelog `_. diff --git a/providers/edge/docs/changelog.rst b/providers/edge/docs/changelog.rst index adfbcee0a89a6..393239fb395f8 100644 --- a/providers/edge/docs/changelog.rst +++ b/providers/edge/docs/changelog.rst @@ -27,6 +27,27 @@ Changelog --------- +0.13.1pre0 +.......... + +Fix +~~~ + +* ``EdgeWorkerVersionException is raised if http 400 is responded on set_state.`` + +0.13.0pre0 +.......... + +Misc +~~~~ + +* ``Allow removing an Edge worker that is offline.`` + +Fixes +~~~~~ + +* ``Implement proper CSRF protection on plugin form.`` + 0.12.0pre0 .......... diff --git a/providers/edge/provider.yaml b/providers/edge/provider.yaml index 85a672ac2f9e9..7bb5ce88ea967 100644 --- a/providers/edge/provider.yaml +++ b/providers/edge/provider.yaml @@ -25,7 +25,7 @@ source-date-epoch: 1737371680 # note that those versions are maintained by release manager - do not update them manually versions: - - 0.12.0pre0 + - 0.13.1pre0 plugins: - name: edge_executor diff --git a/providers/edge/pyproject.toml b/providers/edge/pyproject.toml index 5769283f3392b..33a5ef3ecb921 100644 --- a/providers/edge/pyproject.toml +++ b/providers/edge/pyproject.toml @@ -25,7 +25,7 @@ build-backend = "flit_core.buildapi" [project] name = "apache-airflow-providers-edge" -version = "0.12.0pre0" +version = "0.13.1pre0" description = "Provider package apache-airflow-providers-edge for Apache Airflow" readme = "README.rst" authors = [ @@ -57,12 +57,12 @@ requires-python = "~=3.9" dependencies = [ "apache-airflow>=2.10.0", "pydantic>=2.10.2", - "retryhttp>=1.2.0", + "retryhttp>=1.2.0,!=1.3.0", ] [project.urls] -"Documentation" = "https://airflow.apache.org/docs/apache-airflow-providers-edge/0.12.0pre0" -"Changelog" = "https://airflow.apache.org/docs/apache-airflow-providers-edge/0.12.0pre0/changelog.html" +"Documentation" = "https://airflow.apache.org/docs/apache-airflow-providers-edge/0.13.1pre0" +"Changelog" = "https://airflow.apache.org/docs/apache-airflow-providers-edge/0.13.1pre0/changelog.html" "Bug Tracker" = "https://github.com/apache/airflow/issues" "Source Code" = "https://github.com/apache/airflow" "Slack Chat" = "https://s.apache.org/airflow-slack" diff --git a/providers/edge/src/airflow/providers/edge/__init__.py b/providers/edge/src/airflow/providers/edge/__init__.py index d4005adbb662a..749d124f95886 100644 --- a/providers/edge/src/airflow/providers/edge/__init__.py +++ b/providers/edge/src/airflow/providers/edge/__init__.py @@ -29,7 +29,7 @@ __all__ = ["__version__"] -__version__ = "0.12.0pre0" +__version__ = "0.13.1pre0" if packaging.version.parse(packaging.version.parse(airflow_version).base_version) < packaging.version.parse( "2.10.0" diff --git a/providers/edge/src/airflow/providers/edge/cli/api_client.py b/providers/edge/src/airflow/providers/edge/cli/api_client.py index 60c230dada1ca..c19504787d75c 100644 --- a/providers/edge/src/airflow/providers/edge/cli/api_client.py +++ b/providers/edge/src/airflow/providers/edge/cli/api_client.py @@ -30,6 +30,7 @@ from tenacity import before_log, wait_random_exponential from airflow.configuration import conf +from airflow.providers.edge.models.edge_worker import EdgeWorkerVersionException from airflow.providers.edge.worker_api.auth import jwt_signer from airflow.providers.edge.worker_api.datamodels import ( EdgeJobFetched, @@ -92,13 +93,18 @@ def worker_register( hostname: str, state: EdgeWorkerState, queues: list[str] | None, sysinfo: dict ) -> datetime: """Register worker with the Edge API.""" - result = _make_generic_request( - "POST", - f"worker/{quote(hostname)}", - WorkerStateBody(state=state, jobs_active=0, queues=queues, sysinfo=sysinfo).model_dump_json( - exclude_unset=True - ), - ) + try: + result = _make_generic_request( + "POST", + f"worker/{quote(hostname)}", + WorkerStateBody(state=state, jobs_active=0, queues=queues, sysinfo=sysinfo).model_dump_json( + exclude_unset=True + ), + ) + except requests.HTTPError as e: + if e.response.status_code == 400: + raise EdgeWorkerVersionException(str(e)) + raise e return datetime.fromisoformat(result) @@ -106,13 +112,18 @@ def worker_set_state( hostname: str, state: EdgeWorkerState, jobs_active: int, queues: list[str] | None, sysinfo: dict ) -> WorkerSetStateReturn: """Update the state of the worker in the central site and thereby implicitly heartbeat.""" - result = _make_generic_request( - "PATCH", - f"worker/{quote(hostname)}", - WorkerStateBody(state=state, jobs_active=jobs_active, queues=queues, sysinfo=sysinfo).model_dump_json( - exclude_unset=True - ), - ) + try: + result = _make_generic_request( + "PATCH", + f"worker/{quote(hostname)}", + WorkerStateBody( + state=state, jobs_active=jobs_active, queues=queues, sysinfo=sysinfo + ).model_dump_json(exclude_unset=True), + ) + except requests.HTTPError as e: + if e.response.status_code == 400: + raise EdgeWorkerVersionException(str(e)) + raise e return WorkerSetStateReturn(**result) diff --git a/providers/edge/src/airflow/providers/edge/get_provider_info.py b/providers/edge/src/airflow/providers/edge/get_provider_info.py index 63316fee3504b..1e9229e7f64b3 100644 --- a/providers/edge/src/airflow/providers/edge/get_provider_info.py +++ b/providers/edge/src/airflow/providers/edge/get_provider_info.py @@ -28,7 +28,7 @@ def get_provider_info(): "description": "Handle edge workers on remote sites via HTTP(s) connection and orchestrates work over distributed sites\n", "state": "not-ready", "source-date-epoch": 1737371680, - "versions": ["0.12.0pre0"], + "versions": ["0.13.1pre0"], "plugins": [ { "name": "edge_executor", @@ -99,5 +99,5 @@ def get_provider_info(): }, } }, - "dependencies": ["apache-airflow>=2.10.0", "pydantic>=2.10.2", "retryhttp>=1.2.0"], + "dependencies": ["apache-airflow>=2.10.0", "pydantic>=2.10.2", "retryhttp>=1.2.0,!=1.3.0"], } diff --git a/providers/edge/src/airflow/providers/edge/models/edge_worker.py b/providers/edge/src/airflow/providers/edge/models/edge_worker.py index 8c03bd16ed9ea..ec2f90c85c426 100644 --- a/providers/edge/src/airflow/providers/edge/models/edge_worker.py +++ b/providers/edge/src/airflow/providers/edge/models/edge_worker.py @@ -22,7 +22,7 @@ from enum import Enum from typing import TYPE_CHECKING -from sqlalchemy import Column, Integer, String, select +from sqlalchemy import Column, Integer, String, delete, select from airflow.exceptions import AirflowException from airflow.models.base import Base @@ -186,18 +186,20 @@ def reset_metrics(worker_name: str) -> None: @provide_session def request_maintenance(worker_name: str, session: Session = NEW_SESSION) -> None: """Writes maintenance request to the db""" - query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name == worker_name) worker: EdgeWorkerModel = session.scalar(query) worker.state = EdgeWorkerState.MAINTENANCE_REQUEST - session.commit() @provide_session def exit_maintenance(worker_name: str, session: Session = NEW_SESSION) -> None: """Writes maintenance exit to the db""" - query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name == worker_name) worker: EdgeWorkerModel = session.scalar(query) worker.state = EdgeWorkerState.MAINTENANCE_EXIT - session.commit() + + +@provide_session +def remove_worker(worker_name: str, session: Session = NEW_SESSION) -> None: + """Remove a worker that is offline or just gone from DB""" + session.execute(delete(EdgeWorkerModel).where(EdgeWorkerModel.worker_name == worker_name)) diff --git a/providers/edge/src/airflow/providers/edge/plugins/edge_executor_plugin.py b/providers/edge/src/airflow/providers/edge/plugins/edge_executor_plugin.py index 9274a412f65c5..34cc0d34a2bb6 100644 --- a/providers/edge/src/airflow/providers/edge/plugins/edge_executor_plugin.py +++ b/providers/edge/src/airflow/providers/edge/plugins/edge_executor_plugin.py @@ -34,7 +34,6 @@ from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.yaml import safe_load from airflow.www import utils as wwwutils -from airflow.www.app import csrf from airflow.www.auth import has_access_view from airflow.www.constants import SWAGGER_BUNDLE, SWAGGER_ENABLED from airflow.www.extensions.init_views import _CustomErrorRequestBodyValidator, _LazyResolver @@ -118,22 +117,26 @@ def status(self, session: Session = NEW_SESSION): @expose("/status/maintenance//on", methods=["POST"]) @has_access_view(AccessView.JOBS) - @provide_session - @csrf.exempt - def worker_to_maintenance(self, worker_name: str, session: Session = NEW_SESSION): + def worker_to_maintenance(self, worker_name: str): from airflow.providers.edge.models.edge_worker import request_maintenance - request_maintenance(worker_name, session) + request_maintenance(worker_name) return redirect(url_for("EdgeWorkerHosts.status")) @expose("/status/maintenance//off", methods=["POST"]) @has_access_view(AccessView.JOBS) - @provide_session - @csrf.exempt - def remove_worker_from_maintenance(self, worker_name: str, session: Session = NEW_SESSION): + def remove_worker_from_maintenance(self, worker_name: str): from airflow.providers.edge.models.edge_worker import exit_maintenance - exit_maintenance(worker_name, session) + exit_maintenance(worker_name) + return redirect(url_for("EdgeWorkerHosts.status")) + + @expose("/status/maintenance//remove", methods=["POST"]) + @has_access_view(AccessView.JOBS) + def remove_worker(self, worker_name: str): + from airflow.providers.edge.models.edge_worker import remove_worker + + remove_worker(worker_name) return redirect(url_for("EdgeWorkerHosts.status")) diff --git a/providers/edge/src/airflow/providers/edge/plugins/templates/edge_worker_hosts.html b/providers/edge/src/airflow/providers/edge/plugins/templates/edge_worker_hosts.html index 8ba99836dc78c..0ebea69b64180 100644 --- a/providers/edge/src/airflow/providers/edge/plugins/templates/edge_worker_hosts.html +++ b/providers/edge/src/airflow/providers/edge/plugins/templates/edge_worker_hosts.html @@ -103,16 +103,25 @@

Edge Worker Hosts

{%- if host.state in ["idle", "running"] -%}
+
{%- elif host.state in ["maintenance pending", "maintenance mode", "maintenance request"] -%}
+
+ {%- elif host.state in ["offline", "unknown", "offline maintenance"] -%} +
+ + +
{% endif %} diff --git a/providers/elasticsearch/src/airflow/providers/elasticsearch/hooks/elasticsearch.py b/providers/elasticsearch/src/airflow/providers/elasticsearch/hooks/elasticsearch.py index ab1bc433d94a4..582e4abdb9e12 100644 --- a/providers/elasticsearch/src/airflow/providers/elasticsearch/hooks/elasticsearch.py +++ b/providers/elasticsearch/src/airflow/providers/elasticsearch/hooks/elasticsearch.py @@ -17,6 +17,7 @@ # under the License. from __future__ import annotations +from collections.abc import Iterable, Mapping from functools import cached_property from typing import TYPE_CHECKING, Any from urllib import parse @@ -42,6 +43,73 @@ def connect( return ESConnection(host, port, user, password, scheme, **kwargs) +class ElasticsearchSQLCursor: + """A PEP 249-like Cursor class for Elasticsearch SQL API""" + + def __init__(self, es: Elasticsearch, **kwargs): + self.es = es + self.body = { + "fetch_size": kwargs.get("fetch_size", 1000), + "field_multi_value_leniency": kwargs.get("field_multi_value_leniency", False), + } + self._response: ObjectApiResponse | None = None + + @property + def response(self) -> ObjectApiResponse: + return self._response or {} # type: ignore + + @response.setter + def response(self, value): + self._response = value + + @property + def cursor(self): + return self.response.get("cursor") + + @property + def rows(self): + return self.response.get("rows", []) + + @property + def rowcount(self) -> int: + return len(self.rows) + + @property + def description(self) -> list[tuple]: + return [(column["name"], column["type"]) for column in self.response.get("columns", [])] + + def execute( + self, statement: str, params: Iterable | Mapping[str, Any] | None = None + ) -> ObjectApiResponse: + self.body["query"] = statement + if params: + self.body["params"] = params + self.response = self.es.sql.query(body=self.body) + if self.cursor: + self.body["cursor"] = self.cursor + else: + self.body.pop("cursor", None) + return self.response + + def fetchone(self): + if self.rows: + return self.rows[0] + return None + + def fetchmany(self, size: int | None = None): + raise NotImplementedError() + + def fetchall(self): + results = self.rows + while self.cursor: + self.execute(statement=self.body["query"]) + results.extend(self.rows) + return results + + def close(self): + self._response = None + + class ESConnection: """wrapper class for elasticsearch.Elasticsearch.""" @@ -67,9 +135,19 @@ def __init__( else: self.es = Elasticsearch(self.url, **self.kwargs) - def execute_sql(self, query: str) -> ObjectApiResponse: - sql_query = {"query": query} - return self.es.sql.query(body=sql_query) + def cursor(self) -> ElasticsearchSQLCursor: + return ElasticsearchSQLCursor(self.es, **self.kwargs) + + def close(self): + self.es.close() + + def commit(self): + pass + + def execute_sql( + self, query: str, params: Iterable | Mapping[str, Any] | None = None + ) -> ObjectApiResponse: + return self.cursor().execute(query, params) class ElasticsearchSQLHook(DbApiHook): @@ -84,13 +162,13 @@ class ElasticsearchSQLHook(DbApiHook): conn_name_attr = "elasticsearch_conn_id" default_conn_name = "elasticsearch_default" + connector = ESConnection conn_type = "elasticsearch" hook_name = "Elasticsearch" def __init__(self, schema: str = "http", connection: AirflowConnection | None = None, *args, **kwargs): super().__init__(*args, **kwargs) self.schema = schema - self._connection = connection def get_conn(self) -> ESConnection: """Return an elasticsearch connection object.""" @@ -104,11 +182,10 @@ def get_conn(self) -> ESConnection: "scheme": conn.schema or "http", } - if conn.extra_dejson.get("http_compress", False): - conn_args["http_compress"] = bool(["http_compress"]) + conn_args.update(conn.extra_dejson) - if conn.extra_dejson.get("timeout", False): - conn_args["timeout"] = conn.extra_dejson["timeout"] + if conn_args.get("http_compress", False): + conn_args["http_compress"] = bool(conn_args["http_compress"]) return connect(**conn_args) diff --git a/providers/elasticsearch/tests/provider_tests/elasticsearch/hooks/test_elasticsearch.py b/providers/elasticsearch/tests/provider_tests/elasticsearch/hooks/test_elasticsearch.py index ea34f2532de41..953e7dd50ef72 100644 --- a/providers/elasticsearch/tests/provider_tests/elasticsearch/hooks/test_elasticsearch.py +++ b/providers/elasticsearch/tests/provider_tests/elasticsearch/hooks/test_elasticsearch.py @@ -20,15 +20,42 @@ from unittest import mock from unittest.mock import MagicMock +import pytest from elasticsearch import Elasticsearch +from elasticsearch._sync.client import SqlClient +from kgb import SpyAgency from airflow.models import Connection +from airflow.providers.common.sql.hooks.handlers import fetch_all_handler from airflow.providers.elasticsearch.hooks.elasticsearch import ( ElasticsearchPythonHook, + ElasticsearchSQLCursor, ElasticsearchSQLHook, ESConnection, ) +ROWS = [ + [1, "Stallone", "Sylvester", "78"], + [2, "Statham", "Jason", "57"], + [3, "Li", "Jet", "61"], + [4, "Lundgren", "Dolph", "66"], + [5, "Norris", "Chuck", "84"], +] +RESPONSE_WITHOUT_CURSOR = { + "columns": [ + {"name": "index", "type": "long"}, + {"name": "name", "type": "text"}, + {"name": "firstname", "type": "text"}, + {"name": "age", "type": "long"}, + ], + "rows": ROWS, +} +RESPONSE = {**RESPONSE_WITHOUT_CURSOR, **{"cursor": "e7f8QwXUruW2mIebzudH4BwAA//8DAA=="}} +RESPONSES = [ + RESPONSE, + RESPONSE_WITHOUT_CURSOR, +] + class TestElasticsearchSQLHookConn: def setup_method(self): @@ -48,10 +75,68 @@ def test_get_conn(self, mock_connect): mock_connect.assert_called_with(host="localhost", port=9200, scheme="http", user=None, password=None) +class TestElasticsearchSQLCursor: + def setup_method(self): + sql = MagicMock(spec=SqlClient) + sql.query.side_effect = RESPONSES + self.es = MagicMock(sql=sql, spec=Elasticsearch) + + def test_execute(self): + cursor = ElasticsearchSQLCursor(es=self.es, options={}) + + assert cursor.execute("SELECT * FROM hollywood.actors") == RESPONSE + + def test_rowcount(self): + cursor = ElasticsearchSQLCursor(es=self.es, options={}) + cursor.execute("SELECT * FROM hollywood.actors") + + assert cursor.rowcount == len(ROWS) + + def test_description(self): + cursor = ElasticsearchSQLCursor(es=self.es, options={}) + cursor.execute("SELECT * FROM hollywood.actors") + + assert cursor.description == [ + ("index", "long"), + ("name", "text"), + ("firstname", "text"), + ("age", "long"), + ] + + def test_fetchone(self): + cursor = ElasticsearchSQLCursor(es=self.es, options={}) + cursor.execute("SELECT * FROM hollywood.actors") + + assert cursor.fetchone() == ROWS[0] + + def test_fetchmany(self): + cursor = ElasticsearchSQLCursor(es=self.es, options={}) + cursor.execute("SELECT * FROM hollywood.actors") + + with pytest.raises(NotImplementedError): + cursor.fetchmany() + + def test_fetchall(self): + cursor = ElasticsearchSQLCursor(es=self.es, options={}) + cursor.execute("SELECT * FROM hollywood.actors") + + records = cursor.fetchall() + + assert len(records) == 10 + assert records == ROWS + + class TestElasticsearchSQLHook: def setup_method(self): - self.cur = mock.MagicMock(rowcount=0) - self.conn = mock.MagicMock() + sql = MagicMock(spec=SqlClient) + sql.query.side_effect = RESPONSES + es = MagicMock(sql=sql, spec=Elasticsearch) + self.cur = ElasticsearchSQLCursor(es=es, options={}) + self.spy_agency = SpyAgency() + self.spy_agency.spy_on(self.cur.close, call_original=True) + self.spy_agency.spy_on(self.cur.execute, call_original=True) + self.spy_agency.spy_on(self.cur.fetchall, call_original=True) + self.conn = MagicMock(spec=ESConnection) self.conn.cursor.return_value = self.cur conn = self.conn @@ -64,55 +149,60 @@ def get_conn(self): self.db_hook = UnitTestElasticsearchSQLHook() def test_get_first_record(self): - statement = "SQL" - result_sets = [("row1",), ("row2",)] - self.cur.fetchone.return_value = result_sets[0] + statement = "SELECT * FROM hollywood.actors" + + assert self.db_hook.get_first(statement) == ROWS[0] - assert result_sets[0] == self.db_hook.get_first(statement) self.conn.close.assert_called_once_with() - self.cur.close.assert_called_once_with() - self.cur.execute.assert_called_once_with(statement) + self.spy_agency.assert_spy_called(self.cur.close) + self.spy_agency.assert_spy_called(self.cur.execute) def test_get_records(self): - statement = "SQL" - result_sets = [("row1",), ("row2",)] - self.cur.fetchall.return_value = result_sets + statement = "SELECT * FROM hollywood.actors" + + assert self.db_hook.get_records(statement) == ROWS - assert result_sets == self.db_hook.get_records(statement) self.conn.close.assert_called_once_with() - self.cur.close.assert_called_once_with() - self.cur.execute.assert_called_once_with(statement) + self.spy_agency.assert_spy_called(self.cur.close) + self.spy_agency.assert_spy_called(self.cur.execute) def test_get_pandas_df(self): - statement = "SQL" - column = "col" - result_sets = [("row1",), ("row2",)] - self.cur.description = [(column,)] - self.cur.fetchall.return_value = result_sets + statement = "SELECT * FROM hollywood.actors" df = self.db_hook.get_pandas_df(statement) - assert column == df.columns[0] + assert list(df.columns) == ["index", "name", "firstname", "age"] + assert df.values.tolist() == ROWS + + self.conn.close.assert_called_once_with() + self.spy_agency.assert_spy_called(self.cur.close) + self.spy_agency.assert_spy_called(self.cur.execute) + + def test_run(self): + statement = "SELECT * FROM hollywood.actors" - assert result_sets[0][0] == df.values.tolist()[0][0] - assert result_sets[1][0] == df.values.tolist()[1][0] + assert self.db_hook.run(statement, handler=fetch_all_handler) == ROWS - self.cur.execute.assert_called_once_with(statement) + self.conn.close.assert_called_once_with() + self.spy_agency.assert_spy_called(self.cur.close) + self.spy_agency.assert_spy_called(self.cur.execute) @mock.patch("airflow.providers.elasticsearch.hooks.elasticsearch.Elasticsearch") def test_execute_sql_query(self, mock_es): mock_es_sql_client = MagicMock() - mock_es_sql_client.query.return_value = { - "columns": [{"name": "id"}, {"name": "first_name"}], - "rows": [[1, "John"], [2, "Jane"]], - } + mock_es_sql_client.query.return_value = RESPONSE_WITHOUT_CURSOR mock_es.return_value.sql = mock_es_sql_client es_connection = ESConnection(host="localhost", port=9200) - response = es_connection.execute_sql("SELECT * FROM index1") - mock_es_sql_client.query.assert_called_once_with(body={"query": "SELECT * FROM index1"}) - - assert response["rows"] == [[1, "John"], [2, "Jane"]] - assert response["columns"] == [{"name": "id"}, {"name": "first_name"}] + response = es_connection.execute_sql("SELECT * FROM hollywood.actors") + mock_es_sql_client.query.assert_called_once_with( + body={ + "fetch_size": 1000, + "field_multi_value_leniency": False, + "query": "SELECT * FROM hollywood.actors", + } + ) + + assert response == RESPONSE_WITHOUT_CURSOR class MockElasticsearch: diff --git a/providers/microsoft/mssql/src/airflow/providers/microsoft/mssql/dialects/mssql.py b/providers/microsoft/mssql/src/airflow/providers/microsoft/mssql/dialects/mssql.py index edad1a11515d5..0c0ba72309a9a 100644 --- a/providers/microsoft/mssql/src/airflow/providers/microsoft/mssql/dialects/mssql.py +++ b/providers/microsoft/mssql/src/airflow/providers/microsoft/mssql/dialects/mssql.py @@ -55,10 +55,15 @@ def generate_replace_sql(self, table, values, target_fields, **kwargs) -> str: self.log.debug("primary_keys: %s", primary_keys) self.log.debug("columns: %s", columns) - return f"""MERGE INTO {table} WITH (ROWLOCK) AS target + sql = f"""MERGE INTO {table} WITH (ROWLOCK) AS target USING (SELECT {', '.join(map(lambda column: f'{self.placeholder} AS {self.escape_word(column)}', target_fields))}) AS source - ON {' AND '.join(map(lambda column: f'target.{self.escape_word(column)} = source.{self.escape_word(column)}', primary_keys))} + ON {' AND '.join(map(lambda column: f'target.{self.escape_word(column)} = source.{self.escape_word(column)}', primary_keys))}""" + + if columns: + sql = f"""{sql} WHEN MATCHED THEN - UPDATE SET {', '.join(map(lambda column: f'target.{column} = source.{column}', columns))} + UPDATE SET {', '.join(map(lambda column: f'target.{column} = source.{column}', columns))}""" + + return f"""{sql} WHEN NOT MATCHED THEN INSERT ({', '.join(map(self.escape_word, target_fields))}) VALUES ({', '.join(map(lambda column: f'source.{self.escape_word(column)}', target_fields))});""" diff --git a/providers/microsoft/mssql/tests/provider_tests/microsoft/mssql/dialects/test_mssql.py b/providers/microsoft/mssql/tests/provider_tests/microsoft/mssql/dialects/test_mssql.py index 749a79c13fcd1..c584a15ba3b85 100644 --- a/providers/microsoft/mssql/tests/provider_tests/microsoft/mssql/dialects/test_mssql.py +++ b/providers/microsoft/mssql/tests/provider_tests/microsoft/mssql/dialects/test_mssql.py @@ -17,51 +17,119 @@ # under the License. from __future__ import annotations -from unittest.mock import MagicMock +import pytest -from sqlalchemy.engine import Inspector - -from airflow.providers.common.sql.hooks.sql import DbApiHook from airflow.providers.microsoft.mssql.dialects.mssql import MsSqlDialect class TestMsSqlDialect: - def setup_method(self): - inspector = MagicMock(spc=Inspector) - inspector.get_columns.side_effect = lambda table_name, schema: [ - {"name": "index", "identity": True}, - {"name": "name"}, - {"name": "firstname"}, - {"name": "age"}, - ] - self.test_db_hook = MagicMock(placeholder="?", inspector=inspector, spec=DbApiHook) - self.test_db_hook.run.side_effect = lambda *args: [("index",)] - self.test_db_hook.reserved_words = {"index", "user"} - self.test_db_hook.escape_word_format = "[{}]" - self.test_db_hook.escape_column_names = False - - def test_placeholder(self): - assert MsSqlDialect(self.test_db_hook).placeholder == "?" + @pytest.mark.parametrize( + "create_db_api_hook", + [ + ( + [ + {"name": "index", "identity": True}, + {"name": "name"}, + {"name": "firstname"}, + {"name": "age"}, + ], # columns + [("index",)], # primary_keys + {"index", "user"}, # reserved_words + False, # escape_column_names + ), + ], + indirect=True, + ) + def test_placeholder(self, create_db_api_hook): + assert MsSqlDialect(create_db_api_hook).placeholder == "?" - def test_get_column_names(self): - assert MsSqlDialect(self.test_db_hook).get_column_names("hollywood.actors") == [ + @pytest.mark.parametrize( + "create_db_api_hook", + [ + ( + [ + {"name": "index", "identity": True}, + {"name": "name"}, + {"name": "firstname"}, + {"name": "age"}, + ], # columns + [("index",)], # primary_keys + {"index", "user"}, # reserved_words + False, # escape_column_names + ), + ], + indirect=True, + ) + def test_get_column_names(self, create_db_api_hook): + assert MsSqlDialect(create_db_api_hook).get_column_names("hollywood.actors") == [ "index", "name", "firstname", "age", ] - def test_get_target_fields(self): - assert MsSqlDialect(self.test_db_hook).get_target_fields("hollywood.actors") == [ + @pytest.mark.parametrize( + "create_db_api_hook", + [ + ( + [ + {"name": "index", "identity": True}, + {"name": "name"}, + {"name": "firstname"}, + {"name": "age"}, + ], # columns + [("index",)], # primary_keys + {"index", "user"}, # reserved_words + False, # escape_column_names + ), + ], + indirect=True, + ) + def test_get_target_fields(self, create_db_api_hook): + assert MsSqlDialect(create_db_api_hook).get_target_fields("hollywood.actors") == [ "name", "firstname", "age", ] - def test_get_primary_keys(self): - assert MsSqlDialect(self.test_db_hook).get_primary_keys("hollywood.actors") == ["index"] + @pytest.mark.parametrize( + "create_db_api_hook", + [ + ( + [ + {"name": "index", "identity": True}, + {"name": "name"}, + {"name": "firstname"}, + {"name": "age"}, + ], # columns + [("index",)], # primary_keys + {"index", "user"}, # reserved_words + False, # escape_column_names + ), + ], + indirect=True, + ) + def test_get_primary_keys(self, create_db_api_hook): + assert MsSqlDialect(create_db_api_hook).get_primary_keys("hollywood.actors") == ["index"] - def test_generate_replace_sql(self): + @pytest.mark.parametrize( + "create_db_api_hook", + [ + ( + [ + {"name": "index", "identity": True}, + {"name": "name"}, + {"name": "firstname"}, + {"name": "age"}, + ], # columns + [("index",)], # primary_keys + {"index", "user"}, # reserved_words + False, # escape_column_names + ), + ], + indirect=True, + ) + def test_generate_replace_sql(self, create_db_api_hook): values = [ {"index": 1, "name": "Stallone", "firstname": "Sylvester", "age": "78"}, {"index": 2, "name": "Statham", "firstname": "Jason", "age": "57"}, @@ -70,7 +138,7 @@ def test_generate_replace_sql(self): {"index": 5, "name": "Norris", "firstname": "Chuck", "age": "84"}, ] target_fields = ["index", "name", "firstname", "age"] - sql = MsSqlDialect(self.test_db_hook).generate_replace_sql("hollywood.actors", values, target_fields) + sql = MsSqlDialect(create_db_api_hook).generate_replace_sql("hollywood.actors", values, target_fields) assert ( sql == """ @@ -84,7 +152,62 @@ def test_generate_replace_sql(self): """.strip() ) - def test_generate_replace_sql_when_escape_column_names_is_enabled(self): + @pytest.mark.parametrize( + "create_db_api_hook", + [ + ( + [ + {"name": "index", "identity": True}, + {"name": "name", "identity": True}, + {"name": "firstname", "identity": True}, + {"name": "age", "identity": True}, + ], # columns + [("index",), ("name",), ("firstname",), ("age",)], # primary_keys + {"index", "user"}, # reserved_words + False, # escape_column_names + ), + ], + indirect=True, + ) + def test_generate_replace_sql_when_all_columns_are_part_of_primary_key(self, create_db_api_hook): + values = [ + {"index": 1, "name": "Stallone", "firstname": "Sylvester", "age": "78"}, + {"index": 2, "name": "Statham", "firstname": "Jason", "age": "57"}, + {"index": 3, "name": "Li", "firstname": "Jet", "age": "61"}, + {"index": 4, "name": "Lundgren", "firstname": "Dolph", "age": "66"}, + {"index": 5, "name": "Norris", "firstname": "Chuck", "age": "84"}, + ] + target_fields = ["index", "name", "firstname", "age"] + sql = MsSqlDialect(create_db_api_hook).generate_replace_sql("hollywood.actors", values, target_fields) + assert ( + sql + == """ + MERGE INTO hollywood.actors WITH (ROWLOCK) AS target + USING (SELECT ? AS [index], ? AS name, ? AS firstname, ? AS age) AS source + ON target.[index] = source.[index] AND target.name = source.name AND target.firstname = source.firstname AND target.age = source.age + WHEN NOT MATCHED THEN + INSERT ([index], name, firstname, age) VALUES (source.[index], source.name, source.firstname, source.age); + """.strip() + ) + + @pytest.mark.parametrize( + "create_db_api_hook", + [ + ( + [ + {"name": "index", "identity": True}, + {"name": "name"}, + {"name": "firstname"}, + {"name": "age"}, + ], # columns + [("index",)], # primary_keys + {"index", "user"}, # reserved_words + True, # escape_column_names + ), + ], + indirect=True, + ) + def test_generate_replace_sql_when_escape_column_names_is_enabled(self, create_db_api_hook): values = [ {"index": 1, "name": "Stallone", "firstname": "Sylvester", "age": "78"}, {"index": 2, "name": "Statham", "firstname": "Jason", "age": "57"}, @@ -93,8 +216,7 @@ def test_generate_replace_sql_when_escape_column_names_is_enabled(self): {"index": 5, "name": "Norris", "firstname": "Chuck", "age": "84"}, ] target_fields = ["index", "name", "firstname", "age"] - self.test_db_hook.escape_column_names = True - sql = MsSqlDialect(self.test_db_hook).generate_replace_sql("hollywood.actors", values, target_fields) + sql = MsSqlDialect(create_db_api_hook).generate_replace_sql("hollywood.actors", values, target_fields) assert ( sql == """ diff --git a/providers/microsoft/winrm/README.rst b/providers/microsoft/winrm/README.rst new file mode 100644 index 0000000000000..1cd1a2e28cffa --- /dev/null +++ b/providers/microsoft/winrm/README.rst @@ -0,0 +1,62 @@ + + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + .. NOTE! THIS FILE IS AUTOMATICALLY GENERATED AND WILL BE OVERWRITTEN! + + .. IF YOU WANT TO MODIFY TEMPLATE FOR THIS FILE, YOU SHOULD MODIFY THE TEMPLATE + `PROVIDER_README_TEMPLATE.rst.jinja2` IN the `dev/breeze/src/airflow_breeze/templates` DIRECTORY + + +Package ``apache-airflow-providers-microsoft-winrm`` + +Release: ``3.7.0`` + + +`Windows Remote Management (WinRM) `__ + + +Provider package +---------------- + +This is a provider package for ``microsoft.winrm`` provider. All classes for this provider package +are in ``airflow.providers.microsoft.winrm`` python package. + +You can find package information and changelog for the provider +in the `documentation `_. + +Installation +------------ + +You can install this package on top of an existing Airflow 2 installation (see ``Requirements`` below +for the minimum Airflow version supported) via +``pip install apache-airflow-providers-microsoft-winrm`` + +The package supports the following python versions: 3.9,3.10,3.11,3.12 + +Requirements +------------ + +================== ================== +PIP package Version required +================== ================== +``apache-airflow`` ``>=2.9.0`` +``pywinrm`` ``>=0.4`` +================== ================== + +The changelog for the provider package can be found in the +`changelog `_. diff --git a/providers/src/airflow/providers/microsoft/winrm/.latest-doc-only-change.txt b/providers/microsoft/winrm/docs/.latest-doc-only-change.txt similarity index 100% rename from providers/src/airflow/providers/microsoft/winrm/.latest-doc-only-change.txt rename to providers/microsoft/winrm/docs/.latest-doc-only-change.txt diff --git a/providers/src/airflow/providers/microsoft/winrm/CHANGELOG.rst b/providers/microsoft/winrm/docs/changelog.rst similarity index 100% rename from providers/src/airflow/providers/microsoft/winrm/CHANGELOG.rst rename to providers/microsoft/winrm/docs/changelog.rst diff --git a/docs/apache-airflow-providers-microsoft-winrm/commits.rst b/providers/microsoft/winrm/docs/commits.rst similarity index 100% rename from docs/apache-airflow-providers-microsoft-winrm/commits.rst rename to providers/microsoft/winrm/docs/commits.rst diff --git a/docs/apache-airflow-providers-microsoft-winrm/index.rst b/providers/microsoft/winrm/docs/index.rst similarity index 100% rename from docs/apache-airflow-providers-microsoft-winrm/index.rst rename to providers/microsoft/winrm/docs/index.rst diff --git a/docs/apache-airflow-providers-microsoft-winrm/installing-providers-from-sources.rst b/providers/microsoft/winrm/docs/installing-providers-from-sources.rst similarity index 100% rename from docs/apache-airflow-providers-microsoft-winrm/installing-providers-from-sources.rst rename to providers/microsoft/winrm/docs/installing-providers-from-sources.rst diff --git a/docs/integration-logos/winrm/WinRM.png b/providers/microsoft/winrm/docs/integration-logos/WinRM.png similarity index 100% rename from docs/integration-logos/winrm/WinRM.png rename to providers/microsoft/winrm/docs/integration-logos/WinRM.png diff --git a/docs/apache-airflow-providers-microsoft-winrm/operators.rst b/providers/microsoft/winrm/docs/operators.rst similarity index 86% rename from docs/apache-airflow-providers-microsoft-winrm/operators.rst rename to providers/microsoft/winrm/docs/operators.rst index 2e7fdc6633486..4035379f880f0 100644 --- a/docs/apache-airflow-providers-microsoft-winrm/operators.rst +++ b/providers/microsoft/winrm/docs/operators.rst @@ -22,7 +22,7 @@ use the WinRMOperator to execute commands on a given remote host using the winrm create a hook -.. exampleinclude:: /../../providers/tests/system/microsoft/winrm/example_winrm.py +.. exampleinclude:: /../../providers/microsoft/winrm/tests/system/microsoft/winrm/example_winrm.py :language: python :dedent: 4 :start-after: [START create_hook] @@ -30,7 +30,7 @@ create a hook Run the operator, pass the hook, and pass a command to do something -.. exampleinclude:: /../../providers/tests/system/microsoft/winrm/example_winrm.py +.. exampleinclude:: /../../providers/microsoft/winrm/tests/system/microsoft/winrm/example_winrm.py :language: python :dedent: 4 :start-after: [START run_operator] diff --git a/docs/apache-airflow-providers-microsoft-winrm/security.rst b/providers/microsoft/winrm/docs/security.rst similarity index 100% rename from docs/apache-airflow-providers-microsoft-winrm/security.rst rename to providers/microsoft/winrm/docs/security.rst diff --git a/providers/src/airflow/providers/microsoft/winrm/provider.yaml b/providers/microsoft/winrm/provider.yaml similarity index 94% rename from providers/src/airflow/providers/microsoft/winrm/provider.yaml rename to providers/microsoft/winrm/provider.yaml index 033e85bac8ad0..674f0ac4de1bd 100644 --- a/providers/src/airflow/providers/microsoft/winrm/provider.yaml +++ b/providers/microsoft/winrm/provider.yaml @@ -49,14 +49,10 @@ versions: - 1.0.1 - 1.0.0 -dependencies: - - apache-airflow>=2.9.0 - - pywinrm>=0.4 - integrations: - integration-name: Windows Remote Management (WinRM) external-doc-url: https://docs.microsoft.com/en-us/windows/win32/winrm/portal - logo: /integration-logos/winrm/WinRM.png + logo: /docs/integration-logos/WinRM.png how-to-guide: - /docs/apache-airflow-providers-microsoft-winrm/operators.rst tags: [protocol] diff --git a/providers/microsoft/winrm/pyproject.toml b/providers/microsoft/winrm/pyproject.toml new file mode 100644 index 0000000000000..f143e40f8b185 --- /dev/null +++ b/providers/microsoft/winrm/pyproject.toml @@ -0,0 +1,78 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# NOTE! THIS FILE IS AUTOMATICALLY GENERATED AND WILL BE OVERWRITTEN! + +# IF YOU WANT TO MODIFY THIS FILE EXCEPT DEPENDENCIES, YOU SHOULD MODIFY THE TEMPLATE +# `pyproject_TEMPLATE.toml.jinja2` IN the `dev/breeze/src/airflow_breeze/templates` DIRECTORY +[build-system] +requires = ["flit_core==3.10.1"] +build-backend = "flit_core.buildapi" + +[project] +name = "apache-airflow-providers-microsoft-winrm" +version = "3.7.0" +description = "Provider package apache-airflow-providers-microsoft-winrm for Apache Airflow" +readme = "README.rst" +authors = [ + {name="Apache Software Foundation", email="dev@airflow.apache.org"}, +] +maintainers = [ + {name="Apache Software Foundation", email="dev@airflow.apache.org"}, +] +keywords = [ "airflow-provider", "microsoft.winrm", "airflow", "integration" ] +classifiers = [ + "Development Status :: 5 - Production/Stable", + "Environment :: Console", + "Environment :: Web Environment", + "Intended Audience :: Developers", + "Intended Audience :: System Administrators", + "Framework :: Apache Airflow", + "Framework :: Apache Airflow :: Provider", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Topic :: System :: Monitoring", +] +requires-python = "~=3.9" + +# The dependencies should be modified in place in the generated file +# Any change in the dependencies is preserved when the file is regenerated +dependencies = [ + "apache-airflow>=2.9.0", + "pywinrm>=0.4", +] + +[project.urls] +"Documentation" = "https://airflow.apache.org/docs/apache-airflow-providers-microsoft-winrm/3.7.0" +"Changelog" = "https://airflow.apache.org/docs/apache-airflow-providers-microsoft-winrm/3.7.0/changelog.html" +"Bug Tracker" = "https://github.com/apache/airflow/issues" +"Source Code" = "https://github.com/apache/airflow" +"Slack Chat" = "https://s.apache.org/airflow-slack" +"Twitter" = "https://x.com/ApacheAirflow" +"YouTube" = "https://www.youtube.com/channel/UCSXwxpWZQ7XZ1WL3wqevChA/" + +[project.entry-points."apache_airflow_provider"] +provider_info = "airflow.providers.microsoft.winrm.get_provider_info:get_provider_info" + +[tool.flit.module] +name = "airflow.providers.microsoft.winrm" + +[tool.pytest.ini_options] +ignore = "tests/system/" diff --git a/providers/microsoft/winrm/src/airflow/providers/microsoft/winrm/LICENSE b/providers/microsoft/winrm/src/airflow/providers/microsoft/winrm/LICENSE new file mode 100644 index 0000000000000..11069edd79019 --- /dev/null +++ b/providers/microsoft/winrm/src/airflow/providers/microsoft/winrm/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + +TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + +1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + +2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + +3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + +4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + +5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + +6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + +7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + +8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + +9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + +END OF TERMS AND CONDITIONS + +APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + +Copyright [yyyy] [name of copyright owner] + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/providers/src/airflow/providers/microsoft/winrm/__init__.py b/providers/microsoft/winrm/src/airflow/providers/microsoft/winrm/__init__.py similarity index 100% rename from providers/src/airflow/providers/microsoft/winrm/__init__.py rename to providers/microsoft/winrm/src/airflow/providers/microsoft/winrm/__init__.py diff --git a/providers/microsoft/winrm/src/airflow/providers/microsoft/winrm/get_provider_info.py b/providers/microsoft/winrm/src/airflow/providers/microsoft/winrm/get_provider_info.py new file mode 100644 index 0000000000000..b35967335f517 --- /dev/null +++ b/providers/microsoft/winrm/src/airflow/providers/microsoft/winrm/get_provider_info.py @@ -0,0 +1,78 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# NOTE! THIS FILE IS AUTOMATICALLY GENERATED AND WILL BE OVERWRITTEN! +# +# IF YOU WANT TO MODIFY THIS FILE, YOU SHOULD MODIFY THE TEMPLATE +# `get_provider_info_TEMPLATE.py.jinja2` IN the `dev/breeze/src/airflow_breeze/templates` DIRECTORY + + +def get_provider_info(): + return { + "package-name": "apache-airflow-providers-microsoft-winrm", + "name": "Windows Remote Management (WinRM)", + "description": "`Windows Remote Management (WinRM) `__\n", + "state": "ready", + "source-date-epoch": 1734535461, + "versions": [ + "3.7.0", + "3.6.1", + "3.6.0", + "3.5.1", + "3.5.0", + "3.4.0", + "3.3.0", + "3.2.2", + "3.2.1", + "3.2.0", + "3.1.1", + "3.1.0", + "3.0.0", + "2.0.5", + "2.0.4", + "2.0.3", + "2.0.2", + "2.0.1", + "2.0.0", + "1.2.0", + "1.1.0", + "1.0.1", + "1.0.0", + ], + "integrations": [ + { + "integration-name": "Windows Remote Management (WinRM)", + "external-doc-url": "https://docs.microsoft.com/en-us/windows/win32/winrm/portal", + "logo": "/docs/integration-logos/WinRM.png", + "how-to-guide": ["/docs/apache-airflow-providers-microsoft-winrm/operators.rst"], + "tags": ["protocol"], + } + ], + "operators": [ + { + "integration-name": "Windows Remote Management (WinRM)", + "python-modules": ["airflow.providers.microsoft.winrm.operators.winrm"], + } + ], + "hooks": [ + { + "integration-name": "Windows Remote Management (WinRM)", + "python-modules": ["airflow.providers.microsoft.winrm.hooks.winrm"], + } + ], + "dependencies": ["apache-airflow>=2.9.0", "pywinrm>=0.4"], + } diff --git a/providers/tests/cncf/kubernetes/__init__.py b/providers/microsoft/winrm/src/airflow/providers/microsoft/winrm/hooks/__init__.py similarity index 100% rename from providers/tests/cncf/kubernetes/__init__.py rename to providers/microsoft/winrm/src/airflow/providers/microsoft/winrm/hooks/__init__.py diff --git a/providers/src/airflow/providers/microsoft/winrm/hooks/winrm.py b/providers/microsoft/winrm/src/airflow/providers/microsoft/winrm/hooks/winrm.py similarity index 99% rename from providers/src/airflow/providers/microsoft/winrm/hooks/winrm.py rename to providers/microsoft/winrm/src/airflow/providers/microsoft/winrm/hooks/winrm.py index 961e37ba3fe40..4ded4cdffacf2 100644 --- a/providers/src/airflow/providers/microsoft/winrm/hooks/winrm.py +++ b/providers/microsoft/winrm/src/airflow/providers/microsoft/winrm/hooks/winrm.py @@ -22,12 +22,11 @@ from base64 import b64encode from contextlib import suppress -from winrm.exceptions import WinRMOperationTimeoutError -from winrm.protocol import Protocol - from airflow.exceptions import AirflowException from airflow.hooks.base import BaseHook from airflow.utils.platform import getuser +from winrm.exceptions import WinRMOperationTimeoutError +from winrm.protocol import Protocol # TODO: FIXME please - I have too complex implementation diff --git a/providers/tests/cncf/kubernetes/models/__init__.py b/providers/microsoft/winrm/src/airflow/providers/microsoft/winrm/operators/__init__.py similarity index 100% rename from providers/tests/cncf/kubernetes/models/__init__.py rename to providers/microsoft/winrm/src/airflow/providers/microsoft/winrm/operators/__init__.py diff --git a/providers/src/airflow/providers/microsoft/winrm/operators/winrm.py b/providers/microsoft/winrm/src/airflow/providers/microsoft/winrm/operators/winrm.py similarity index 100% rename from providers/src/airflow/providers/microsoft/winrm/operators/winrm.py rename to providers/microsoft/winrm/src/airflow/providers/microsoft/winrm/operators/winrm.py diff --git a/providers/microsoft/winrm/tests/conftest.py b/providers/microsoft/winrm/tests/conftest.py new file mode 100644 index 0000000000000..068fe6bbf5ae9 --- /dev/null +++ b/providers/microsoft/winrm/tests/conftest.py @@ -0,0 +1,32 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import pathlib + +import pytest + +pytest_plugins = "tests_common.pytest_plugin" + + +@pytest.hookimpl(tryfirst=True) +def pytest_configure(config: pytest.Config) -> None: + deprecations_ignore_path = pathlib.Path(__file__).parent.joinpath("deprecations_ignore.yml") + dep_path = [deprecations_ignore_path] if deprecations_ignore_path.exists() else [] + config.inicfg["airflow_deprecations_ignore"] = ( + config.inicfg.get("airflow_deprecations_ignore", []) + dep_path # type: ignore[assignment,operator] + ) diff --git a/providers/microsoft/winrm/tests/provider_tests/__init__.py b/providers/microsoft/winrm/tests/provider_tests/__init__.py new file mode 100644 index 0000000000000..e8fd22856438c --- /dev/null +++ b/providers/microsoft/winrm/tests/provider_tests/__init__.py @@ -0,0 +1,17 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +__path__ = __import__("pkgutil").extend_path(__path__, __name__) # type: ignore diff --git a/providers/microsoft/winrm/tests/provider_tests/microsoft/__init__.py b/providers/microsoft/winrm/tests/provider_tests/microsoft/__init__.py new file mode 100644 index 0000000000000..e8fd22856438c --- /dev/null +++ b/providers/microsoft/winrm/tests/provider_tests/microsoft/__init__.py @@ -0,0 +1,17 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +__path__ = __import__("pkgutil").extend_path(__path__, __name__) # type: ignore diff --git a/providers/tests/microsoft/winrm/__init__.py b/providers/microsoft/winrm/tests/provider_tests/microsoft/winrm/__init__.py similarity index 100% rename from providers/tests/microsoft/winrm/__init__.py rename to providers/microsoft/winrm/tests/provider_tests/microsoft/winrm/__init__.py diff --git a/providers/tests/microsoft/winrm/hooks/__init__.py b/providers/microsoft/winrm/tests/provider_tests/microsoft/winrm/hooks/__init__.py similarity index 100% rename from providers/tests/microsoft/winrm/hooks/__init__.py rename to providers/microsoft/winrm/tests/provider_tests/microsoft/winrm/hooks/__init__.py diff --git a/providers/tests/microsoft/winrm/hooks/test_winrm.py b/providers/microsoft/winrm/tests/provider_tests/microsoft/winrm/hooks/test_winrm.py similarity index 100% rename from providers/tests/microsoft/winrm/hooks/test_winrm.py rename to providers/microsoft/winrm/tests/provider_tests/microsoft/winrm/hooks/test_winrm.py diff --git a/providers/tests/microsoft/winrm/operators/__init__.py b/providers/microsoft/winrm/tests/provider_tests/microsoft/winrm/operators/__init__.py similarity index 100% rename from providers/tests/microsoft/winrm/operators/__init__.py rename to providers/microsoft/winrm/tests/provider_tests/microsoft/winrm/operators/__init__.py diff --git a/providers/tests/microsoft/winrm/operators/test_winrm.py b/providers/microsoft/winrm/tests/provider_tests/microsoft/winrm/operators/test_winrm.py similarity index 100% rename from providers/tests/microsoft/winrm/operators/test_winrm.py rename to providers/microsoft/winrm/tests/provider_tests/microsoft/winrm/operators/test_winrm.py diff --git a/providers/tests/cncf/kubernetes/log_handlers/__init__.py b/providers/microsoft/winrm/tests/system/microsoft/winrm/__init__.py similarity index 100% rename from providers/tests/cncf/kubernetes/log_handlers/__init__.py rename to providers/microsoft/winrm/tests/system/microsoft/winrm/__init__.py diff --git a/providers/tests/system/microsoft/winrm/example_winrm.py b/providers/microsoft/winrm/tests/system/microsoft/winrm/example_winrm.py similarity index 100% rename from providers/tests/system/microsoft/winrm/example_winrm.py rename to providers/microsoft/winrm/tests/system/microsoft/winrm/example_winrm.py diff --git a/providers/openlineage/src/airflow/providers/openlineage/plugins/adapter.py b/providers/openlineage/src/airflow/providers/openlineage/plugins/adapter.py index b0d0259e3d867..0e6a519271f07 100644 --- a/providers/openlineage/src/airflow/providers/openlineage/plugins/adapter.py +++ b/providers/openlineage/src/airflow/providers/openlineage/plugins/adapter.py @@ -50,8 +50,13 @@ from datetime import datetime from airflow.providers.openlineage.extractors import OperatorLineage - from airflow.utils.log.secrets_masker import SecretsMasker + from airflow.sdk.execution_time.secrets_masker import SecretsMasker, _secrets_masker from airflow.utils.state import DagRunState +else: + try: + from airflow.sdk.execution_time.secrets_masker import SecretsMasker, _secrets_masker + except ImportError: + from airflow.utils.log.secrets_masker import SecretsMasker, _secrets_masker _PRODUCER = f"https://github.com/apache/airflow/tree/providers-openlineage/{OPENLINEAGE_PROVIDER_VERSION}" @@ -71,8 +76,6 @@ def __init__(self, client: OpenLineageClient | None = None, secrets_masker: Secr super().__init__() self._client = client if not secrets_masker: - from airflow.utils.log.secrets_masker import _secrets_masker - secrets_masker = _secrets_masker() self._redacter = OpenLineageRedactor.from_masker(secrets_masker) diff --git a/providers/openlineage/src/airflow/providers/openlineage/utils/utils.py b/providers/openlineage/src/airflow/providers/openlineage/utils/utils.py index 4f3d904415157..9476cf828e04b 100644 --- a/providers/openlineage/src/airflow/providers/openlineage/utils/utils.py +++ b/providers/openlineage/src/airflow/providers/openlineage/utils/utils.py @@ -54,12 +54,6 @@ from airflow.sensors.base import BaseSensorOperator from airflow.serialization.serialized_objects import SerializedBaseOperator from airflow.utils.context import AirflowContextDeprecationWarning -from airflow.utils.log.secrets_masker import ( - Redactable, - Redacted, - SecretsMasker, - should_hide_value_for_key, -) from airflow.utils.module_loading import import_string from airflow.utils.session import NEW_SESSION, provide_session from openlineage.client.utils import RedactMixin @@ -68,6 +62,12 @@ from airflow.models import TaskInstance from airflow.providers.common.compat.assets import Asset from airflow.sdk import DAG, BaseOperator, MappedOperator + from airflow.sdk.execution_time.secrets_masker import ( + Redactable, + Redacted, + SecretsMasker, + should_hide_value_for_key, + ) from airflow.utils.state import DagRunState, TaskInstanceState from openlineage.client.event_v2 import Dataset as OpenLineageDataset from openlineage.client.facet_v2 import RunFacet, processing_engine_run @@ -86,6 +86,21 @@ # dataset is renamed to asset since Airflow 3.0 from airflow.datasets import Dataset as Asset + try: + from airflow.sdk.execution_time.secrets_masker import ( + Redactable, + Redacted, + SecretsMasker, + should_hide_value_for_key, + ) + except ImportError: + from airflow.utils.log.secrets_masker import ( + Redactable, + Redacted, + SecretsMasker, + should_hide_value_for_key, + ) + log = logging.getLogger(__name__) _NOMINAL_TIME_FORMAT = "%Y-%m-%dT%H:%M:%S.%fZ" diff --git a/providers/openlineage/tests/provider_tests/openlineage/plugins/test_utils.py b/providers/openlineage/tests/provider_tests/openlineage/plugins/test_utils.py index 59fbc8605ad6e..3d30f6a6fd942 100644 --- a/providers/openlineage/tests/provider_tests/openlineage/plugins/test_utils.py +++ b/providers/openlineage/tests/provider_tests/openlineage/plugins/test_utils.py @@ -20,7 +20,7 @@ import json import uuid from json import JSONEncoder -from typing import Any +from typing import TYPE_CHECKING, Any from unittest.mock import MagicMock, patch import pytest @@ -45,7 +45,6 @@ ) from airflow.serialization.enums import DagAttributeTypes from airflow.utils import timezone -from airflow.utils.log.secrets_masker import _secrets_masker from airflow.utils.state import State from airflow.utils.types import DagRunType @@ -57,6 +56,14 @@ if AIRFLOW_V_3_0_PLUS: from airflow.utils.types import DagRunTriggeredByType +if TYPE_CHECKING: + from airflow.sdk.execution_time.secrets_masker import _secrets_masker +else: + try: + from airflow.sdk.execution_time.secrets_masker import _secrets_masker + except ImportError: + from airflow.utils.log.secrets_masker import _secrets_masker + class SafeStrDict(dict): def __str__(self): @@ -236,7 +243,7 @@ def __init__(self): @pytest.mark.enable_redact def test_redact_with_exclusions(monkeypatch): - redactor = OpenLineageRedactor.from_masker(_secrets_masker()) + redactor = OpenLineageRedactor.from_masker(_secrets_masker()) # type: ignore[assignment] class NotMixin: def __init__(self): diff --git a/providers/openlineage/tests/system/openlineage/example_openlineage.json b/providers/openlineage/tests/system/openlineage/example_openlineage.json index 0db8bc53e2288..3332d6650d406 100644 --- a/providers/openlineage/tests/system/openlineage/example_openlineage.json +++ b/providers/openlineage/tests/system/openlineage/example_openlineage.json @@ -2,17 +2,93 @@ { "eventType": "START", "eventTime": "{{ is_datetime(result) }}", + "producer": "{{ result.startswith('https://github.com/apache/airflow/tree/providers-openlineage') }}", + "schemaURL": "{{ result.startswith('https://openlineage.io/spec') }}", + "inputs": [], + "outputs": [], "run": { - "runId": "{{ is_uuid(result) }}" + "runId": "{{ is_uuid(result) }}", + "facets": { + "parent": { + "job": { + "namespace": "{{ result is string }}", + "name": "openlineage_basic_dag" + }, + "run": { + "runId": "{{ is_uuid(result) }}" + } + }, + "airflow": { + "dag": { + "dag_id": "openlineage_basic_dag", + "fileloc": "{{ result.endswith('openlineage/example_openlineage.py') }}", + "owner": "airflow", + "start_date": "{{ is_datetime(result) }}" + }, + "dagRun": { + "conf": {}, + "dag_id": "openlineage_basic_dag", + "data_interval_end": "{{ is_datetime(result) }}", + "data_interval_start": "{{ is_datetime(result) }}", + "start_date": "{{ is_datetime(result) }}" + }, + "taskInstance": { + "try_number": "{{ result is number }}", + "queued_dttm": "{{ is_datetime(result) }}", + "log_url": "{{ result is string }}" + }, + "task": { + "inlets": "[]", + "mapped": false, + "outlets": "[]", + "task_id": "do_nothing_task", + "trigger_rule": "all_success", + "operator_class": "PythonOperator", + "retries": "{{ result is number }}", + "depends_on_past": false, + "executor_config": {}, + "priority_weight": 1, + "multiple_outputs": false, + "upstream_task_ids": "[]", + "downstream_task_ids": "['check_events']", + "operator_class_path": "{{ result.endswith('.PythonOperator') }}", + "wait_for_downstream": false, + "retry_exponential_backoff": false, + "ignore_first_depends_on_past": false, + "wait_for_past_depends_before_skipping": false + }, + "taskUuid": "{{ is_uuid(result) }}" + }, + "nominalTime": { + "nominalEndTime": "{{ is_datetime(result) }}", + "nominalStartTime": "{{ is_datetime(result) }}" + }, + "processing_engine": { + "name": "Airflow", + "openlineageAdapterVersion": "{{ result is string }}", + "version": "{{ result is string }}" + } + } }, "job": { - "namespace": "default", + "namespace": "{{ result is string }}", "name": "openlineage_basic_dag.do_nothing_task", "facets": { "jobType": { "integration": "AIRFLOW", "jobType": "TASK", "processingType": "BATCH" + }, + "ownership": { + "owners": [ + { + "name": "{{ result is string }}" + } + ] + }, + "sourceCode": { + "language": "python", + "sourceCode": "def do_nothing():\n pass\n" } } } @@ -21,16 +97,77 @@ "eventType": "COMPLETE", "eventTime": "{{ is_datetime(result) }}", "run": { - "runId": "{{ is_uuid(result) }}" + "runId": "{{ is_uuid(result) }}", + "facets": { + "parent": { + "job": { + "namespace": "{{ result is string }}", + "name": "openlineage_basic_dag" + }, + "run": { + "runId": "{{ is_uuid(result) }}" + } + }, + "airflow": { + "dag": { + "dag_id": "openlineage_basic_dag", + "fileloc": "{{ result.endswith('openlineage/example_openlineage.py') }}", + "owner": "airflow", + "start_date": "{{ is_datetime(result) }}" + }, + "dagRun": { + "conf": {}, + "dag_id": "openlineage_basic_dag", + "data_interval_end": "{{ is_datetime(result) }}", + "data_interval_start": "{{ is_datetime(result) }}", + "start_date": "{{ is_datetime(result) }}" + }, + "taskInstance": { + "try_number": "{{ result is number }}", + "queued_dttm": "{{ is_datetime(result) }}", + "log_url": "{{ result is string }}" + }, + "task": { + "inlets": "[]", + "mapped": false, + "outlets": "[]", + "task_id": "do_nothing_task", + "trigger_rule": "all_success", + "operator_class": "PythonOperator", + "retries": "{{ result is number }}", + "depends_on_past": false, + "executor_config": {}, + "priority_weight": 1, + "multiple_outputs": false, + "upstream_task_ids": "[]", + "downstream_task_ids": "['check_events']", + "operator_class_path": "{{ result.endswith('.PythonOperator') }}", + "wait_for_downstream": false, + "retry_exponential_backoff": false, + "ignore_first_depends_on_past": false, + "wait_for_past_depends_before_skipping": false + }, + "taskUuid": "{{ is_uuid(result) }}" + }, + "processing_engine": { + "name": "Airflow", + "openlineageAdapterVersion": "{{ result is string }}", + "version": "{{ result is string }}" + } + } }, "job": { - "namespace": "default", + "namespace": "{{ result is string }}", "name": "openlineage_basic_dag.do_nothing_task", "facets": { "jobType": { "integration": "AIRFLOW", "jobType": "TASK", "processingType": "BATCH" + }, + "sourceCode": { + "language": "python", + "sourceCode": "def do_nothing():\n pass\n" } } } diff --git a/providers/openlineage/tests/system/openlineage/example_openlineage.py b/providers/openlineage/tests/system/openlineage/example_openlineage.py index 8e632d2c2dcbc..28b92540ef4b3 100644 --- a/providers/openlineage/tests/system/openlineage/example_openlineage.py +++ b/providers/openlineage/tests/system/openlineage/example_openlineage.py @@ -16,8 +16,8 @@ # under the License. from __future__ import annotations -import os from datetime import datetime +from pathlib import Path from providers.openlineage.tests.system.openlineage.operator import OpenLineageTestOperator @@ -43,7 +43,7 @@ def do_nothing(): check_events = OpenLineageTestOperator( task_id="check_events", - file_path=f"{os.getenv('AIRFLOW_HOME')}/dags/providers/tests/system/openlineage/example_openlineage.json", + file_path=str(Path(__file__).parent / "example_openlineage.json"), ) nothing_task >> check_events diff --git a/providers/openlineage/tests/system/openlineage/example_openlineage_mapped_sensor.py b/providers/openlineage/tests/system/openlineage/example_openlineage_mapped_sensor.py index 44fecc85a35e5..f49b6e591b325 100644 --- a/providers/openlineage/tests/system/openlineage/example_openlineage_mapped_sensor.py +++ b/providers/openlineage/tests/system/openlineage/example_openlineage_mapped_sensor.py @@ -18,6 +18,7 @@ import os from datetime import datetime, timedelta +from pathlib import Path from providers.openlineage.tests.system.openlineage.operator import OpenLineageTestOperator @@ -68,7 +69,7 @@ def check_start_amount_func(): check_events = OpenLineageTestOperator( task_id="check_events", - file_path=f"{os.getenv('AIRFLOW_HOME')}/dags/providers/tests/system/openlineage/example_openlineage_mapped_sensor.json", + file_path=str(Path(__file__).parent / "example_openlineage_mapped_sensor.json"), allow_duplicate_events=True, ) diff --git a/providers/sftp/src/airflow/providers/sftp/hooks/sftp.py b/providers/sftp/src/airflow/providers/sftp/hooks/sftp.py index e7ca8d94fc359..1301fd9102cfe 100644 --- a/providers/sftp/src/airflow/providers/sftp/hooks/sftp.py +++ b/providers/sftp/src/airflow/providers/sftp/hooks/sftp.py @@ -22,20 +22,24 @@ import datetime import os import stat -from collections.abc import Sequence +import warnings +from collections.abc import Generator, Sequence +from contextlib import closing, contextmanager from fnmatch import fnmatch +from io import BytesIO from pathlib import Path from typing import TYPE_CHECKING, Any, Callable import asyncssh from asgiref.sync import sync_to_async -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.hooks.base import BaseHook from airflow.providers.ssh.hooks.ssh import SSHHook if TYPE_CHECKING: - import paramiko + from paramiko.sftp_attr import SFTPAttributes + from paramiko.sftp_client import SFTPClient from airflow.models.connection import Connection @@ -52,8 +56,6 @@ class SFTPHook(SSHHook): - In contrast with FTPHook describe_directory only returns size, type and modify. It doesn't return unix.owner, unix.mode, perm, unix.group and unique. - - retrieve_file and store_file only take a local full path and not a - buffer. - If no mode is passed to create_directory it will be created with 777 permissions. @@ -85,7 +87,22 @@ def __init__( *args, **kwargs, ) -> None: - self.conn: paramiko.SFTPClient | None = None + # TODO: remove support for ssh_hook when it is removed from SFTPOperator + if kwargs.get("ssh_hook") is not None: + warnings.warn( + "Parameter `ssh_hook` is deprecated and will be ignored.", + AirflowProviderDeprecationWarning, + stacklevel=2, + ) + + ftp_conn_id = kwargs.pop("ftp_conn_id", None) + if ftp_conn_id: + warnings.warn( + "Parameter `ftp_conn_id` is deprecated. Please use `ssh_conn_id` instead.", + AirflowProviderDeprecationWarning, + stacklevel=2, + ) + ssh_conn_id = ftp_conn_id kwargs["ssh_conn_id"] = ssh_conn_id kwargs["host_proxy_cmd"] = host_proxy_cmd @@ -93,17 +110,11 @@ def __init__( super().__init__(*args, **kwargs) - def get_conn(self) -> paramiko.SFTPClient: # type: ignore[override] - """Open an SFTP connection to the remote host.""" - if self.conn is None: - self.conn = super().get_conn().open_sftp() - return self.conn - - def close_conn(self) -> None: - """Close the SFTP connection.""" - if self.conn is not None: - self.conn.close() - self.conn = None + @contextmanager + def get_conn(self) -> Generator[SFTPClient, None, None]: + """Context manager that closes the connection after use.""" + with closing(super().get_conn().open_sftp()) as conn: + yield conn def describe_directory(self, path: str) -> dict[str, dict[str, str | int | None]]: """ @@ -114,17 +125,17 @@ def describe_directory(self, path: str) -> dict[str, dict[str, str | int | None] :param path: full path to the remote directory """ - conn = self.get_conn() - flist = sorted(conn.listdir_attr(path), key=lambda x: x.filename) - files = {} - for f in flist: - modify = datetime.datetime.fromtimestamp(f.st_mtime).strftime("%Y%m%d%H%M%S") # type: ignore - files[f.filename] = { - "size": f.st_size, - "type": "dir" if stat.S_ISDIR(f.st_mode) else "file", # type: ignore - "modify": modify, - } - return files + with self.get_conn() as conn: # type: SFTPClient + flist = sorted(conn.listdir_attr(path), key=lambda x: x.filename) + files = {} + for f in flist: + modify = datetime.datetime.fromtimestamp(f.st_mtime).strftime("%Y%m%d%H%M%S") # type: ignore + files[f.filename] = { + "size": f.st_size, + "type": "dir" if stat.S_ISDIR(f.st_mode) else "file", # type: ignore + "modify": modify, + } + return files def list_directory(self, path: str) -> list[str]: """ @@ -132,18 +143,17 @@ def list_directory(self, path: str) -> list[str]: :param path: full path to the remote directory to list """ - conn = self.get_conn() - files = sorted(conn.listdir(path)) - return files + with self.get_conn() as conn: + return sorted(conn.listdir(path)) - def list_directory_with_attr(self, path: str) -> list[paramiko.SFTPAttributes]: + def list_directory_with_attr(self, path: str) -> list[SFTPAttributes]: """ List files in a directory on the remote system including their SFTPAttributes. :param path: full path to the remote directory to list """ - conn = self.get_conn() - return [file for file in conn.listdir_attr(path)] + with self.get_conn() as conn: + return [file for file in conn.listdir_attr(path)] def mkdir(self, path: str, mode: int = 0o777) -> None: """ @@ -155,8 +165,8 @@ def mkdir(self, path: str, mode: int = 0o777) -> None: :param path: full path to the remote directory to create :param mode: int permissions of octal mode for directory """ - conn = self.get_conn() - conn.mkdir(path, mode=mode) + with self.get_conn() as conn: + conn.mkdir(path, mode=mode) def isdir(self, path: str) -> bool: """ @@ -164,12 +174,11 @@ def isdir(self, path: str) -> bool: :param path: full path to the remote directory to check """ - conn = self.get_conn() - try: - result = stat.S_ISDIR(conn.stat(path).st_mode) # type: ignore - except OSError: - result = False - return result + with self.get_conn() as conn: + try: + return stat.S_ISDIR(conn.stat(path).st_mode) # type: ignore + except OSError: + return False def isfile(self, path: str) -> bool: """ @@ -177,12 +186,11 @@ def isfile(self, path: str) -> bool: :param path: full path to the remote file to check """ - conn = self.get_conn() - try: - result = stat.S_ISREG(conn.stat(path).st_mode) # type: ignore - except OSError: - result = False - return result + with self.get_conn() as conn: + try: + return stat.S_ISREG(conn.stat(path).st_mode) # type: ignore + except OSError: + return False def create_directory(self, path: str, mode: int = 0o777) -> None: """ @@ -196,28 +204,35 @@ def create_directory(self, path: str, mode: int = 0o777) -> None: :param path: full path to the remote directory to create :param mode: int permissions of octal mode for directory """ - conn = self.get_conn() - if self.isdir(path): - self.log.info("%s already exists", path) - return - elif self.isfile(path): - raise AirflowException(f"{path} already exists and is a file") - else: - dirname, basename = os.path.split(path) - if dirname and not self.isdir(dirname): - self.create_directory(dirname, mode) - if basename: - self.log.info("Creating %s", path) - conn.mkdir(path, mode=mode) - - def delete_directory(self, path: str) -> None: + with self.get_conn() as conn: + if self.isdir(path): + self.log.info("%s already exists", path) + return + elif self.isfile(path): + raise AirflowException(f"{path} already exists and is a file") + else: + dirname, basename = os.path.split(path) + if dirname and not self.isdir(dirname): + self.create_directory(dirname, mode) + if basename: + self.log.info("Creating %s", path) + conn.mkdir(path, mode=mode) + + def delete_directory(self, path: str, include_files: bool = False) -> None: """ Delete a directory on the remote system. :param path: full path to the remote directory to delete """ - conn = self.get_conn() - conn.rmdir(path) + with self.get_conn() as conn: + if include_files is True: + files, dirs, _ = self.get_tree_map(path) + dirs = dirs[::-1] # reverse the order for deleting deepest directories first + for file_path in files: + conn.remove(file_path) + for dir_path in dirs: + conn.rmdir(dir_path) + conn.rmdir(path) def retrieve_file(self, remote_full_path: str, local_full_path: str, prefetch: bool = True) -> None: """ @@ -227,11 +242,14 @@ def retrieve_file(self, remote_full_path: str, local_full_path: str, prefetch: b at that location. :param remote_full_path: full path to the remote file - :param local_full_path: full path to the local file + :param local_full_path: full path to the local file or a file-like buffer :param prefetch: controls whether prefetch is performed (default: True) """ - conn = self.get_conn() - conn.get(remote_full_path, local_full_path, prefetch=prefetch) + with self.get_conn() as conn: + if isinstance(local_full_path, BytesIO): + conn.getfo(remote_full_path, local_full_path, prefetch=prefetch) + else: + conn.get(remote_full_path, local_full_path, prefetch=prefetch) def store_file(self, remote_full_path: str, local_full_path: str, confirm: bool = True) -> None: """ @@ -241,10 +259,13 @@ def store_file(self, remote_full_path: str, local_full_path: str, confirm: bool from that location. :param remote_full_path: full path to the remote file - :param local_full_path: full path to the local file + :param local_full_path: full path to the local file or a file-like buffer """ - conn = self.get_conn() - conn.put(local_full_path, remote_full_path, confirm=confirm) + with self.get_conn() as conn: + if isinstance(local_full_path, BytesIO): + conn.putfo(local_full_path, remote_full_path, confirm=confirm) + else: + conn.put(local_full_path, remote_full_path, confirm=confirm) def delete_file(self, path: str) -> None: """ @@ -252,8 +273,8 @@ def delete_file(self, path: str) -> None: :param path: full path to the remote file """ - conn = self.get_conn() - conn.remove(path) + with self.get_conn() as conn: + conn.remove(path) def retrieve_directory(self, remote_full_path: str, local_full_path: str, prefetch: bool = True) -> None: """ @@ -306,9 +327,9 @@ def get_mod_time(self, path: str) -> str: :param path: full path to the remote file """ - conn = self.get_conn() - ftp_mdtm = conn.stat(path).st_mtime - return datetime.datetime.fromtimestamp(ftp_mdtm).strftime("%Y%m%d%H%M%S") # type: ignore + with self.get_conn() as conn: + ftp_mdtm = conn.stat(path).st_mtime + return datetime.datetime.fromtimestamp(ftp_mdtm).strftime("%Y%m%d%H%M%S") # type: ignore def path_exists(self, path: str) -> bool: """ @@ -316,12 +337,12 @@ def path_exists(self, path: str) -> bool: :param path: full path to the remote file or directory """ - conn = self.get_conn() - try: - conn.stat(path) - except OSError: - return False - return True + with self.get_conn() as conn: + try: + conn.stat(path) + except OSError: + return False + return True @staticmethod def _is_path_match(path: str, prefix: str | None = None, delimiter: str | None = None) -> bool: @@ -415,9 +436,9 @@ def append_matching_path_callback(list_: list[str]) -> Callable: def test_connection(self) -> tuple[bool, str]: """Test the SFTP connection by calling path with directory.""" try: - conn = self.get_conn() - conn.normalize(".") - return True, "Connection successfully tested" + with self.get_conn() as conn: + conn.normalize(".") + return True, "Connection successfully tested" except Exception as e: return False, str(e) @@ -432,7 +453,6 @@ def get_file_by_pattern(self, path, fnmatch_pattern) -> str: for file in self.list_directory(path): if fnmatch(file, fnmatch_pattern): return file - return "" def get_files_by_pattern(self, path, fnmatch_pattern) -> list[str]: @@ -600,17 +620,13 @@ async def get_mod_time(self, path: str) -> str: # type: ignore[return] :param path: full path to the remote file """ - ssh_conn = None - try: - ssh_conn = await self._get_conn() - sftp_client = await ssh_conn.start_sftp_client() - ftp_mdtm = await sftp_client.stat(path) - modified_time = ftp_mdtm.mtime - mod_time = datetime.datetime.fromtimestamp(modified_time).strftime("%Y%m%d%H%M%S") # type: ignore[arg-type] - self.log.info("Found File %s last modified: %s", str(path), str(mod_time)) - return mod_time - except asyncssh.SFTPNoSuchFile: - raise AirflowException("No files matching") - finally: - if ssh_conn: - ssh_conn.close() + async with await self._get_conn() as ssh_conn: + try: + sftp_client = await ssh_conn.start_sftp_client() + ftp_mdtm = await sftp_client.stat(path) + modified_time = ftp_mdtm.mtime + mod_time = datetime.datetime.fromtimestamp(modified_time).strftime("%Y%m%d%H%M%S") # type: ignore[arg-type] + self.log.info("Found File %s last modified: %s", str(path), str(mod_time)) + return mod_time + except asyncssh.SFTPNoSuchFile: + raise AirflowException("No files matching") diff --git a/providers/sftp/src/airflow/providers/sftp/operators/sftp.py b/providers/sftp/src/airflow/providers/sftp/operators/sftp.py index 3984154661836..99365e0932d71 100644 --- a/providers/sftp/src/airflow/providers/sftp/operators/sftp.py +++ b/providers/sftp/src/airflow/providers/sftp/operators/sftp.py @@ -37,6 +37,7 @@ class SFTPOperation: PUT = "put" GET = "get" + DELETE = "delete" class SFTPOperator(BaseOperator): @@ -53,8 +54,8 @@ class SFTPOperator(BaseOperator): Nullable. If provided, it will replace the `remote_host` which was defined in `sftp_hook` or predefined in the connection of `ssh_conn_id`. :param local_filepath: local file path or list of local file paths to get or put. (templated) - :param remote_filepath: remote file path or list of remote file paths to get or put. (templated) - :param operation: specify operation 'get' or 'put', defaults to put + :param remote_filepath: remote file path or list of remote file paths to get, put, or delete. (templated) + :param operation: specify operation 'get', 'put', or 'delete', defaults to put :param confirm: specify if the SFTP operation should be confirmed, defaults to True :param create_intermediate_dirs: create missing intermediate directories when copying from remote to local and vice-versa. Default is False. @@ -84,7 +85,7 @@ def __init__( sftp_hook: SFTPHook | None = None, ssh_conn_id: str | None = None, remote_host: str | None = None, - local_filepath: str | list[str], + local_filepath: str | list[str] | None = None, remote_filepath: str | list[str], operation: str = SFTPOperation.PUT, confirm: bool = True, @@ -102,7 +103,9 @@ def __init__( self.remote_filepath = remote_filepath def execute(self, context: Any) -> str | list[str] | None: - if isinstance(self.local_filepath, str): + if self.local_filepath is None: + local_filepath_array = [] + elif isinstance(self.local_filepath, str): local_filepath_array = [self.local_filepath] else: local_filepath_array = self.local_filepath @@ -112,16 +115,21 @@ def execute(self, context: Any) -> str | list[str] | None: else: remote_filepath_array = self.remote_filepath - if len(local_filepath_array) != len(remote_filepath_array): + if self.operation.lower() in (SFTPOperation.GET, SFTPOperation.PUT) and len( + local_filepath_array + ) != len(remote_filepath_array): raise ValueError( f"{len(local_filepath_array)} paths in local_filepath " f"!= {len(remote_filepath_array)} paths in remote_filepath" ) - if self.operation.lower() not in (SFTPOperation.GET, SFTPOperation.PUT): + if self.operation.lower() == SFTPOperation.DELETE and local_filepath_array: + raise ValueError("local_filepath should not be provided for delete operation") + + if self.operation.lower() not in (SFTPOperation.GET, SFTPOperation.PUT, SFTPOperation.DELETE): raise TypeError( f"Unsupported operation value {self.operation}, " - f"expected {SFTPOperation.GET} or {SFTPOperation.PUT}." + f"expected {SFTPOperation.GET} or {SFTPOperation.PUT} or {SFTPOperation.DELETE}." ) file_msg = None @@ -144,32 +152,43 @@ def execute(self, context: Any) -> str | list[str] | None: ) self.sftp_hook.remote_host = self.remote_host - for _local_filepath, _remote_filepath in zip(local_filepath_array, remote_filepath_array): - if self.operation.lower() == SFTPOperation.GET: - local_folder = os.path.dirname(_local_filepath) - if self.create_intermediate_dirs: - Path(local_folder).mkdir(parents=True, exist_ok=True) - file_msg = f"from {_remote_filepath} to {_local_filepath}" - self.log.info("Starting to transfer %s", file_msg) + if self.operation.lower() in (SFTPOperation.GET, SFTPOperation.PUT): + for _local_filepath, _remote_filepath in zip(local_filepath_array, remote_filepath_array): + if self.operation.lower() == SFTPOperation.GET: + local_folder = os.path.dirname(_local_filepath) + if self.create_intermediate_dirs: + Path(local_folder).mkdir(parents=True, exist_ok=True) + file_msg = f"from {_remote_filepath} to {_local_filepath}" + self.log.info("Starting to transfer %s", file_msg) + if self.sftp_hook.isdir(_remote_filepath): + self.sftp_hook.retrieve_directory(_remote_filepath, _local_filepath) + else: + self.sftp_hook.retrieve_file(_remote_filepath, _local_filepath) + elif self.operation.lower() == SFTPOperation.PUT: + remote_folder = os.path.dirname(_remote_filepath) + if self.create_intermediate_dirs: + self.sftp_hook.create_directory(remote_folder) + file_msg = f"from {_local_filepath} to {_remote_filepath}" + self.log.info("Starting to transfer file %s", file_msg) + if os.path.isdir(_local_filepath): + self.sftp_hook.store_directory( + _remote_filepath, _local_filepath, confirm=self.confirm + ) + else: + self.sftp_hook.store_file(_remote_filepath, _local_filepath, confirm=self.confirm) + elif self.operation.lower() == SFTPOperation.DELETE: + for _remote_filepath in remote_filepath_array: + file_msg = f"{_remote_filepath}" + self.log.info("Starting to delete %s", file_msg) if self.sftp_hook.isdir(_remote_filepath): - self.sftp_hook.retrieve_directory(_remote_filepath, _local_filepath) + self.sftp_hook.delete_directory(_remote_filepath, include_files=True) else: - self.sftp_hook.retrieve_file(_remote_filepath, _local_filepath) - else: - remote_folder = os.path.dirname(_remote_filepath) - if self.create_intermediate_dirs: - self.sftp_hook.create_directory(remote_folder) - file_msg = f"from {_local_filepath} to {_remote_filepath}" - self.log.info("Starting to transfer file %s", file_msg) - if os.path.isdir(_local_filepath): - self.sftp_hook.store_directory( - _remote_filepath, _local_filepath, confirm=self.confirm - ) - else: - self.sftp_hook.store_file(_remote_filepath, _local_filepath, confirm=self.confirm) + self.sftp_hook.delete_file(_remote_filepath) except Exception as e: - raise AirflowException(f"Error while transferring {file_msg}, error: {e}") + raise AirflowException( + f"Error while processing {self.operation.upper()} operation {file_msg}, error: {e}" + ) return self.local_filepath diff --git a/providers/sftp/tests/provider_tests/sftp/hooks/test_sftp.py b/providers/sftp/tests/provider_tests/sftp/hooks/test_sftp.py index 54b27e450be9a..a35ee4010712d 100644 --- a/providers/sftp/tests/provider_tests/sftp/hooks/test_sftp.py +++ b/providers/sftp/tests/provider_tests/sftp/hooks/test_sftp.py @@ -21,14 +21,15 @@ import json import os import shutil -from io import StringIO -from unittest import mock -from unittest.mock import AsyncMock, patch +from io import BytesIO, StringIO +from unittest.mock import AsyncMock, MagicMock, PropertyMock, patch import paramiko import pytest from asyncssh import SFTPAttrs, SFTPNoSuchFile from asyncssh.sftp import SFTPName +from paramiko.client import SSHClient +from paramiko.sftp_client import SFTPClient from airflow.exceptions import AirflowException from airflow.models import Connection @@ -87,7 +88,10 @@ def setup_test_cases(self, tmp_path_factory): file.write("Test file") with open(os.path.join(temp_dir, TMP_DIR_FOR_TESTS, SUB_DIR, TMP_FILE_FOR_TESTS), "a") as file: file.write("Test file") - os.mkfifo(os.path.join(temp_dir, TMP_DIR_FOR_TESTS, FIFO_FOR_TESTS)) + try: + os.mkfifo(os.path.join(temp_dir, TMP_DIR_FOR_TESTS, FIFO_FOR_TESTS)) + except AttributeError: + os.makedirs(os.path.join(temp_dir, TMP_DIR_FOR_TESTS, FIFO_FOR_TESTS)) self.temp_dir = str(temp_dir) @@ -99,14 +103,20 @@ def setup_test_cases(self, tmp_path_factory): self.update_connection(self.old_login) def test_get_conn(self): - output = self.hook.get_conn() - assert isinstance(output, paramiko.SFTPClient) + with self.hook.get_conn() as conn: + assert isinstance(conn, paramiko.SFTPClient) - def test_close_conn(self): - self.hook.conn = self.hook.get_conn() - assert self.hook.conn is not None - self.hook.close_conn() - assert self.hook.conn is None + @patch("airflow.providers.ssh.hooks.ssh.SSHHook.get_conn") + def test_get_close_conn(self, mock_get_conn): + mock_sftp_client = MagicMock(spec=SFTPClient) + mock_ssh_client = MagicMock(spec=SSHClient) + mock_ssh_client.open_sftp.return_value = mock_sftp_client + mock_get_conn.return_value = mock_ssh_client + + with SFTPHook().get_conn() as conn: + assert conn == mock_sftp_client + + mock_sftp_client.close.assert_called_once() def test_describe_directory(self): output = self.hook.describe_directory(self.temp_dir) @@ -129,8 +139,9 @@ def test_mkdir(self): assert new_dir_name in output # test the directory has default permissions to 777 - umask umask = 0o022 - output = self.hook.get_conn().lstat(os.path.join(self.temp_dir, TMP_DIR_FOR_TESTS, new_dir_name)) - assert output.st_mode & 0o777 == 0o777 - umask + with self.hook.get_conn() as conn: + output = conn.lstat(os.path.join(self.temp_dir, TMP_DIR_FOR_TESTS, new_dir_name)) + assert output.st_mode & 0o777 == 0o777 - umask def test_create_and_delete_directory(self): new_dir_name = "new_dir" @@ -139,8 +150,9 @@ def test_create_and_delete_directory(self): assert new_dir_name in output # test the directory has default permissions to 777 umask = 0o022 - output = self.hook.get_conn().lstat(os.path.join(self.temp_dir, TMP_DIR_FOR_TESTS, new_dir_name)) - assert output.st_mode & 0o777 == 0o777 - umask + with self.hook.get_conn() as conn: + output = conn.lstat(os.path.join(self.temp_dir, TMP_DIR_FOR_TESTS, new_dir_name)) + assert output.st_mode & 0o777 == 0o777 - umask # test directory already exists for code coverage, should not raise an exception self.hook.create_directory(os.path.join(self.temp_dir, TMP_DIR_FOR_TESTS, new_dir_name)) # test path already exists and is a file, should raise an exception @@ -167,6 +179,28 @@ def test_create_and_delete_directories(self): assert new_dir_path not in output assert base_dir not in output + def test_create_and_delete_directory_with_files(self): + new_dir = "new_dir" + sub_dir = "sub_dir" + additional_file = "additional_file.txt" + self.hook.create_directory(os.path.join(self.temp_dir, TMP_DIR_FOR_TESTS, new_dir)) + output = self.hook.describe_directory(os.path.join(self.temp_dir, TMP_DIR_FOR_TESTS)) + assert new_dir in output + self.hook.create_directory(os.path.join(self.temp_dir, TMP_DIR_FOR_TESTS, new_dir, sub_dir)) + self._create_additional_test_file(file_name=additional_file) + self.hook.store_file( + remote_full_path=os.path.join(self.temp_dir, TMP_DIR_FOR_TESTS, new_dir, additional_file), + local_full_path=os.path.join(self.temp_dir, additional_file), + ) + output = self.hook.describe_directory(os.path.join(self.temp_dir, TMP_DIR_FOR_TESTS, new_dir)) + assert sub_dir in output + assert additional_file in output + self.hook.delete_directory( + os.path.join(self.temp_dir, TMP_DIR_FOR_TESTS, new_dir), include_files=True + ) + output = self.hook.describe_directory(os.path.join(self.temp_dir, TMP_DIR_FOR_TESTS)) + assert new_dir not in output + def test_store_retrieve_and_delete_file(self): self.hook.store_file( remote_full_path=os.path.join(self.temp_dir, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS), @@ -185,6 +219,24 @@ def test_store_retrieve_and_delete_file(self): output = self.hook.list_directory(path=os.path.join(self.temp_dir, TMP_DIR_FOR_TESTS)) assert output == [SUB_DIR, FIFO_FOR_TESTS] + def test_store_retrieve_and_delete_file_using_buffer(self): + file_contents = BytesIO(b"Test file") + self.hook.store_file( + remote_full_path=os.path.join(self.temp_dir, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS), + local_full_path=file_contents, + ) + output = self.hook.list_directory(path=os.path.join(self.temp_dir, TMP_DIR_FOR_TESTS)) + assert output == [SUB_DIR, FIFO_FOR_TESTS, TMP_FILE_FOR_TESTS] + retrieved_file_contents = BytesIO() + self.hook.retrieve_file( + remote_full_path=os.path.join(self.temp_dir, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS), + local_full_path=retrieved_file_contents, + ) + assert retrieved_file_contents.getvalue() == file_contents.getvalue() + self.hook.delete_file(path=os.path.join(self.temp_dir, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS)) + output = self.hook.list_directory(path=os.path.join(self.temp_dir, TMP_DIR_FOR_TESTS)) + assert output == [SUB_DIR, FIFO_FOR_TESTS] + def test_get_mod_time(self): self.hook.store_file( remote_full_path=os.path.join(self.temp_dir, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS), @@ -195,14 +247,14 @@ def test_get_mod_time(self): ) assert len(output) == 14 - @mock.patch("airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection") + @patch("airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection") def test_no_host_key_check_default(self, get_connection): connection = Connection(login="login", host="host") get_connection.return_value = connection hook = SFTPHook() assert hook.no_host_key_check is True - @mock.patch("airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection") + @patch("airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection") def test_no_host_key_check_enabled(self, get_connection): connection = Connection(login="login", host="host", extra='{"no_host_key_check": true}') @@ -210,7 +262,7 @@ def test_no_host_key_check_enabled(self, get_connection): hook = SFTPHook() assert hook.no_host_key_check is True - @mock.patch("airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection") + @patch("airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection") def test_no_host_key_check_disabled(self, get_connection): connection = Connection(login="login", host="host", extra='{"no_host_key_check": false}') @@ -218,7 +270,7 @@ def test_no_host_key_check_disabled(self, get_connection): hook = SFTPHook() assert hook.no_host_key_check is False - @mock.patch("airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection") + @patch("airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection") def test_ciphers(self, get_connection): connection = Connection(login="login", host="host", extra='{"ciphers": ["A", "B", "C"]}') @@ -226,7 +278,7 @@ def test_ciphers(self, get_connection): hook = SFTPHook() assert hook.ciphers == ["A", "B", "C"] - @mock.patch("airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection") + @patch("airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection") def test_no_host_key_check_disabled_for_all_but_true(self, get_connection): connection = Connection(login="login", host="host", extra='{"no_host_key_check": "foo"}') @@ -234,7 +286,7 @@ def test_no_host_key_check_disabled_for_all_but_true(self, get_connection): hook = SFTPHook() assert hook.no_host_key_check is False - @mock.patch("airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection") + @patch("airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection") def test_no_host_key_check_ignore(self, get_connection): connection = Connection(login="login", host="host", extra='{"ignore_hostkey_verification": true}') @@ -242,14 +294,14 @@ def test_no_host_key_check_ignore(self, get_connection): hook = SFTPHook() assert hook.no_host_key_check is True - @mock.patch("airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection") + @patch("airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection") def test_host_key_default(self, get_connection): connection = Connection(login="login", host="host") get_connection.return_value = connection hook = SFTPHook() assert hook.host_key is None - @mock.patch("airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection") + @patch("airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection") def test_host_key(self, get_connection): connection = Connection( login="login", @@ -260,7 +312,7 @@ def test_host_key(self, get_connection): hook = SFTPHook() assert hook.host_key.get_base64() == TEST_HOST_KEY - @mock.patch("airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection") + @patch("airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection") def test_host_key_with_type(self, get_connection): connection = Connection( login="login", @@ -271,14 +323,14 @@ def test_host_key_with_type(self, get_connection): hook = SFTPHook() assert hook.host_key.get_base64() == TEST_HOST_KEY - @mock.patch("airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection") + @patch("airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection") def test_host_key_with_no_host_key_check(self, get_connection): connection = Connection(login="login", host="host", extra=json.dumps({"host_key": TEST_HOST_KEY})) get_connection.return_value = connection hook = SFTPHook() assert hook.host_key is not None - @mock.patch("airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection") + @patch("airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection") def test_key_content_as_str(self, get_connection): file_obj = StringIO() TEST_PKEY.write_private_key(file_obj) @@ -299,7 +351,7 @@ def test_key_content_as_str(self, get_connection): assert hook.pkey == TEST_PKEY assert hook.key_file is None - @mock.patch("airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection") + @patch("airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection") def test_key_file(self, get_connection): connection = Connection( login="login", @@ -356,37 +408,50 @@ def test_get_tree_map(self): assert dirs == [os.path.join(self.temp_dir, TMP_DIR_FOR_TESTS, SUB_DIR)] assert unknowns == [os.path.join(self.temp_dir, TMP_DIR_FOR_TESTS, FIFO_FOR_TESTS)] - @mock.patch("airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection") - def test_connection_failure(self, mock_get_connection): - connection = Connection( - login="login", - host="host", + @patch("airflow.providers.ssh.hooks.ssh.SSHHook.get_conn") + def test_connection_failure(self, mock_get_conn): + mock_ssh_client = MagicMock(spec=SSHClient) + type(mock_ssh_client.open_sftp.return_value).normalize = PropertyMock( + side_effect=Exception("Connection Error") ) - mock_get_connection.return_value = connection - with mock.patch.object(SFTPHook, "get_conn") as get_conn: - type(get_conn.return_value).normalize = mock.PropertyMock( - side_effect=Exception("Connection Error") - ) + mock_get_conn.return_value = mock_ssh_client + + hook = SFTPHook() + status, msg = hook.test_connection() - hook = SFTPHook() - status, msg = hook.test_connection() assert status is False assert msg == "Connection Error" - @mock.patch("airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection") - def test_connection_success(self, mock_get_connection): - connection = Connection( - login="login", - host="host", + @pytest.mark.parametrize( + "test_connection_side_effect", + [ + (lambda arg: (True, "Connection successfully tested")), + (lambda arg: RuntimeError("Test connection failed")), + ], + ) + @patch("airflow.providers.ssh.hooks.ssh.SSHHook.get_conn") + def test_context_manager(self, mock_get_conn, test_connection_side_effect): + mock_sftp_client = MagicMock(spec=SFTPClient) + mock_ssh_client = MagicMock(spec=SSHClient) + mock_ssh_client.open_sftp.return_value = mock_sftp_client + mock_get_conn.return_value = mock_ssh_client + + type(mock_sftp_client.normalize.return_value).normalize = PropertyMock( + side_effect=test_connection_side_effect ) - mock_get_connection.return_value = connection - with mock.patch.object(SFTPHook, "get_conn") as get_conn: - get_conn.return_value.pwd = "/home/someuser" - hook = SFTPHook() + hook = SFTPHook() + if isinstance(test_connection_side_effect, RuntimeError): + with pytest.raises(RuntimeError, match="Test connection failed"): + hook.test_connection() + else: status, msg = hook.test_connection() - assert status is True - assert msg == "Connection successfully tested" + + assert status is True + assert msg == "Connection successfully tested" + + mock_ssh_client.open_sftp.assert_called_once() + mock_sftp_client.close.assert_called() def test_get_suffix_pattern_match(self): output = self.hook.get_file_by_pattern(self.temp_dir, "*.txt") @@ -447,6 +512,38 @@ def test_store_and_retrieve_directory(self): ) assert retrieved_dir_name in os.listdir(os.path.join(self.temp_dir, TMP_DIR_FOR_TESTS)) + @patch("paramiko.SSHClient") + @patch("paramiko.ProxyCommand") + def test_sftp_hook_with_proxy_command(self, mock_proxy_command, mock_ssh_client): + mock_sftp_client = MagicMock(spec=SFTPClient) + mock_ssh_client.open_sftp.return_value = mock_sftp_client + + mock_transport = MagicMock() + mock_ssh_client.return_value.get_transport.return_value = mock_transport + mock_proxy_command.return_value = MagicMock() + + host_proxy_cmd = "ncat --proxy-auth proxy_user:**** --proxy proxy_host:port %h %p" + + hook = SFTPHook( + remote_host="example.com", + username="user", + host_proxy_cmd=host_proxy_cmd, + ) + + with hook.get_conn(): + mock_proxy_command.assert_called_once_with(host_proxy_cmd) + mock_ssh_client.return_value.connect.assert_called_once_with( + hostname="example.com", + username="user", + timeout=None, + compress=True, + port=22, + sock=mock_proxy_command.return_value, + look_for_keys=True, + banner_timeout=30.0, + auth_timeout=None, + ) + class MockSFTPClient: def __init__(self): @@ -755,8 +852,9 @@ async def test_get_mod_time(self, mock_hook_get_conn): """ Assert that file attribute and return the modified time of the file """ - mock_hook_get_conn.return_value.start_sftp_client.return_value = MockSFTPClient() + mock_hook_get_conn.return_value.__aenter__.return_value = MockSSHClient() hook = SFTPHookAsync() + mod_time = await hook.get_mod_time("/path/exists/file") expected_value = datetime.datetime.fromtimestamp(1667302566).strftime("%Y%m%d%H%M%S") assert mod_time == expected_value @@ -767,36 +865,9 @@ async def test_get_mod_time_exception(self, mock_hook_get_conn): """ Assert that get_mod_time raise exception when file does not exist """ - mock_hook_get_conn.return_value.start_sftp_client.return_value = MockSFTPClient() + mock_hook_get_conn.return_value.__aenter__.return_value = MockSSHClient() hook = SFTPHookAsync() + with pytest.raises(AirflowException) as exc: await hook.get_mod_time("/path/does_not/exist/") assert str(exc.value) == "No files matching" - - @patch("paramiko.SSHClient") - @mock.patch("paramiko.ProxyCommand") - def test_sftp_hook_with_proxy_command(self, mock_proxy_command, mock_ssh_client): - mock_transport = mock.MagicMock() - mock_ssh_client.return_value.get_transport.return_value = mock_transport - mock_proxy_command.return_value = mock.MagicMock() - - host_proxy_cmd = "ncat --proxy-auth proxy_user:**** --proxy proxy_host:port %h %p" - hook = SFTPHook( - remote_host="example.com", - username="user", - host_proxy_cmd=host_proxy_cmd, - ) - hook.get_conn() - - mock_proxy_command.assert_called_once_with(host_proxy_cmd) - mock_ssh_client.return_value.connect.assert_called_once_with( - hostname="example.com", - username="user", - timeout=None, - compress=True, - port=22, - sock=mock_proxy_command.return_value, - look_for_keys=True, - banner_timeout=30.0, - auth_timeout=None, - ) diff --git a/providers/sftp/tests/provider_tests/sftp/operators/test_sftp.py b/providers/sftp/tests/provider_tests/sftp/operators/test_sftp.py index c0d2c03db8e31..a376fefda7d44 100644 --- a/providers/sftp/tests/provider_tests/sftp/operators/test_sftp.py +++ b/providers/sftp/tests/provider_tests/sftp/operators/test_sftp.py @@ -492,6 +492,60 @@ def test_return_str_when_local_filepath_was_str(self, mock_get): assert isinstance(return_value, str) assert return_value == local_filepath + @mock.patch("airflow.providers.sftp.operators.sftp.SFTPHook.delete_file") + def test_str_filepaths_delete(self, mock_delete): + remote_filepath = "/tmp/test" + SFTPOperator( + task_id="test_str_filepaths_delete", + sftp_hook=self.sftp_hook, + remote_filepath=remote_filepath, + operation=SFTPOperation.DELETE, + ).execute(None) + assert mock_delete.call_count == 1 + args, _ = mock_delete.call_args_list[0] + assert args == (remote_filepath,) + + @mock.patch("airflow.providers.sftp.operators.sftp.SFTPHook.delete_file") + def test_multiple_filepaths_delete(self, mock_delete): + remote_filepath = ["/tmp/rtest1", "/tmp/rtest2"] + SFTPOperator( + task_id="test_multiple_filepaths_delete", + sftp_hook=self.sftp_hook, + remote_filepath=remote_filepath, + operation=SFTPOperation.DELETE, + ).execute(None) + assert mock_delete.call_count == 2 + args0, _ = mock_delete.call_args_list[0] + args1, _ = mock_delete.call_args_list[1] + assert args0 == (remote_filepath[0],) + assert args1 == (remote_filepath[1],) + + @mock.patch("airflow.providers.sftp.operators.sftp.SFTPHook.delete_directory") + def test_str_dirpaths_delete(self, mock_delete): + remote_filepath = "/tmp" + SFTPOperator( + task_id="test_str_dirpaths_delete", + sftp_hook=self.sftp_hook, + remote_filepath=remote_filepath, + operation=SFTPOperation.DELETE, + ).execute(None) + assert mock_delete.call_count == 1 + args, _ = mock_delete.call_args_list[0] + assert args == (remote_filepath,) + + @mock.patch("airflow.providers.sftp.operators.sftp.SFTPHook.delete_file") + def test_local_filepath_exists_error_delete(self, mock_delete): + local_filepath = "/tmp" + remote_filepath = "/tmp_remote" + with pytest.raises(ValueError, match="local_filepath should not be provided for delete operation"): + SFTPOperator( + task_id="test_local_filepath_exists_error_delete", + sftp_hook=self.sftp_hook, + local_filepath=local_filepath, + remote_filepath=remote_filepath, + operation=SFTPOperation.DELETE, + ).execute(None) + @pytest.mark.parametrize( "operation, expected", TEST_GET_PUT_PARAMS, diff --git a/providers/snowflake/docs/connections/snowflake.rst b/providers/snowflake/docs/connections/snowflake.rst index 741d73a62e3d4..2d7076d120f55 100644 --- a/providers/snowflake/docs/connections/snowflake.rst +++ b/providers/snowflake/docs/connections/snowflake.rst @@ -64,6 +64,7 @@ Extra (optional) * ``insecure_mode``: Turn off OCSP certificate checks. For details, see: `How To: Turn Off OCSP Checking in Snowflake Client Drivers - Snowflake Community `_. * ``host``: Target Snowflake hostname to connect to (e.g., for local testing with LocalStack). * ``port``: Target Snowflake port to connect to (e.g., for local testing with LocalStack). + * ``ocsp_fail_open``: Specify `ocsp_fail_open `_. URI format example ^^^^^^^^^^^^^^^^^^ diff --git a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py index 45e12666b88d0..5777968b8d87c 100644 --- a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py +++ b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py @@ -299,6 +299,12 @@ def _get_conn_params(self) -> dict[str, str | None]: if snowflake_port: conn_config["port"] = snowflake_port + # if a value for ocsp_fail_open is set, pass it along. + # Note the check is for `is not None` so that we can pass along `False` as a value. + ocsp_fail_open = extra_dict.get("ocsp_fail_open") + if ocsp_fail_open is not None: + conn_config["ocsp_fail_open"] = _try_to_boolean(ocsp_fail_open) + return conn_config def get_uri(self) -> str: @@ -320,6 +326,7 @@ def _conn_params_to_sqlalchemy_uri(self, conn_params: dict) -> str: "client_request_mfa_token", "client_store_temporary_credential", "json_result_force_utf8_decoding", + "ocsp_fail_open", ] } ) @@ -345,6 +352,9 @@ def get_sqlalchemy_engine(self, engine_kwargs=None): if "json_result_force_utf8_decoding" in conn_params: engine_kwargs.setdefault("connect_args", {}) engine_kwargs["connect_args"]["json_result_force_utf8_decoding"] = True + if "ocsp_fail_open" in conn_params: + engine_kwargs.setdefault("connect_args", {}) + engine_kwargs["connect_args"]["ocsp_fail_open"] = conn_params["ocsp_fail_open"] for key in ["session_parameters", "private_key"]: if conn_params.get(key): engine_kwargs.setdefault("connect_args", {}) diff --git a/providers/snowflake/src/airflow/providers/snowflake/operators/snowflake.py b/providers/snowflake/src/airflow/providers/snowflake/operators/snowflake.py index e948d338bd5dc..d2b7479dff38b 100644 --- a/providers/snowflake/src/airflow/providers/snowflake/operators/snowflake.py +++ b/providers/snowflake/src/airflow/providers/snowflake/operators/snowflake.py @@ -358,6 +358,11 @@ class SnowflakeSqlApiOperator(SQLExecuteQueryOperator): LIFETIME = timedelta(minutes=59) # The tokens will have a 59 minutes lifetime RENEWAL_DELTA = timedelta(minutes=54) # Tokens will be renewed after 54 minutes + template_fields: Sequence[str] = tuple( + set(SQLExecuteQueryOperator.template_fields) | {"snowflake_conn_id"} + ) + conn_id_field = "snowflake_conn_id" + def __init__( self, *, diff --git a/providers/snowflake/tests/provider_tests/snowflake/hooks/test_snowflake.py b/providers/snowflake/tests/provider_tests/snowflake/hooks/test_snowflake.py index 775e93827297b..b1a65b4293b66 100644 --- a/providers/snowflake/tests/provider_tests/snowflake/hooks/test_snowflake.py +++ b/providers/snowflake/tests/provider_tests/snowflake/hooks/test_snowflake.py @@ -277,6 +277,60 @@ class TestPytestSnowflakeHook: "json_result_force_utf8_decoding": True, }, ), + ( + { + **BASE_CONNECTION_KWARGS, + "extra": { + **BASE_CONNECTION_KWARGS["extra"], + "ocsp_fail_open": True, + }, + }, + ( + "snowflake://user:pw@airflow.af_region/db/public?" + "application=AIRFLOW&authenticator=snowflake&role=af_role&warehouse=af_wh" + ), + { + "account": "airflow", + "application": "AIRFLOW", + "authenticator": "snowflake", + "database": "db", + "password": "pw", + "region": "af_region", + "role": "af_role", + "schema": "public", + "session_parameters": None, + "user": "user", + "warehouse": "af_wh", + "ocsp_fail_open": True, + }, + ), + ( + { + **BASE_CONNECTION_KWARGS, + "extra": { + **BASE_CONNECTION_KWARGS["extra"], + "ocsp_fail_open": False, + }, + }, + ( + "snowflake://user:pw@airflow.af_region/db/public?" + "application=AIRFLOW&authenticator=snowflake&role=af_role&warehouse=af_wh" + ), + { + "account": "airflow", + "application": "AIRFLOW", + "authenticator": "snowflake", + "database": "db", + "password": "pw", + "region": "af_region", + "role": "af_role", + "schema": "public", + "session_parameters": None, + "user": "user", + "warehouse": "af_wh", + "ocsp_fail_open": False, + }, + ), ], ) def test_hook_should_support_prepare_basic_conn_params_and_uri( @@ -530,6 +584,23 @@ def test_get_sqlalchemy_engine_should_support_private_key_auth(self, non_encrypt assert "private_key" in mock_create_engine.call_args.kwargs["connect_args"] assert mock_create_engine.return_value == conn + def test_get_sqlalchemy_engine_should_support_ocsp_fail_open(self): + connection_kwargs = deepcopy(BASE_CONNECTION_KWARGS) + connection_kwargs["extra"]["ocsp_fail_open"] = "False" + + with ( + mock.patch.dict("os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()), + mock.patch("airflow.providers.snowflake.hooks.snowflake.create_engine") as mock_create_engine, + ): + hook = SnowflakeHook(snowflake_conn_id="test_conn") + conn = hook.get_sqlalchemy_engine() + mock_create_engine.assert_called_once_with( + "snowflake://user:pw@airflow.af_region/db/public" + "?application=AIRFLOW&authenticator=snowflake&role=af_role&warehouse=af_wh", + connect_args={"ocsp_fail_open": False}, + ) + assert mock_create_engine.return_value == conn + def test_hook_parameters_should_take_precedence(self): with mock.patch.dict( "os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**BASE_CONNECTION_KWARGS).get_uri() diff --git a/providers/src/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py b/providers/src/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py index b93d6a1ad3789..6561e8121c465 100644 --- a/providers/src/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py +++ b/providers/src/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py @@ -20,9 +20,10 @@ from collections import defaultdict from collections.abc import Container, Sequence from functools import cached_property -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING, Any, cast -from flask import session, url_for +from fastapi import FastAPI +from flask import session from airflow.auth.managers.base_auth_manager import BaseAuthManager from airflow.auth.managers.models.resource_details import ( @@ -34,6 +35,7 @@ VariableDetails, ) from airflow.cli.cli_config import CLICommand, DefaultHelpParser, GroupCommand +from airflow.configuration import conf from airflow.exceptions import AirflowOptionalProviderFeatureException from airflow.providers.amazon.aws.auth_manager.avp.entities import AvpEntities from airflow.providers.amazon.aws.auth_manager.avp.facade import ( @@ -43,11 +45,7 @@ from airflow.providers.amazon.aws.auth_manager.cli.definition import ( AWS_AUTH_MANAGER_COMMANDS, ) -from airflow.providers.amazon.aws.auth_manager.security_manager.aws_security_manager_override import ( - AwsSecurityManagerOverride, -) from airflow.providers.amazon.aws.auth_manager.user import AwsAuthManagerUser -from airflow.providers.amazon.aws.auth_manager.views.auth import AwsAuthManagerAuthenticationViews from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS if TYPE_CHECKING: @@ -61,7 +59,6 @@ IsAuthorizedVariableRequest, ) from airflow.auth.managers.models.resource_details import AssetDetails, ConfigurationDetails - from airflow.www.extensions.init_appbuilder import AirflowAppBuilder class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]): @@ -72,8 +69,6 @@ class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]): authentication and authorization in Airflow. """ - appbuilder: AirflowAppBuilder | None = None - def __init__(self) -> None: if not AIRFLOW_V_3_0_PLUS: raise AirflowOptionalProviderFeatureException( @@ -87,12 +82,27 @@ def __init__(self) -> None: def avp_facade(self): return AwsAuthManagerAmazonVerifiedPermissionsFacade() + @cached_property + def fastapi_endpoint(self) -> str: + return conf.get("fastapi", "base_url") + def get_user(self) -> AwsAuthManagerUser | None: return session["aws_user"] if self.is_logged_in() else None def is_logged_in(self) -> bool: return "aws_user" in session + def deserialize_user(self, token: dict[str, Any]) -> AwsAuthManagerUser: + return AwsAuthManagerUser(**token) + + def serialize_user(self, user: AwsAuthManagerUser) -> dict[str, Any]: + return { + "user_id": user.get_id(), + "groups": user.get_groups(), + "username": user.username, + "email": user.email, + } + def is_authorized_configuration( self, *, @@ -367,14 +377,10 @@ def _has_access_to_menu_item(request: IsAuthorizedRequest): return accessible_items def get_url_login(self, **kwargs) -> str: - return url_for("AwsAuthManagerAuthenticationViews.login") + return f"{self.fastapi_endpoint}/auth/login" def get_url_logout(self) -> str: - return url_for("AwsAuthManagerAuthenticationViews.logout") - - @cached_property - def security_manager(self) -> AwsSecurityManagerOverride: - return AwsSecurityManagerOverride(self.appbuilder) + raise NotImplementedError() @staticmethod def get_cli_commands() -> list[CLICommand]: @@ -387,9 +393,20 @@ def get_cli_commands() -> list[CLICommand]: ), ] - def register_views(self) -> None: - if self.appbuilder: - self.appbuilder.add_view_no_menu(AwsAuthManagerAuthenticationViews()) + def get_fastapi_app(self) -> FastAPI | None: + from airflow.providers.amazon.aws.auth_manager.router.login import login_router + + app = FastAPI( + title="AWS auth manager sub application", + description=( + "This is the AWS auth manager fastapi sub application. This API is only available if the " + "auth manager used in the Airflow environment is AWS auth manager. " + "This sub application provides login routes." + ), + ) + app.include_router(login_router) + + return app @staticmethod def _get_menu_item_request(resource_name: str) -> IsAuthorizedRequest: diff --git a/providers/tests/cncf/kubernetes/operators/__init__.py b/providers/src/airflow/providers/amazon/aws/auth_manager/router/__init__.py similarity index 100% rename from providers/tests/cncf/kubernetes/operators/__init__.py rename to providers/src/airflow/providers/amazon/aws/auth_manager/router/__init__.py diff --git a/providers/src/airflow/providers/amazon/aws/auth_manager/router/login.py b/providers/src/airflow/providers/amazon/aws/auth_manager/router/login.py new file mode 100644 index 0000000000000..aca770dca29e6 --- /dev/null +++ b/providers/src/airflow/providers/amazon/aws/auth_manager/router/login.py @@ -0,0 +1,124 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import logging +from typing import Any + +import anyio +from fastapi import HTTPException, Request +from starlette import status +from starlette.responses import RedirectResponse + +from airflow.api_fastapi.app import get_auth_manager +from airflow.api_fastapi.common.router import AirflowRouter +from airflow.configuration import conf +from airflow.providers.amazon.aws.auth_manager.constants import CONF_SAML_METADATA_URL_KEY, CONF_SECTION_NAME +from airflow.providers.amazon.aws.auth_manager.user import AwsAuthManagerUser + +try: + from onelogin.saml2.auth import OneLogin_Saml2_Auth + from onelogin.saml2.errors import OneLogin_Saml2_Error + from onelogin.saml2.idp_metadata_parser import OneLogin_Saml2_IdPMetadataParser +except ImportError: + raise ImportError( + "AWS auth manager requires the python3-saml library but it is not installed by default. " + "Please install the python3-saml library by running: " + "pip install apache-airflow-providers-amazon[python3-saml]" + ) + +log = logging.getLogger(__name__) +login_router = AirflowRouter(tags=["AWSAuthManagerLogin"]) + + +@login_router.get("/login") +def login(request: Request): + """Authenticate the user.""" + saml_auth = _init_saml_auth(request) + callback_url = saml_auth.login() + return RedirectResponse(url=callback_url) + + +@login_router.post("/login_callback") +def login_callback(request: Request): + """Authenticate the user.""" + saml_auth = _init_saml_auth(request) + try: + saml_auth.process_response() + except OneLogin_Saml2_Error as e: + log.exception(e) + raise HTTPException(status.HTTP_500_INTERNAL_SERVER_ERROR, "Failed to authenticate") + errors = saml_auth.get_errors() + is_authenticated = saml_auth.is_authenticated() + if not is_authenticated: + error_reason = saml_auth.get_last_error_reason() + log.error("Failed to authenticate") + log.error("Errors: %s", errors) + log.error("Error reason: %s", error_reason) + raise HTTPException(status.HTTP_500_INTERNAL_SERVER_ERROR, f"Failed to authenticate: {error_reason}") + + attributes = saml_auth.get_attributes() + user = AwsAuthManagerUser( + user_id=attributes["id"][0], + groups=attributes["groups"], + username=saml_auth.get_nameid(), + email=attributes["email"][0] if "email" in attributes else None, + ) + return RedirectResponse(url=f"/webapp?token={get_auth_manager().get_jwt_token(user)}", status_code=303) + + +def _init_saml_auth(request: Request) -> OneLogin_Saml2_Auth: + request_data = _prepare_request(request) + base_url = conf.get(section="fastapi", key="base_url") + settings = { + # We want to keep this flag on in case of errors. + # It provides an error reasons, if turned off, it does not + "debug": True, + "sp": { + "entityId": "aws-auth-manager-saml-client", + "assertionConsumerService": { + "url": f"{base_url}/auth/login_callback", + "binding": "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST", + }, + }, + } + merged_settings = OneLogin_Saml2_IdPMetadataParser.merge_settings(_get_idp_data(), settings) + return OneLogin_Saml2_Auth(request_data, merged_settings) + + +def _prepare_request(request: Request) -> dict: + host = request.headers.get("host", request.client.host if request.client else "localhost") + data: dict[str, Any] = { + "https": "on" if request.url.scheme == "https" else "off", + "http_host": host, + "server_port": request.url.port, + "script_name": request.url.path, + "get_data": request.query_params, + "post_data": {}, + } + form_data = anyio.from_thread.run(request.form) + if "SAMLResponse" in form_data: + data["post_data"]["SAMLResponse"] = form_data["SAMLResponse"] + if "RelayState" in form_data: + data["post_data"]["RelayState"] = form_data["RelayState"] + return data + + +def _get_idp_data() -> dict: + saml_metadata_url = conf.get_mandatory_value(CONF_SECTION_NAME, CONF_SAML_METADATA_URL_KEY) + return OneLogin_Saml2_IdPMetadataParser.parse_remote(saml_metadata_url) diff --git a/providers/src/airflow/providers/amazon/aws/auth_manager/views/auth.py b/providers/src/airflow/providers/amazon/aws/auth_manager/views/auth.py deleted file mode 100644 index e08c2a7a6e100..0000000000000 --- a/providers/src/airflow/providers/amazon/aws/auth_manager/views/auth.py +++ /dev/null @@ -1,151 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -import logging -from functools import cached_property - -from flask import make_response, redirect, request, session, url_for -from flask_appbuilder import expose - -from airflow.configuration import conf -from airflow.exceptions import AirflowException -from airflow.providers.amazon.aws.auth_manager.constants import CONF_SAML_METADATA_URL_KEY, CONF_SECTION_NAME -from airflow.providers.amazon.aws.auth_manager.user import AwsAuthManagerUser -from airflow.www.app import csrf -from airflow.www.views import AirflowBaseView - -try: - from onelogin.saml2.auth import OneLogin_Saml2_Auth - from onelogin.saml2.idp_metadata_parser import OneLogin_Saml2_IdPMetadataParser -except ImportError: - raise ImportError( - "AWS auth manager requires the python3-saml library but it is not installed by default. " - "Please install the python3-saml library by running: " - "pip install apache-airflow-providers-amazon[python3-saml]" - ) - -logger = logging.getLogger(__name__) - - -class AwsAuthManagerAuthenticationViews(AirflowBaseView): - """ - Views specific to AWS auth manager authentication mechanism. - - Some code below is inspired from - https://github.com/SAML-Toolkits/python3-saml/blob/6988bdab7a203abfe8dc264992f7e350c67aef3d/demo-flask/index.py - """ - - @cached_property - def idp_data(self) -> dict: - saml_metadata_url = conf.get_mandatory_value(CONF_SECTION_NAME, CONF_SAML_METADATA_URL_KEY) - return OneLogin_Saml2_IdPMetadataParser.parse_remote(saml_metadata_url) - - @expose("/login") - def login(self): - """Start login process.""" - saml_auth = self._init_saml_auth() - return redirect(saml_auth.login()) - - @expose("/logout", methods=("GET", "POST")) - def logout(self): - """Start logout process.""" - session.clear() - saml_auth = self._init_saml_auth() - - return redirect(saml_auth.logout()) - - @csrf.exempt - @expose("/login_callback", methods=("GET", "POST")) - def login_callback(self): - """ - Redirect the user to this callback after successful login. - - CSRF protection needs to be disabled otherwise the callback won't work. - """ - saml_auth = self._init_saml_auth() - saml_auth.process_response() - errors = saml_auth.get_errors() - is_authenticated = saml_auth.is_authenticated() - if not is_authenticated: - error_reason = saml_auth.get_last_error_reason() - logger.error("Failed to authenticate") - logger.error("Errors: %s", errors) - logger.error("Error reason: %s", error_reason) - raise AirflowException(f"Failed to authenticate: {error_reason}") - - attributes = saml_auth.get_attributes() - user = AwsAuthManagerUser( - user_id=attributes["id"][0], - groups=attributes["groups"], - username=saml_auth.get_nameid(), - email=attributes["email"][0] if "email" in attributes else None, - ) - session["aws_user"] = user - - return redirect(url_for("Airflow.index")) - - @csrf.exempt - @expose("/logout_callback", methods=("GET", "POST")) - def logout_callback(self): - raise NotImplementedError("AWS Identity center does not support SLO (Single Logout Service)") - - @expose("/login_metadata") - def login_metadata(self): - saml_auth = self._init_saml_auth() - settings = saml_auth.get_settings() - metadata = settings.get_sp_metadata() - errors = settings.validate_metadata(metadata) - - if len(errors) == 0: - resp = make_response(metadata, 200) - resp.headers["Content-Type"] = "text/xml" - else: - resp = make_response(", ".join(errors), 500) - return resp - - @staticmethod - def _prepare_flask_request() -> dict: - return { - "https": "on" if request.scheme == "https" else "off", - "http_host": request.host, - "script_name": request.path, - "get_data": request.args.copy(), - "post_data": request.form.copy(), - } - - def _init_saml_auth(self) -> OneLogin_Saml2_Auth: - request_data = self._prepare_flask_request() - base_url = conf.get(section="webserver", key="base_url") - settings = { - # We want to keep this flag on in case of errors. - # It provides an error reasons, if turned off, it does not - "debug": True, - "sp": { - "entityId": f"{base_url}/login_metadata", - "assertionConsumerService": { - "url": f"{base_url}/login_callback", - "binding": "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST", - }, - "singleLogoutService": { - "url": f"{base_url}/logout_callback", - "binding": "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect", - }, - }, - } - merged_settings = OneLogin_Saml2_IdPMetadataParser.merge_settings(settings, self.idp_data) - return OneLogin_Saml2_Auth(request_data, merged_settings) diff --git a/providers/src/airflow/providers/amazon/aws/hooks/base_aws.py b/providers/src/airflow/providers/amazon/aws/hooks/base_aws.py index 2c919a4ff5183..04c21380155f8 100644 --- a/providers/src/airflow/providers/amazon/aws/hooks/base_aws.py +++ b/providers/src/airflow/providers/amazon/aws/hooks/base_aws.py @@ -58,7 +58,6 @@ from airflow.providers_manager import ProvidersManager from airflow.utils.helpers import exactly_one from airflow.utils.log.logging_mixin import LoggingMixin -from airflow.utils.log.secrets_masker import mask_secret BaseAwsConnection = TypeVar("BaseAwsConnection", bound=Union[boto3.client, boto3.resource]) @@ -68,6 +67,12 @@ from botocore.credentials import ReadOnlyCredentials from airflow.models.connection import Connection + from airflow.sdk.execution_time.secrets_masker import mask_secret +else: + try: + from airflow.sdk.execution_time.secrets_masker import mask_secret + except ImportError: + from airflow.utils.log.secrets_masker import mask_secret _loader = botocore.loaders.Loader() """ diff --git a/providers/src/airflow/providers/amazon/aws/hooks/ecr.py b/providers/src/airflow/providers/amazon/aws/hooks/ecr.py index 62da7a7636c37..2feb04402ecfb 100644 --- a/providers/src/airflow/providers/amazon/aws/hooks/ecr.py +++ b/providers/src/airflow/providers/amazon/aws/hooks/ecr.py @@ -23,11 +23,17 @@ from typing import TYPE_CHECKING from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook -from airflow.utils.log.secrets_masker import mask_secret if TYPE_CHECKING: from datetime import datetime + from airflow.sdk.execution_time.secrets_masker import mask_secret +else: + try: + from airflow.sdk.execution_time.secrets_masker import mask_secret + except ImportError: + from airflow.utils.log.secrets_masker import mask_secret + logger = logging.getLogger(__name__) diff --git a/providers/src/airflow/providers/amazon/aws/hooks/emr.py b/providers/src/airflow/providers/amazon/aws/hooks/emr.py index 17ecaae2fd7e1..bd0b7d2039df3 100644 --- a/providers/src/airflow/providers/amazon/aws/hooks/emr.py +++ b/providers/src/airflow/providers/amazon/aws/hooks/emr.py @@ -22,7 +22,9 @@ import warnings from typing import Any +import tenacity from botocore.exceptions import ClientError +from tenacity import retry_if_exception, stop_after_attempt, wait_fixed from airflow.exceptions import AirflowException, AirflowNotFoundException from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook @@ -311,6 +313,15 @@ def cancel_running_jobs( return count +def is_connection_being_updated_exception(exception: BaseException) -> bool: + return ( + isinstance(exception, ClientError) + and exception.response["Error"]["Code"] == "ValidationException" + and "is not reachable as its connection is currently being updated" + in exception.response["Error"]["Message"] + ) + + class EmrContainerHook(AwsBaseHook): """ Interact with Amazon EMR Containers (Amazon EMR on EKS). @@ -348,6 +359,15 @@ def __init__(self, *args: Any, virtual_cluster_id: str | None = None, **kwargs: super().__init__(client_type="emr-containers", *args, **kwargs) # type: ignore self.virtual_cluster_id = virtual_cluster_id + # Retry this method when the ``create_virtual_cluster`` raises + # "Cluster XXX is not reachable as its connection is currently being updated". + # Even though the EKS cluster status is ``ACTIVE``, ``create_virtual_cluster`` can raise this error. + # Retrying is the only option. + @tenacity.retry( + retry=retry_if_exception(is_connection_being_updated_exception), + stop=stop_after_attempt(5), + wait=wait_fixed(10), + ) def create_emr_on_eks_cluster( self, virtual_cluster_name: str, diff --git a/providers/src/airflow/providers/amazon/aws/hooks/ssm.py b/providers/src/airflow/providers/amazon/aws/hooks/ssm.py index 4cc8c4f9afd42..15f2c64ce8d07 100644 --- a/providers/src/airflow/providers/amazon/aws/hooks/ssm.py +++ b/providers/src/airflow/providers/amazon/aws/hooks/ssm.py @@ -17,10 +17,19 @@ from __future__ import annotations +from typing import TYPE_CHECKING + from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook -from airflow.utils.log.secrets_masker import mask_secret from airflow.utils.types import NOTSET, ArgNotSet +if TYPE_CHECKING: + from airflow.sdk.execution_time.secrets_masker import mask_secret +else: + try: + from airflow.sdk.execution_time.secrets_masker import mask_secret + except ImportError: + from airflow.utils.log.secrets_masker import mask_secret + class SsmHook(AwsBaseHook): """ diff --git a/providers/src/airflow/providers/amazon/aws/links/ec2.py b/providers/src/airflow/providers/amazon/aws/links/ec2.py new file mode 100644 index 0000000000000..38a23956cddbb --- /dev/null +++ b/providers/src/airflow/providers/amazon/aws/links/ec2.py @@ -0,0 +1,46 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from airflow.providers.amazon.aws.links.base_aws import BASE_AWS_CONSOLE_LINK, BaseAwsLink + + +class EC2InstanceLink(BaseAwsLink): + """Helper class for constructing Amazon EC2 instance links.""" + + name = "Instance" + key = "_instance_id" + format_str = ( + BASE_AWS_CONSOLE_LINK + "/ec2/home?region={region_name}#InstanceDetails:instanceId={instance_id}" + ) + + +class EC2InstanceDashboardLink(BaseAwsLink): + """ + Helper class for constructing Amazon EC2 console links. + + This is useful for displaying the list of EC2 instances, rather + than a single instance. + """ + + name = "EC2 Instances" + key = "_instance_dashboard" + format_str = BASE_AWS_CONSOLE_LINK + "/ec2/home?region={region_name}#Instances:instanceId=:{instance_ids}" + + @staticmethod + def format_instance_id_filter(instance_ids: list[str]) -> str: + return ",:".join(instance_ids) diff --git a/providers/src/airflow/providers/amazon/aws/operators/ec2.py b/providers/src/airflow/providers/amazon/aws/operators/ec2.py index 5b25b27fd0555..f3d0e9fc2af25 100644 --- a/providers/src/airflow/providers/amazon/aws/operators/ec2.py +++ b/providers/src/airflow/providers/amazon/aws/operators/ec2.py @@ -23,6 +23,10 @@ from airflow.exceptions import AirflowException from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.ec2 import EC2Hook +from airflow.providers.amazon.aws.links.ec2 import ( + EC2InstanceDashboardLink, + EC2InstanceLink, +) if TYPE_CHECKING: from airflow.utils.context import Context @@ -47,6 +51,7 @@ class EC2StartInstanceOperator(BaseOperator): between each instance state checks until operation is completed """ + operator_extra_links = (EC2InstanceLink(),) template_fields: Sequence[str] = ("instance_id", "region_name") ui_color = "#eeaa11" ui_fgcolor = "#ffffff" @@ -71,6 +76,13 @@ def execute(self, context: Context): self.log.info("Starting EC2 instance %s", self.instance_id) instance = ec2_hook.get_instance(instance_id=self.instance_id) instance.start() + EC2InstanceLink.persist( + context=context, + operator=self, + aws_partition=ec2_hook.conn_partition, + instance_id=self.instance_id, + region_name=ec2_hook.conn_region_name, + ) ec2_hook.wait_for_state( instance_id=self.instance_id, target_state="running", @@ -97,6 +109,7 @@ class EC2StopInstanceOperator(BaseOperator): between each instance state checks until operation is completed """ + operator_extra_links = (EC2InstanceLink(),) template_fields: Sequence[str] = ("instance_id", "region_name") ui_color = "#eeaa11" ui_fgcolor = "#ffffff" @@ -120,7 +133,15 @@ def execute(self, context: Context): ec2_hook = EC2Hook(aws_conn_id=self.aws_conn_id, region_name=self.region_name) self.log.info("Stopping EC2 instance %s", self.instance_id) instance = ec2_hook.get_instance(instance_id=self.instance_id) + EC2InstanceLink.persist( + context=context, + operator=self, + aws_partition=ec2_hook.conn_partition, + instance_id=self.instance_id, + region_name=ec2_hook.conn_region_name, + ) instance.stop() + ec2_hook.wait_for_state( instance_id=self.instance_id, target_state="stopped", @@ -154,6 +175,7 @@ class EC2CreateInstanceOperator(BaseOperator): in the `running` state before returning. """ + operator_extra_links = (EC2InstanceDashboardLink(),) template_fields: Sequence[str] = ( "image_id", "max_count", @@ -198,6 +220,15 @@ def execute(self, context: Context): )["Instances"] instance_ids = self._on_kill_instance_ids = [instance["InstanceId"] for instance in instances] + # Console link is for EC2 dashboard list, not individual instances when more than 1 instance + + EC2InstanceDashboardLink.persist( + context=context, + operator=self, + region_name=ec2_hook.conn_region_name, + aws_partition=ec2_hook.conn_partition, + instance_ids=EC2InstanceDashboardLink.format_instance_id_filter(instance_ids), + ) for instance_id in instance_ids: self.log.info("Created EC2 instance %s", instance_id) @@ -311,6 +342,7 @@ class EC2RebootInstanceOperator(BaseOperator): in the `running` state before returning. """ + operator_extra_links = (EC2InstanceDashboardLink(),) template_fields: Sequence[str] = ("instance_ids", "region_name") ui_color = "#eeaa11" ui_fgcolor = "#ffffff" @@ -341,6 +373,14 @@ def execute(self, context: Context): self.log.info("Rebooting EC2 instances %s", ", ".join(self.instance_ids)) ec2_hook.conn.reboot_instances(InstanceIds=self.instance_ids) + # Console link is for EC2 dashboard list, not individual instances + EC2InstanceDashboardLink.persist( + context=context, + operator=self, + region_name=ec2_hook.conn_region_name, + aws_partition=ec2_hook.conn_partition, + instance_ids=EC2InstanceDashboardLink.format_instance_id_filter(self.instance_ids), + ) if self.wait_for_completion: ec2_hook.get_waiter("instance_running").wait( InstanceIds=self.instance_ids, @@ -374,6 +414,7 @@ class EC2HibernateInstanceOperator(BaseOperator): in the `stopped` state before returning. """ + operator_extra_links = (EC2InstanceDashboardLink(),) template_fields: Sequence[str] = ("instance_ids", "region_name") ui_color = "#eeaa11" ui_fgcolor = "#ffffff" @@ -404,6 +445,15 @@ def execute(self, context: Context): self.log.info("Hibernating EC2 instances %s", ", ".join(self.instance_ids)) instances = ec2_hook.get_instances(instance_ids=self.instance_ids) + # Console link is for EC2 dashboard list, not individual instances + EC2InstanceDashboardLink.persist( + context=context, + operator=self, + region_name=ec2_hook.conn_region_name, + aws_partition=ec2_hook.conn_partition, + instance_ids=EC2InstanceDashboardLink.format_instance_id_filter(self.instance_ids), + ) + for instance in instances: hibernation_options = instance.get("HibernationOptions") if not hibernation_options or not hibernation_options["Configured"]: diff --git a/providers/src/airflow/providers/amazon/provider.yaml b/providers/src/airflow/providers/amazon/provider.yaml index 43569a28827ab..5192052800079 100644 --- a/providers/src/airflow/providers/amazon/provider.yaml +++ b/providers/src/airflow/providers/amazon/provider.yaml @@ -891,7 +891,8 @@ extra-links: - airflow.providers.amazon.aws.links.comprehend.ComprehendDocumentClassifierLink - airflow.providers.amazon.aws.links.datasync.DataSyncTaskLink - airflow.providers.amazon.aws.links.datasync.DataSyncTaskExecutionLink - + - airflow.providers.amazon.aws.links.ec2.EC2InstanceLink + - airflow.providers.amazon.aws.links.ec2.EC2InstanceDashboardLink connection-types: - hook-class-name: airflow.providers.amazon.aws.hooks.base_aws.AwsGenericHook diff --git a/providers/src/airflow/providers/cncf/kubernetes/pod_generator_deprecated.py b/providers/src/airflow/providers/cncf/kubernetes/pod_generator_deprecated.py deleted file mode 100644 index 9a978cbd08df9..0000000000000 --- a/providers/src/airflow/providers/cncf/kubernetes/pod_generator_deprecated.py +++ /dev/null @@ -1,309 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -Backwards compatibility for Pod generation. - -This module provides an interface between the previous Pod -API and outputs a kubernetes.client.models.V1Pod. -The advantage being that the full Kubernetes API -is supported and no serialization need be written. -""" - -from __future__ import annotations - -import copy -import uuid - -import re2 -from kubernetes.client import models as k8s - -MAX_POD_ID_LEN = 253 - -MAX_LABEL_LEN = 63 - - -class PodDefaults: - """Static defaults for Pods.""" - - XCOM_MOUNT_PATH = "/airflow/xcom" - SIDECAR_CONTAINER_NAME = "airflow-xcom-sidecar" - XCOM_CMD = 'trap "exit 0" INT; while true; do sleep 30; done;' - VOLUME_MOUNT = k8s.V1VolumeMount(name="xcom", mount_path=XCOM_MOUNT_PATH) - VOLUME = k8s.V1Volume(name="xcom", empty_dir=k8s.V1EmptyDirVolumeSource()) - SIDECAR_CONTAINER = k8s.V1Container( - name=SIDECAR_CONTAINER_NAME, - command=["sh", "-c", XCOM_CMD], - image="alpine", - volume_mounts=[VOLUME_MOUNT], - resources=k8s.V1ResourceRequirements( - requests={ - "cpu": "1m", - } - ), - ) - - -def make_safe_label_value(string): - """ - Normalize a provided label to be of valid length and characters. - - Valid label values must be 63 characters or less and must be empty or begin and - end with an alphanumeric character ([a-z0-9A-Z]) with dashes (-), underscores (_), - dots (.), and alphanumerics between. - - If the label value is greater than 63 chars once made safe, or differs in any - way from the original value sent to this function, then we need to truncate to - 53 chars, and append it with a unique hash. - """ - from airflow.utils.hashlib_wrapper import md5 - - safe_label = re2.sub(r"^[^a-z0-9A-Z]*|[^a-zA-Z0-9_\-\.]|[^a-z0-9A-Z]*$", "", string) - - if len(safe_label) > MAX_LABEL_LEN or string != safe_label: - safe_hash = md5(string.encode()).hexdigest()[:9] - safe_label = safe_label[: MAX_LABEL_LEN - len(safe_hash) - 1] + "-" + safe_hash - - return safe_label - - -class PodGenerator: - """ - Contains Kubernetes Airflow Worker configuration logic. - - Represents a kubernetes pod and manages execution of a single pod. - Any configuration that is container specific gets applied to - the first container in the list of containers. - - :param image: The docker image - :param name: name in the metadata section (not the container name) - :param namespace: pod namespace - :param volume_mounts: list of kubernetes volumes mounts - :param envs: A dict containing the environment variables - :param cmds: The command to be run on the first container - :param args: The arguments to be run on the pod - :param labels: labels for the pod metadata - :param node_selectors: node selectors for the pod - :param ports: list of ports. Applies to the first container. - :param volumes: Volumes to be attached to the first container - :param image_pull_policy: Specify a policy to cache or always pull an image - :param restart_policy: The restart policy of the pod - :param image_pull_secrets: Any image pull secrets to be given to the pod. - If more than one secret is required, provide a comma separated list: - secret_a,secret_b - :param init_containers: A list of init containers - :param service_account_name: Identity for processes that run in a Pod - :param resources: Resource requirements for the first containers - :param annotations: annotations for the pod - :param affinity: A dict containing a group of affinity scheduling rules - :param hostnetwork: If True enable host networking on the pod - :param tolerations: A list of kubernetes tolerations - :param security_context: A dict containing the security context for the pod - :param configmaps: Any configmap refs to read ``configmaps`` for environments from. - If more than one configmap is required, provide a comma separated list - configmap_a,configmap_b - :param dnspolicy: Specify a dnspolicy for the pod - :param schedulername: Specify a schedulername for the pod - :param pod: The fully specified pod. Mutually exclusive with `path_or_string` - :param extract_xcom: Whether to bring up a container for xcom - :param priority_class_name: priority class name for the launched Pod - """ - - def __init__( - self, - image: str | None = None, - name: str | None = None, - namespace: str | None = None, - volume_mounts: list[k8s.V1VolumeMount | dict] | None = None, - envs: dict[str, str] | None = None, - cmds: list[str] | None = None, - args: list[str] | None = None, - labels: dict[str, str] | None = None, - node_selectors: dict[str, str] | None = None, - ports: list[k8s.V1ContainerPort | dict] | None = None, - volumes: list[k8s.V1Volume | dict] | None = None, - image_pull_policy: str | None = None, - restart_policy: str | None = None, - image_pull_secrets: str | None = None, - init_containers: list[k8s.V1Container] | None = None, - service_account_name: str | None = None, - resources: k8s.V1ResourceRequirements | dict | None = None, - annotations: dict[str, str] | None = None, - affinity: dict | None = None, - hostnetwork: bool = False, - tolerations: list | None = None, - security_context: k8s.V1PodSecurityContext | dict | None = None, - configmaps: list[str] | None = None, - dnspolicy: str | None = None, - schedulername: str | None = None, - extract_xcom: bool = False, - priority_class_name: str | None = None, - ): - self.pod = k8s.V1Pod() - self.pod.api_version = "v1" - self.pod.kind = "Pod" - - # Pod Metadata - self.metadata = k8s.V1ObjectMeta() - self.metadata.labels = labels - self.metadata.name = name - self.metadata.namespace = namespace - self.metadata.annotations = annotations - - # Pod Container - self.container = k8s.V1Container(name="base") - self.container.image = image - self.container.env = [] - - if envs: - if isinstance(envs, dict): - for key, val in envs.items(): - self.container.env.append(k8s.V1EnvVar(name=key, value=val)) - elif isinstance(envs, list): - self.container.env.extend(envs) - - configmaps = configmaps or [] - self.container.env_from = [] - for configmap in configmaps: - self.container.env_from.append( - k8s.V1EnvFromSource(config_map_ref=k8s.V1ConfigMapEnvSource(name=configmap)) - ) - - self.container.command = cmds or [] - self.container.args = args or [] - if image_pull_policy: - self.container.image_pull_policy = image_pull_policy - self.container.ports = ports or [] - self.container.resources = resources - self.container.volume_mounts = volume_mounts or [] - - # Pod Spec - self.spec = k8s.V1PodSpec(containers=[]) - self.spec.security_context = security_context - self.spec.tolerations = tolerations - if dnspolicy: - self.spec.dns_policy = dnspolicy - self.spec.scheduler_name = schedulername - self.spec.host_network = hostnetwork - self.spec.affinity = affinity - self.spec.service_account_name = service_account_name - self.spec.init_containers = init_containers - self.spec.volumes = volumes or [] - self.spec.node_selector = node_selectors - if restart_policy: - self.spec.restart_policy = restart_policy - self.spec.priority_class_name = priority_class_name - - self.spec.image_pull_secrets = [] - - if image_pull_secrets: - for image_pull_secret in image_pull_secrets.split(","): - self.spec.image_pull_secrets.append(k8s.V1LocalObjectReference(name=image_pull_secret)) - - # Attach sidecar - self.extract_xcom = extract_xcom - - def gen_pod(self) -> k8s.V1Pod: - """Generate pod.""" - result = None - - if result is None: - result = self.pod - result.spec = self.spec - result.metadata = self.metadata - result.spec.containers = [self.container] - - result.metadata.name = self.make_unique_pod_id(result.metadata.name) - - if self.extract_xcom: - result = self.add_sidecar(result) - - return result - - @staticmethod - def add_sidecar(pod: k8s.V1Pod) -> k8s.V1Pod: - """Add sidecar.""" - pod_cp = copy.deepcopy(pod) - pod_cp.spec.volumes = pod.spec.volumes or [] - pod_cp.spec.volumes.insert(0, PodDefaults.VOLUME) - pod_cp.spec.containers[0].volume_mounts = pod_cp.spec.containers[0].volume_mounts or [] - pod_cp.spec.containers[0].volume_mounts.insert(0, PodDefaults.VOLUME_MOUNT) - pod_cp.spec.containers.append(PodDefaults.SIDECAR_CONTAINER) - - return pod_cp - - @staticmethod - def from_obj(obj) -> k8s.V1Pod | None: - """Convert to pod from obj.""" - if obj is None: - return None - - if isinstance(obj, PodGenerator): - return obj.gen_pod() - - if not isinstance(obj, dict): - raise TypeError( - "Cannot convert a non-dictionary or non-PodGenerator " - "object into a KubernetesExecutorConfig" - ) - - # We do not want to extract constant here from ExecutorLoader because it is just - # A name in dictionary rather than executor selection mechanism and it causes cyclic import - namespaced = obj.get("KubernetesExecutor", {}) - - if not namespaced: - return None - - resources = namespaced.get("resources") - - if resources is None: - requests = { - "cpu": namespaced.get("request_cpu"), - "memory": namespaced.get("request_memory"), - "ephemeral-storage": namespaced.get("ephemeral-storage"), - } - limits = { - "cpu": namespaced.get("limit_cpu"), - "memory": namespaced.get("limit_memory"), - "ephemeral-storage": namespaced.get("ephemeral-storage"), - } - all_resources = list(requests.values()) + list(limits.values()) - if all(r is None for r in all_resources): - resources = None - else: - resources = k8s.V1ResourceRequirements(requests=requests, limits=limits) - namespaced["resources"] = resources - return PodGenerator(**namespaced).gen_pod() - - @staticmethod - def make_unique_pod_id(dag_id): - r""" - Generate a unique Pod name. - - Kubernetes pod names must be <= 253 chars and must pass the following regex for - validation - ``^[a-z0-9]([-a-z0-9]*[a-z0-9])?(\\.[a-z0-9]([-a-z0-9]*[a-z0-9])?)*$`` - - :param dag_id: a dag_id with only alphanumeric characters - :return: ``str`` valid Pod name of appropriate length - """ - if not dag_id: - return None - - safe_uuid = uuid.uuid4().hex - safe_pod_id = dag_id[: MAX_POD_ID_LEN - len(safe_uuid) - 1] + "-" + safe_uuid - - return safe_pod_id diff --git a/providers/src/airflow/providers/databricks/operators/databricks_workflow.py b/providers/src/airflow/providers/databricks/operators/databricks_workflow.py index d185d6d30cd37..12ea0792f0089 100644 --- a/providers/src/airflow/providers/databricks/operators/databricks_workflow.py +++ b/providers/src/airflow/providers/databricks/operators/databricks_workflow.py @@ -95,7 +95,7 @@ class _CreateDatabricksWorkflowOperator(BaseOperator): """ operator_extra_links = (WorkflowJobRunLink(), WorkflowJobRepairAllFailedLink()) - template_fields = ("notebook_params",) + template_fields = ("notebook_params", "job_clusters") caller = "_CreateDatabricksWorkflowOperator" def __init__( diff --git a/providers/src/airflow/providers/microsoft/azure/operators/msgraph.py b/providers/src/airflow/providers/microsoft/azure/operators/msgraph.py index 7aefd2971d4e6..8d2b1ae1eee9e 100644 --- a/providers/src/airflow/providers/microsoft/azure/operators/msgraph.py +++ b/providers/src/airflow/providers/microsoft/azure/operators/msgraph.py @@ -241,7 +241,6 @@ def pull_xcom(self, context: Context) -> list: key=self.key, task_ids=self.task_id, dag_id=self.dag_id, - map_indexes=map_index, ) or [] ) diff --git a/providers/src/airflow/providers/smtp/CHANGELOG.rst b/providers/src/airflow/providers/smtp/CHANGELOG.rst index 038fe52acaa0a..e6d91b7c7abf4 100644 --- a/providers/src/airflow/providers/smtp/CHANGELOG.rst +++ b/providers/src/airflow/providers/smtp/CHANGELOG.rst @@ -27,6 +27,28 @@ Changelog --------- + +2.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ +.. warning:: + The argument ``from_email`` is now an optional kwarg in ``SmtpNotifier``, and the argument ``to`` became the first + positional argument. + + Configuring the ``SmtpNotifier`` and ``SmtpHook`` default values via Airflow SMTP configurations is not supported + anymore. You can instead use the SMTP connection configuration to set the default values, where you can use: + + * the connection extra field ``ssl_context`` instead of the configuration ``smtp_provider.ssl_context`` or + ``email.ssl_context`` in the SMTP hook. + * the connection extra field ``from_email`` instead of the configuration ``smtp.smtp_mail_from`` in ``SmtpNotifier``. + * the connection extra field ``subject_template`` instead of the configuration ``smtp.templated_email_subject_path`` + in ``SmtpNotifier``. + * the connection extra field ``html_content_template`` instead of the configuration + ``smtp.templated_html_content_path`` in ``SmtpNotifier``. + + 1.9.0 ..... diff --git a/providers/src/airflow/providers/smtp/hooks/smtp.py b/providers/src/airflow/providers/smtp/hooks/smtp.py index 74707fad5f0a7..991dea54867c9 100644 --- a/providers/src/airflow/providers/smtp/hooks/smtp.py +++ b/providers/src/airflow/providers/smtp/hooks/smtp.py @@ -114,25 +114,15 @@ def _build_client(self) -> smtplib.SMTP_SSL | smtplib.SMTP: smtp_kwargs["timeout"] = self.timeout if self.use_ssl: - from airflow.configuration import conf - - extra_ssl_context = self.conn.extra_dejson.get("ssl_context", None) - if extra_ssl_context: - ssl_context_string = extra_ssl_context - else: - ssl_context_string = conf.get("smtp_provider", "SSL_CONTEXT", fallback=None) - if ssl_context_string is None: - ssl_context_string = conf.get("email", "SSL_CONTEXT", fallback=None) - if ssl_context_string is None: - ssl_context_string = "default" - if ssl_context_string == "default": + ssl_context_string = self.ssl_context + if ssl_context_string is None or ssl_context_string == "default": ssl_context = ssl.create_default_context() elif ssl_context_string == "none": ssl_context = None else: raise RuntimeError( - f"The email.ssl_context configuration variable must " - f"be set to 'default' or 'none' and is '{ssl_context_string}'." + f"The connection extra field `ssl_context` must " + f"be set to 'default' or 'none' but it is set to '{ssl_context_string}'." ) smtp_kwargs["context"] = ssl_context return SMTP(**smtp_kwargs) @@ -411,6 +401,10 @@ def subject_template(self) -> str | None: def html_content_template(self) -> str | None: return self.conn.extra_dejson.get("html_content_template") + @property + def ssl_context(self) -> str | None: + return self.conn.extra_dejson.get("ssl_context") + @staticmethod def _read_template(template_path: str) -> str: """ diff --git a/providers/src/airflow/providers/smtp/notifications/smtp.py b/providers/src/airflow/providers/smtp/notifications/smtp.py index 01ca19f80e7a2..85f71172d2334 100644 --- a/providers/src/airflow/providers/smtp/notifications/smtp.py +++ b/providers/src/airflow/providers/smtp/notifications/smtp.py @@ -22,7 +22,6 @@ from pathlib import Path from typing import Any -from airflow.configuration import conf from airflow.notifications.basenotifier import BaseNotifier from airflow.providers.smtp.hooks.smtp import SmtpHook @@ -67,10 +66,8 @@ class SmtpNotifier(BaseNotifier): def __init__( self, - # TODO: Move from_email to keyword parameter in next major release so that users do not - # need to specify from_email. No argument here will lead to defaults from conf being used. - from_email: str | None, to: str | Iterable[str], + from_email: str | None = None, subject: str | None = None, html_content: str | None = None, files: list[str] | None = None, @@ -85,7 +82,7 @@ def __init__( ): super().__init__() self.smtp_conn_id = smtp_conn_id - self.from_email = from_email or conf.get("smtp", "smtp_mail_from") + self.from_email = from_email self.to = to self.files = files self.cc = cc @@ -110,16 +107,20 @@ def hook(self) -> SmtpHook: def notify(self, context): """Send a email via smtp server.""" fields_to_re_render = [] + if self.from_email is None: + if self.hook.from_email is not None: + self.from_email = self.hook.from_email + else: + raise ValueError("You should provide `from_email` or define it in the connection") + fields_to_re_render.append("from_email") if self.subject is None: smtp_default_templated_subject_path: str if self.hook.subject_template: smtp_default_templated_subject_path = self.hook.subject_template else: - smtp_default_templated_subject_path = conf.get( - "smtp", - "templated_email_subject_path", - fallback=(Path(__file__).parent / "templates" / "email_subject.jinja2").as_posix(), - ) + smtp_default_templated_subject_path = ( + Path(__file__).parent / "templates" / "email_subject.jinja2" + ).as_posix() self.subject = self._read_template(smtp_default_templated_subject_path) fields_to_re_render.append("subject") if self.html_content is None: @@ -127,11 +128,9 @@ def notify(self, context): if self.hook.html_content_template: smtp_default_templated_html_content_path = self.hook.html_content_template else: - smtp_default_templated_html_content_path = conf.get( - "smtp", - "templated_html_content_path", - fallback=(Path(__file__).parent / "templates" / "email.html").as_posix(), - ) + smtp_default_templated_html_content_path = ( + Path(__file__).parent / "templates" / "email.html" + ).as_posix() self.html_content = self._read_template(smtp_default_templated_html_content_path) fields_to_re_render.append("html_content") if fields_to_re_render: diff --git a/providers/src/airflow/providers/smtp/provider.yaml b/providers/src/airflow/providers/smtp/provider.yaml index c25064ae74c57..9face29ef6318 100644 --- a/providers/src/airflow/providers/smtp/provider.yaml +++ b/providers/src/airflow/providers/smtp/provider.yaml @@ -69,43 +69,3 @@ connection-types: notifications: - airflow.providers.smtp.notifications.smtp.SmtpNotifier - -config: - smtp_provider: - description: "Options for SMTP provider." - options: - ssl_context: - description: | - ssl context to use when using SMTP and IMAP SSL connections. By default, the context is "default" - which sets it to ``ssl.create_default_context()`` which provides the right balance between - compatibility and security, it however requires that certificates in your operating system are - updated and that SMTP/IMAP servers of yours have valid certificates that have corresponding public - keys installed on your machines. You can switch it to "none" if you want to disable checking - of the certificates, but it is not recommended as it allows MITM (man-in-the-middle) attacks - if your infrastructure is not sufficiently secured. It should only be set temporarily while you - are fixing your certificate configuration. This can be typically done by upgrading to newer - version of the operating system you run Airflow components on,by upgrading/refreshing proper - certificates in the OS or by updating certificates for your mail servers. - - If you do not set this option explicitly, it will use Airflow "email.ssl_context" configuration, - but if this configuration is not present, it will use "default" value. - type: string - version_added: 1.3.0 - example: "default" - default: ~ - templated_email_subject_path: - description: | - Allows overriding of the standard templated email subject line when the SmtpNotifier is used. - Must provide a path to the template. - type: string - version_added: 1.6.1 - example: "path/to/override/email_subject.html" - default: ~ - templated_html_content_path: - description: | - Allows overriding of the standard templated email path when the SmtpNotifier is used. Must provide - a path to the template. - type: string - version_added: 1.6.1 - example: "path/to/override/email.html" - default: ~ diff --git a/providers/src/airflow/providers/yandex/.latest-doc-only-change.txt b/providers/src/airflow/providers/yandex/.latest-doc-only-change.txt deleted file mode 100644 index 2a10512ecd4bf..0000000000000 --- a/providers/src/airflow/providers/yandex/.latest-doc-only-change.txt +++ /dev/null @@ -1 +0,0 @@ -857ca4c06c9008593674cabdd28d3c30e3e7f97b diff --git a/providers/src/airflow/providers/yandex/CHANGELOG.rst b/providers/src/airflow/providers/yandex/CHANGELOG.rst deleted file mode 100644 index e71425d305686..0000000000000 --- a/providers/src/airflow/providers/yandex/CHANGELOG.rst +++ /dev/null @@ -1,454 +0,0 @@ - .. Licensed to the Apache Software Foundation (ASF) under one - or more contributor license agreements. See the NOTICE file - distributed with this work for additional information - regarding copyright ownership. The ASF licenses this file - to you under the Apache License, Version 2.0 (the - "License"); you may not use this file except in compliance - with the License. You may obtain a copy of the License at - - .. http://www.apache.org/licenses/LICENSE-2.0 - - .. Unless required by applicable law or agreed to in writing, - software distributed under the License is distributed on an - "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - KIND, either express or implied. See the License for the - specific language governing permissions and limitations - under the License. - - -.. NOTE TO CONTRIBUTORS: - Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes - and you want to add an explanation to the users on how they are supposed to deal with them. - The changelog is updated and maintained semi-automatically by release manager. - -``apache-airflow-providers-yandex`` - - -Changelog ---------- - -4.0.0 -..... - -.. note:: - This release of provider is only available for Airflow 2.9+ as explained in the - `Apache Airflow providers support policy `_. - -Breaking changes -~~~~~~~~~~~~~~~~ - -.. warning:: - All deprecated classes, parameters and features have been removed from the {provider_name} provider package. - The following breaking changes were introduced: - - * removed ``YandexCloudBaseHook.provider_user_agent`` . Use ``utils.user_agent.provider_user_agent`` instead. - * removed ``connection_id`` parameter from ``YandexCloudBaseHook``. Use ``yandex_conn_id`` parameter. - * removed ``yandex.hooks.yandexcloud_dataproc`` module. - * removed ``yandex.operators.yandexcloud_dataproc`` module. - * removed implicit passing of ``yandex_conn_id`` in ``DataprocBaseOperator``. Please pass it as a parameter. - -* ``Remove Provider Deprecations in Yandex provider (#44754)`` - -Misc -~~~~ - -* ``Bump minimum Airflow version in providers to Airflow 2.9.0 (#44956)`` -* ``Update DAG example links in multiple providers documents (#44034)`` - - -.. Below changes are excluded from the changelog. Move them to - appropriate section above if needed. Do not delete the lines(!): - * ``Use Python 3.9 as target version for Ruff & Black rules (#44298)`` - * ``Prepare docs for Nov 1st wave of providers (#44011)`` - * ``Split providers out of the main "airflow/" tree into a UV workspace project (#42505)`` - -.. Review and move the new changes to one of the sections above: - * ``Update path of example dags in docs (#45069)`` - -3.12.0 -...... - -.. note:: - This release of provider is only available for Airflow 2.8+ as explained in the - `Apache Airflow providers support policy `_. - -Bug Fixes -~~~~~~~~~ - -* ``providers/yandex: fix typing (#40997)`` - -Misc -~~~~ - -* ``Bump minimum Airflow version in providers to Airflow 2.8.0 (#41396)`` - -.. Below changes are excluded from the changelog. Move them to - appropriate section above if needed. Do not delete the lines(!): - * ``Prepare docs for Aug 1st wave of providers (#41230)`` - * ``Prepare docs 1st wave July 2024 (#40644)`` - * ``Enable enforcing pydocstyle rule D213 in ruff. (#40448)`` - -3.11.2 -...... - -Bug Fixes -~~~~~~~~~ - -* ``Exclude yandex versions 0.289.0, 0.290.0 (#39974)`` - -Misc -~~~~ - -* ``Fix typos in Providers docs and Yandex hook (#40277)`` - -.. Below changes are excluded from the changelog. Move them to - appropriate section above if needed. Do not delete the lines(!): - * ``Limit yandex provider to avoid mypy errors (#39990)`` - * ``Workaround new yandexcloud breaking dataproc integration (#39964)`` - -3.11.1 -...... - -Misc -~~~~ - -* `` AIP-21: yandexcloud: rename files, emit deprecation warning (#39618)`` -* ``yandex provider: bump version for yq http client package (#39548)`` -* ``Faster 'airflow_version' imports (#39552)`` -* ``add doc about Yandex Query operator (#39445)`` -* ``Simplify 'airflow_version' imports (#39497)`` - -.. Below changes are excluded from the changelog. Move them to - appropriate section above if needed. Do not delete the lines(!): - * ``Reapply templates for all providers (#39554)`` - -3.11.0 -...... - -.. note:: - This release of provider is only available for Airflow 2.7+ as explained in the - `Apache Airflow providers support policy `_. - -Misc -~~~~ - -* ``Bump minimum Airflow version in providers to Airflow 2.7.0 (#39240)`` - -3.10.0 -...... - -Features -~~~~~~~~ - -* ``Add Yandex Query support from Yandex.Cloud (#37458)`` - -Misc -~~~~ - -* ``support iam token from metadata, simplify code (#38411)`` -* ``Avoid use of 'assert' outside of the tests (#37718)`` - -.. Below changes are excluded from the changelog. Move them to - appropriate section above if needed. Do not delete the lines(!): - * ``Prepare docs 1st wave (RC1) April 2024 (#38863)`` - * ``docs: yandex provider grammatical improvements (#38589)`` - * ``Bump ruff to 0.3.3 (#38240)`` - * ``Prepare docs 1st wave (RC1) March 2024 (#37876)`` - * ``Add comment about versions updated by release manager (#37488)`` - -3.9.0 -..... - -Features -~~~~~~~~ - -* ``Add secrets-backends section into the Yandex provider yaml definition (#37065)`` - -Bug Fixes -~~~~~~~~~ - -* ``fix: using endpoint from connection if not specified (#37076)`` - -.. Below changes are excluded from the changelog. Move them to - appropriate section above if needed. Do not delete the lines(!): - * ``D401 Support in Providers (simple) (#37258)`` - * ``docs: update description in airflow provider.yaml (#37096)`` - -3.8.0 -..... - -Features -~~~~~~~~ - -* ``feat: add Yandex Cloud Lockbox secrets backend (#36449)`` - - -Bug Fixes -~~~~~~~~~ - -* ``Fix stacklevel in warnings.warn into the providers (#36831)`` - -.. Below changes are excluded from the changelog. Move them to - appropriate section above if needed. Do not delete the lines(!): - * ``Prepare docs 1st wave of Providers January 2024 (#36640)`` - * ``Speed up autocompletion of Breeze by simplifying provider state (#36499)`` - * ``Provide the logger_name param in providers hooks in order to override the logger name (#36675)`` - * ``Revert "Provide the logger_name param in providers hooks in order to override the logger name (#36675)" (#37015)`` - * ``Prepare docs 2nd wave of Providers January 2024 (#36945)`` - -3.7.1 -..... - -Bug Fixes -~~~~~~~~~ - -* ``Follow BaseHook connection fields method signature in child classes (#36086)`` - -.. Below changes are excluded from the changelog. Move them to - appropriate section above if needed. Do not delete the lines(!): - -3.7.0 -..... - -.. note:: - This release of provider is only available for Airflow 2.6+ as explained in the - `Apache Airflow providers support policy `_. - -Misc -~~~~ - -* ``Bump minimum Airflow version in providers to Airflow 2.6.0 (#36017)`` - -.. Below changes are excluded from the changelog. Move them to - appropriate section above if needed. Do not delete the lines(!): - * ``Fix and reapply templates for provider documentation (#35686)`` - * ``Prepare docs 2nd wave of Providers November 2023 (#35836)`` - * ``Use reproducible builds for provider packages (#35693)`` - -3.6.0 -..... - -Features -~~~~~~~~ - -* ``Yandex dataproc deduce default service account (#35059)`` - -.. Below changes are excluded from the changelog. Move them to - appropriate section above if needed. Do not delete the lines(!): - * ``Prepare docs 3rd wave of Providers October 2023 - FIX (#35233)`` - * ``Prepare docs 3rd wave of Providers October 2023 (#35187)`` - * ``Pre-upgrade 'ruff==0.0.292' changes in providers (#35053)`` - -3.5.0 -..... - -.. note:: - This release of provider is only available for Airflow 2.5+ as explained in the - `Apache Airflow providers support policy `_. - -Misc -~~~~ - -* ``Bump min airflow version of providers (#34728)`` - -3.4.0 -..... - -.. note:: - This release dropped support for Python 3.7 - -Features -~~~~~~~~ - -* ``add support for Yandex Dataproc cluster labels (#29811)`` - -.. Below changes are excluded from the changelog. Move them to - appropriate section above if needed. Do not delete the lines(!): - * ``Add note about dropping Python 3.7 for providers (#32015)`` - * ``Add D400 pydocstyle check - Providers (#31427)`` - * ``Add full automation for min Airflow version for providers (#30994)`` - * ``Add mechanism to suspend providers (#30422)`` - * ``Resume yandex provider (#33574)`` - * ``Remove spurious headers for provider changelogs (#32373)`` - * ``Improve provider documentation and README structure (#32125)`` - * ``Use '__version__' in providers not 'version' (#31393)`` - * ``Use 'AirflowProviderDeprecationWarning' in providers (#30975)`` - * ``Bump minimum Airflow version in providers (#30917)`` - * ``Suspend Yandex provider due to protobuf limitation (#30667)`` - -3.3.0 -..... - -Features -~~~~~~~~ - -* ``support Yandex SDK feature "endpoint" (#29635)`` - -3.2.0 -..... - -.. note:: - This release of provider is only available for Airflow 2.3+ as explained in the - Apache Airflow providers support policy `_. - -Features -~~~~~~~~ - -* In YandexCloudBaseHook, non-prefixed extra fields are supported and are preferred (#27040). E.g. ``folder_id`` will be preferred if ``extra__yandexcloud__folder_id`` is also present. - -Misc -~~~~ - -* ``Move min airflow version to 2.3.0 for all providers (#27196)`` - -.. Below changes are excluded from the changelog. Move them to - appropriate section above if needed. Do not delete the lines(!): - * ``Enable string normalization in python formatting - providers (#27205)`` - * ``Update docs for September Provider's release (#26731)`` - * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)`` - * ``pRepare docs for November 2022 wave of Providers (#27613)`` - * ``Prepare for follow-up release for November providers (#27774)`` - -3.1.0 -..... - -Features -~~~~~~~~ - -* ``YandexCloud provider: Support new Yandex SDK features for DataProc (#25158)`` - -.. Below changes are excluded from the changelog. Move them to - appropriate section above if needed. Do not delete the lines(!): - * ``Add documentation for July 2022 Provider's release (#25030)`` - * ``Move provider dependencies to inside provider folders (#24672)`` - * ``Remove 'hook-class-names' from provider.yaml (#24702)`` - -3.0.0 -..... - -Breaking changes -~~~~~~~~~~~~~~~~ - -.. note:: - This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow - Apache Airflow providers support policy `_. - -.. Below changes are excluded from the changelog. Move them to - appropriate section above if needed. Do not delete the lines(!): - * ``Migrate Yandex example DAGs to new design AIP-47 (#24082)`` - * ``Add explanatory note for contributors about updating Changelog (#24229)`` - * ``Prepare docs for May 2022 provider's release (#24231)`` - * ``Update package description to remove double min-airflow specification (#24292)`` - -2.2.3 -..... - -Bug Fixes -~~~~~~~~~ - -* ``Fix mistakenly added install_requires for all providers (#22382)`` - -2.2.2 -..... - -Misc -~~~~~ - -* ``Add Trove classifiers in PyPI (Framework :: Apache Airflow :: Provider)`` - -2.2.1 -..... - -Misc -~~~~ - -* ``Support for Python 3.10`` - -.. Below changes are excluded from the changelog. Move them to - appropriate section above if needed. Do not delete the lines(!): - * ``Fixed changelog for January 2022 (delayed) provider's release (#21439)`` - * ``Add documentation for January 2021 providers release (#21257)`` - * ``Add optional features in providers. (#21074)`` - * ``Remove ':type' lines now sphinx-autoapi supports typehints (#20951)`` - * ``Fix spelling (#22054)`` - -2.2.0 -..... - -Features -~~~~~~~~ - -* ``YandexCloud provider: Support new Yandex SDK features: log_group_id, user-agent, maven packages (#20103)`` - - -.. Below changes are excluded from the changelog. Move them to - appropriate section above if needed. Do not delete the lines(!): - * ``Fix mypy for providers: elasticsearch, oracle, yandex (#20344)`` - * ``Fixup string concatenations (#19099)`` - * ``Update documentation for November 2021 provider's release (#19882)`` - * ``Prepare documentation for October Provider's release (#19321)`` - * ``Update documentation for September providers release (#18613)`` - * ``Static start_date and default arg cleanup for misc. provider example DAGs (#18597)`` - * ``Inclusive Language (#18349)`` - * ``Use typed Context EVERYWHERE (#20565)`` - * ``Fix template_fields type to have MyPy friendly Sequence type (#20571)`` - * ``Update documentation for provider December 2021 release (#20523)`` - -2.1.0 -..... - -Misc -~~~~ - -* ``Optimise connection importing for Airflow 2.2.0`` - - -Features -~~~~~~~~ - -* ``Add autoscaling subcluster support and remove defaults (#17033)`` - - -.. Below changes are excluded from the changelog. Move them to - appropriate section above if needed. Do not delete the lines(!): - * ``Update description about the new ''connection-types'' provider meta-data (#17767)`` - * ``Import Hooks lazily individually in providers manager (#17682)`` - * ``Prepares docs for Rc2 release of July providers (#17116)`` - * ``Remove/refactor default_args pattern for miscellaneous providers (#16872)`` - * ``Prepare documentation for July release of providers. (#17015)`` - * ``Removes pylint from our toolchain (#16682)`` - -2.0.0 -..... - -Breaking changes -~~~~~~~~~~~~~~~~ - -* ``Auto-apply apply_default decorator (#15667)`` - -.. warning:: Due to apply_default decorator removal, this version of the provider requires Airflow 2.1.0+. - If your Airflow version is < 2.1.0, and you want to install this provider version, first upgrade - Airflow to at least version 2.1.0. Otherwise your Airflow package version will be upgraded - automatically and you will have to manually run ``airflow upgrade db`` to complete the migration. - -.. Below changes are excluded from the changelog. Move them to - appropriate section above if needed. Do not delete the lines(!): - * ``Adds interactivity when generating provider documentation. (#15518)`` - * ``Prepares provider release after PIP 21 compatibility (#15576)`` - * ``Update docstrings to adhere to sphinx standards (#14918)`` - * ``Remove Backport Providers (#14886)`` - * ``Update documentation for broken package releases (#14734)`` - * ``Updated documentation for June 2021 provider release (#16294)`` - * ``Fix Sphinx Issues with Docstrings (#14968)`` - * ``More documentation update for June providers release (#16405)`` - * ``Synchronizes updated changelog after buggfix release (#16464)`` - -1.0.1 -..... - -Updated documentation and readme files. - -1.0.0 -..... - -Initial version of the provider. diff --git a/providers/standard/tests/provider_tests/standard/decorators/test_python.py b/providers/standard/tests/provider_tests/standard/decorators/test_python.py index 84a58faae72e2..ae661a17dd583 100644 --- a/providers/standard/tests/provider_tests/standard/decorators/test_python.py +++ b/providers/standard/tests/provider_tests/standard/decorators/test_python.py @@ -18,7 +18,7 @@ import sys import typing from collections import namedtuple -from datetime import date, timedelta +from datetime import date from typing import Union import pytest @@ -26,15 +26,10 @@ from airflow.decorators import setup, task as task_decorator, teardown from airflow.decorators.base import DecoratedMappedOperator from airflow.exceptions import AirflowException, XComNotFound -from airflow.models.baseoperator import BaseOperator -from airflow.models.dag import DAG -from airflow.models.expandinput import DictOfListsExpandInput from airflow.models.taskinstance import TaskInstance from airflow.models.taskmap import TaskMap -from airflow.models.xcom_arg import PlainXComArg, XComArg from airflow.utils import timezone from airflow.utils.state import State -from airflow.utils.task_group import TaskGroup from airflow.utils.task_instance_session import set_current_task_instance_session from airflow.utils.trigger_rule import TriggerRule from airflow.utils.types import DagRunType @@ -44,10 +39,17 @@ from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS if AIRFLOW_V_3_0_PLUS: + from airflow.sdk import DAG, BaseOperator, TaskGroup, XComArg + from airflow.sdk.definitions._internal.expandinput import DictOfListsExpandInput from airflow.sdk.definitions.mappedoperator import MappedOperator from airflow.utils.types import DagRunTriggeredByType else: + from airflow.models.baseoperator import BaseOperator + from airflow.models.dag import DAG # type: ignore[assignment] + from airflow.models.expandinput import DictOfListsExpandInput from airflow.models.mappedoperator import MappedOperator + from airflow.models.xcom_arg import XComArg + from airflow.utils.task_group import TaskGroup pytestmark = pytest.mark.db_test @@ -733,6 +735,11 @@ def double(number: int): def test_partial_mapped_decorator() -> None: + if AIRFLOW_V_3_0_PLUS: + from airflow.sdk.definitions.xcom_arg import PlainXComArg + else: + from airflow.models.xcom_arg import PlainXComArg # type: ignore[attr-defined, no-redef] + @task_decorator def product(number: int, multiple: int): return number * multiple @@ -789,38 +796,12 @@ def task2(arg1, arg2): ... dec = run.task_instance_scheduling_decisions(session=session) assert [ti.task_id for ti in dec.schedulable_tis] == ["task2"] ti = dec.schedulable_tis[0] - unmapped = ti.task.unmap((ti.get_template_context(session), session)) - assert set(unmapped.op_kwargs) == {"arg1", "arg2"} - - -def test_mapped_decorator_converts_partial_kwargs(dag_maker, session): - with dag_maker(session=session): - - @task_decorator - def task1(arg): - return ["x" * arg] - - @task_decorator(retry_delay=30) - def task2(arg1, arg2): ... - task2.partial(arg1=1).expand(arg2=task1.expand(arg=[1, 2])) - - run = dag_maker.create_dagrun() - - # Expand and run task1. - dec = run.task_instance_scheduling_decisions(session=session) - assert [ti.task_id for ti in dec.schedulable_tis] == ["task1", "task1"] - for ti in dec.schedulable_tis: - ti.run(session=session) - assert not isinstance(ti.task, MappedOperator) - assert ti.task.retry_delay == timedelta(seconds=300) # Operator default. - - # Expand task2. - dec = run.task_instance_scheduling_decisions(session=session) - assert [ti.task_id for ti in dec.schedulable_tis] == ["task2", "task2"] - for ti in dec.schedulable_tis: + if AIRFLOW_V_3_0_PLUS: + unmapped = ti.task.unmap((ti.get_template_context(session),)) + else: unmapped = ti.task.unmap((ti.get_template_context(session), session)) - assert unmapped.retry_delay == timedelta(seconds=30) + assert set(unmapped.op_kwargs) == {"arg1", "arg2"} def test_mapped_render_template_fields(dag_maker, session): diff --git a/providers/standard/tests/provider_tests/standard/operators/test_datetime.py b/providers/standard/tests/provider_tests/standard/operators/test_datetime.py index b15ab5a117d7a..36df3942e9402 100644 --- a/providers/standard/tests/provider_tests/standard/operators/test_datetime.py +++ b/providers/standard/tests/provider_tests/standard/operators/test_datetime.py @@ -74,13 +74,13 @@ def base_tests_setup(self, dag_maker): self.branch_1.set_upstream(self.branch_op) self.branch_2.set_upstream(self.branch_op) - self.dr = dag_maker.create_dagrun( - run_id="manual__", - start_date=DEFAULT_DATE, - logical_date=DEFAULT_DATE, - state=State.RUNNING, - data_interval=(DEFAULT_DATE, DEFAULT_DATE), - ) + self.dr = dag_maker.create_dagrun( + run_id="manual__", + start_date=DEFAULT_DATE, + logical_date=DEFAULT_DATE, + state=State.RUNNING, + data_interval=(DEFAULT_DATE, DEFAULT_DATE), + ) def teardown_method(self): with create_session() as session: diff --git a/providers/standard/tests/provider_tests/standard/operators/test_python.py b/providers/standard/tests/provider_tests/standard/operators/test_python.py index d87de0f62dfb4..9f0f3515c696a 100644 --- a/providers/standard/tests/provider_tests/standard/operators/test_python.py +++ b/providers/standard/tests/provider_tests/standard/operators/test_python.py @@ -42,7 +42,6 @@ from slugify import slugify from airflow.config_templates.airflow_local_settings import DEFAULT_LOGGING_CONFIG -from airflow.decorators import task_group from airflow.exceptions import ( AirflowException, DeserializingResultError, @@ -110,8 +109,12 @@ def base_tests_setup(self, request, create_serialized_task_instance_of_operator, self.run_id = f"run_{slugify(request.node.name, max_length=40)}" self.ds_templated = self.default_date.date().isoformat() self.ti_maker = create_serialized_task_instance_of_operator + self.dag_maker = dag_maker self.dag_non_serialized = self.dag_maker(self.dag_id, template_searchpath=TEMPLATE_SEARCHPATH).dag + # We need to entre the context in order to the factory to create things + with self.dag_maker: + ... clear_db_runs() yield clear_db_runs() @@ -138,6 +141,10 @@ def default_kwargs(**kwargs): return kwargs def create_dag_run(self) -> DagRun: + from airflow.models.serialized_dag import SerializedDagModel + + # Update the serialized DAG with any tasks added after initial dag was created + self.dag_maker.serialized_model = SerializedDagModel(self.dag_non_serialized) return self.dag_maker.create_dagrun( state=DagRunState.RUNNING, start_date=self.dag_maker.start_date, @@ -753,39 +760,6 @@ def test_xcom_push_skipped_tasks(self): "skipped": ["empty_task"] } - def test_mapped_xcom_push_skipped_tasks(self, session): - with self.dag_non_serialized: - - @task_group - def group(x): - short_op_push_xcom = ShortCircuitOperator( - task_id="push_xcom_from_shortcircuit", - python_callable=lambda arg: arg % 2 == 0, - op_kwargs={"arg": x}, - ) - empty_task = EmptyOperator(task_id="empty_task") - short_op_push_xcom >> empty_task - - group.expand(x=[0, 1]) - dr = self.create_dag_run() - decision = dr.task_instance_scheduling_decisions(session=session) - for ti in decision.schedulable_tis: - ti.run() - # dr.run(start_date=self.default_date, end_date=self.default_date) - tis = dr.get_task_instances() - - assert ( - tis[0].xcom_pull(task_ids="group.push_xcom_from_shortcircuit", key="return_value", map_indexes=0) - is True - ) - assert ( - tis[0].xcom_pull(task_ids="group.push_xcom_from_shortcircuit", key="skipmixin_key", map_indexes=0) - is None - ) - assert tis[0].xcom_pull( - task_ids="group.push_xcom_from_shortcircuit", key="skipmixin_key", map_indexes=1 - ) == {"skipped": ["group.empty_task"]} - virtualenv_string_args: list[str] = [] diff --git a/providers/tests/cncf/kubernetes/resource_convert/__init__.py b/providers/tests/amazon/aws/auth_manager/router/__init__.py similarity index 100% rename from providers/tests/cncf/kubernetes/resource_convert/__init__.py rename to providers/tests/amazon/aws/auth_manager/router/__init__.py diff --git a/providers/tests/amazon/aws/auth_manager/views/test_auth.py b/providers/tests/amazon/aws/auth_manager/router/test_login.py similarity index 57% rename from providers/tests/amazon/aws/auth_manager/views/test_auth.py rename to providers/tests/amazon/aws/auth_manager/router/test_login.py index 2521dd9e43e6c..f68eb7b7fe23e 100644 --- a/providers/tests/amazon/aws/auth_manager/views/test_auth.py +++ b/providers/tests/amazon/aws/auth_manager/router/test_login.py @@ -19,11 +19,16 @@ from unittest.mock import Mock, patch import pytest -from flask import session, url_for -from airflow.exceptions import AirflowException from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS -from airflow.www import app as application + +if not AIRFLOW_V_3_0_PLUS: + pytest.skip("AWS auth manager is only compatible with Airflow >= 3.0.0", allow_module_level=True) + +from fastapi.testclient import TestClient +from onelogin.saml2.idp_metadata_parser import OneLogin_Saml2_IdPMetadataParser + +from airflow.api_fastapi.app import create_app from tests_common.test_utils.config import conf_vars @@ -47,7 +52,7 @@ @pytest.fixture -def aws_app(): +def test_client(): with conf_vars( { ( @@ -58,42 +63,26 @@ def aws_app(): } ): with ( - patch( - "airflow.providers.amazon.aws.auth_manager.views.auth.OneLogin_Saml2_IdPMetadataParser" - ) as mock_parser, + patch.object(OneLogin_Saml2_IdPMetadataParser, "parse_remote") as mock_parse_remote, patch( "airflow.providers.amazon.aws.auth_manager.avp.facade.AwsAuthManagerAmazonVerifiedPermissionsFacade.is_policy_store_schema_up_to_date" ) as mock_is_policy_store_schema_up_to_date, ): mock_is_policy_store_schema_up_to_date.return_value = True - mock_parser.parse_remote.return_value = SAML_METADATA_PARSED - return application.create_app(testing=True, config={"WTF_CSRF_ENABLED": False}) - - -@pytest.mark.skipif( - not AIRFLOW_V_3_0_PLUS, reason="AWS auth manager is only compatible with Airflow >= 3.0.0" -) -@pytest.mark.db_test -class TestAwsAuthManagerAuthenticationViews: - def test_login_redirect_to_identity_center(self, aws_app): - with aws_app.test_client() as client: - response = client.get("/login") - assert response.status_code == 302 - assert response.location.startswith("https://portal.sso.us-east-1.amazonaws.com/saml/assertion/") - - def test_logout_redirect_to_identity_center(self, aws_app): - with aws_app.test_client() as client: - response = client.post("/logout") - assert response.status_code == 302 - assert response.location.startswith("https://portal.sso.us-east-1.amazonaws.com/saml/logout/") - - def test_login_metadata_return_xml_file(self, aws_app): - with aws_app.test_client() as client: - response = client.get("/login_metadata") - assert response.status_code == 200 - assert response.headers["Content-Type"] == "text/xml" - - def test_login_callback_set_user_in_session(self): + mock_parse_remote.return_value = SAML_METADATA_PARSED + yield TestClient(create_app()) + + +class TestLoginRouter: + def test_login(self, test_client): + response = test_client.get("/auth/login", follow_redirects=False) + assert response.status_code == 307 + assert "location" in response.headers + assert response.headers["location"].startswith( + "https://portal.sso.us-east-1.amazonaws.com/saml/assertion/" + ) + + def test_login_callback_successful(self): with conf_vars( { ( @@ -104,18 +93,16 @@ def test_login_callback_set_user_in_session(self): } ): with ( + patch.object(OneLogin_Saml2_IdPMetadataParser, "parse_remote") as mock_parse_remote, patch( - "airflow.providers.amazon.aws.auth_manager.views.auth.OneLogin_Saml2_IdPMetadataParser" - ) as mock_parser, - patch( - "airflow.providers.amazon.aws.auth_manager.views.auth.AwsAuthManagerAuthenticationViews._init_saml_auth" + "airflow.providers.amazon.aws.auth_manager.router.login._init_saml_auth" ) as mock_init_saml_auth, patch( "airflow.providers.amazon.aws.auth_manager.avp.facade.AwsAuthManagerAmazonVerifiedPermissionsFacade.is_policy_store_schema_up_to_date" ) as mock_is_policy_store_schema_up_to_date, ): mock_is_policy_store_schema_up_to_date.return_value = True - mock_parser.parse_remote.return_value = SAML_METADATA_PARSED + mock_parse_remote.return_value = SAML_METADATA_PARSED auth = Mock() auth.is_authenticated.return_value = True @@ -126,16 +113,13 @@ def test_login_callback_set_user_in_session(self): "email": ["email"], } mock_init_saml_auth.return_value = auth - app = application.create_app(testing=True) - with app.test_client() as client: - response = client.get("/login_callback") - assert response.status_code == 302 - assert response.location == url_for("Airflow.index") - assert session["aws_user"] is not None - assert session["aws_user"].get_id() == "1" - assert session["aws_user"].get_name() == "user_id" - - def test_login_callback_raise_exception_if_errors(self): + client = TestClient(create_app()) + response = client.post("/auth/login_callback", follow_redirects=False) + assert response.status_code == 303 + assert "location" in response.headers + assert response.headers["location"].startswith("/webapp?token=") + + def test_login_callback_unsuccessful(self): with conf_vars( { ( @@ -146,28 +130,20 @@ def test_login_callback_raise_exception_if_errors(self): } ): with ( + patch.object(OneLogin_Saml2_IdPMetadataParser, "parse_remote") as mock_parse_remote, patch( - "airflow.providers.amazon.aws.auth_manager.views.auth.OneLogin_Saml2_IdPMetadataParser" - ) as mock_parser, - patch( - "airflow.providers.amazon.aws.auth_manager.views.auth.AwsAuthManagerAuthenticationViews._init_saml_auth" + "airflow.providers.amazon.aws.auth_manager.router.login._init_saml_auth" ) as mock_init_saml_auth, patch( "airflow.providers.amazon.aws.auth_manager.avp.facade.AwsAuthManagerAmazonVerifiedPermissionsFacade.is_policy_store_schema_up_to_date" ) as mock_is_policy_store_schema_up_to_date, ): mock_is_policy_store_schema_up_to_date.return_value = True - mock_parser.parse_remote.return_value = SAML_METADATA_PARSED + mock_parse_remote.return_value = SAML_METADATA_PARSED auth = Mock() auth.is_authenticated.return_value = False mock_init_saml_auth.return_value = auth - app = application.create_app(testing=True) - with app.test_client() as client: - with pytest.raises(AirflowException): - client.get("/login_callback") - - def test_logout_callback_raise_not_implemented_error(self, aws_app): - with aws_app.test_client() as client: - with pytest.raises(NotImplementedError): - client.get("/logout_callback") + client = TestClient(create_app()) + response = client.post("/auth/login_callback") + assert response.status_code == 500 diff --git a/providers/tests/amazon/aws/auth_manager/test_aws_auth_manager.py b/providers/tests/amazon/aws/auth_manager/test_aws_auth_manager.py index 797f96ee993b9..797d6e0c504f7 100644 --- a/providers/tests/amazon/aws/auth_manager/test_aws_auth_manager.py +++ b/providers/tests/amazon/aws/auth_manager/test_aws_auth_manager.py @@ -38,11 +38,7 @@ VariableDetails, ) from airflow.providers.amazon.aws.auth_manager.avp.entities import AvpEntities -from airflow.providers.amazon.aws.auth_manager.avp.facade import AwsAuthManagerAmazonVerifiedPermissionsFacade from airflow.providers.amazon.aws.auth_manager.aws_auth_manager import AwsAuthManager -from airflow.providers.amazon.aws.auth_manager.security_manager.aws_security_manager_override import ( - AwsSecurityManagerOverride, -) from airflow.providers.amazon.aws.auth_manager.user import AwsAuthManagerUser from airflow.security.permissions import ( RESOURCE_AUDIT_LOG, @@ -54,7 +50,6 @@ from airflow.www.extensions.init_appbuilder import init_appbuilder from tests_common.test_utils.config import conf_vars -from tests_common.test_utils.www import check_content_in_response if TYPE_CHECKING: from airflow.auth.managers.base_auth_manager import ResourceMethod @@ -718,56 +713,9 @@ def test_filter_permitted_dag_ids(self, methods, user, auth_manager, test_user): auth_manager.avp_facade.get_batch_is_authorized_results.assert_called() assert result == {"dag_2"} - @patch("airflow.providers.amazon.aws.auth_manager.aws_auth_manager.url_for") - def test_get_url_login(self, mock_url_for, auth_manager): - auth_manager.get_url_login() - mock_url_for.assert_called_once_with("AwsAuthManagerAuthenticationViews.login") - - @patch("airflow.providers.amazon.aws.auth_manager.aws_auth_manager.url_for") - def test_get_url_logout(self, mock_url_for, auth_manager): - auth_manager.get_url_logout() - mock_url_for.assert_called_once_with("AwsAuthManagerAuthenticationViews.logout") - - @pytest.mark.db_test - def test_security_manager_return_default_security_manager(self, auth_manager_with_appbuilder): - assert isinstance(auth_manager_with_appbuilder.security_manager, AwsSecurityManagerOverride) + def test_get_url_login(self, auth_manager): + result = auth_manager.get_url_login() + assert result == "http://localhost:29091/auth/login" def test_get_cli_commands_return_cli_commands(self, auth_manager): assert len(auth_manager.get_cli_commands()) > 0 - - @pytest.mark.db_test - @patch( - "airflow.providers.amazon.aws.auth_manager.views.auth.conf.get_mandatory_value", return_value="test" - ) - def test_register_views(self, mock_get_mandatory_value, auth_manager_with_appbuilder): - from airflow.providers.amazon.aws.auth_manager.views.auth import AwsAuthManagerAuthenticationViews - - with patch.object(AwsAuthManagerAuthenticationViews, "idp_data"): - auth_manager_with_appbuilder.appbuilder.add_view_no_menu = Mock() - auth_manager_with_appbuilder.register_views() - auth_manager_with_appbuilder.appbuilder.add_view_no_menu.assert_called_once() - assert isinstance( - auth_manager_with_appbuilder.appbuilder.add_view_no_menu.call_args.args[0], - AwsAuthManagerAuthenticationViews, - ) - - @pytest.mark.db_test - @patch.object(AwsAuthManagerAmazonVerifiedPermissionsFacade, "get_batch_is_authorized_single_result") - @patch.object(AwsAuthManagerAmazonVerifiedPermissionsFacade, "get_batch_is_authorized_results") - @patch.object(AwsAuthManagerAmazonVerifiedPermissionsFacade, "is_authorized") - def test_aws_auth_manager_index( - self, - mock_is_authorized, - mock_get_batch_is_authorized_results, - mock_get_batch_is_authorized_single_result, - client_admin, - ): - """ - Load the index page using AWS auth manager. Mock all interactions with Amazon Verified Permissions. - """ - mock_is_authorized.return_value = True - mock_get_batch_is_authorized_results.return_value = [] - mock_get_batch_is_authorized_single_result.return_value = {"decision": "ALLOW"} - with client_admin.test_client() as client: - response = client.get("/login_callback", follow_redirects=True) - check_content_in_response("

DAGs

", response, 200) diff --git a/providers/tests/amazon/aws/links/test_ec2.py b/providers/tests/amazon/aws/links/test_ec2.py new file mode 100644 index 0000000000000..922b12275e5aa --- /dev/null +++ b/providers/tests/amazon/aws/links/test_ec2.py @@ -0,0 +1,59 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from airflow.providers.amazon.aws.links.ec2 import EC2InstanceDashboardLink, EC2InstanceLink + +from providers.tests.amazon.aws.links.test_base_aws import BaseAwsLinksTestCase + + +class TestEC2InstanceLink(BaseAwsLinksTestCase): + link_class = EC2InstanceLink + + INSTANCE_ID = "i-xxxxxxxxxxxx" + + def test_extra_link(self): + self.assert_extra_link_url( + expected_url=( + "https://console.aws.amazon.com/ec2/home" + f"?region=eu-west-1#InstanceDetails:instanceId={self.INSTANCE_ID}" + ), + region_name="eu-west-1", + aws_partition="aws", + instance_id=self.INSTANCE_ID, + ) + + +class TestEC2InstanceDashboardLink(BaseAwsLinksTestCase): + link_class = EC2InstanceDashboardLink + + BASE_URL = "https://console.aws.amazon.com/ec2/home" + INSTANCE_IDS = ["i-xxxxxxxxxxxx", "i-yyyyyyyyyyyy"] + + def test_instance_id_filter(self): + instance_list = ",:".join(self.INSTANCE_IDS) + result = EC2InstanceDashboardLink.format_instance_id_filter(self.INSTANCE_IDS) + assert result == instance_list + + def test_extra_link(self): + instance_list = ",:".join(self.INSTANCE_IDS) + self.assert_extra_link_url( + expected_url=(f"{self.BASE_URL}?region=eu-west-1#Instances:instanceId=:{instance_list}"), + region_name="eu-west-1", + aws_partition="aws", + instance_ids=EC2InstanceDashboardLink.format_instance_id_filter(self.INSTANCE_IDS), + ) diff --git a/providers/tests/amazon/aws/operators/test_dms.py b/providers/tests/amazon/aws/operators/test_dms.py index 0b414234d9fa2..f128ab75f2b20 100644 --- a/providers/tests/amazon/aws/operators/test_dms.py +++ b/providers/tests/amazon/aws/operators/test_dms.py @@ -477,25 +477,35 @@ def test_init(self): @pytest.mark.db_test @mock.patch.object(DmsHook, "conn") def test_template_fields_native(self, mock_conn, session): - execution_date = timezone.datetime(2020, 1, 1) + logical_date = timezone.datetime(2020, 1, 1) Variable.set("test_filter", self.filter, session=session) dag = DAG( "test_dms", schedule=None, - start_date=execution_date, + start_date=logical_date, render_template_as_native_obj=True, ) op = DmsDescribeReplicationConfigsOperator( task_id="test_task", filter="{{ var.value.test_filter }}", dag=dag ) - dag_run = DagRun( - dag_id=dag.dag_id, - run_id="test", - run_type=DagRunType.MANUAL, - state=DagRunState.RUNNING, - ) + if AIRFLOW_V_3_0_PLUS: + dag_run = DagRun( + dag_id=dag.dag_id, + run_id="test", + run_type=DagRunType.MANUAL, + state=DagRunState.RUNNING, + logical_date=logical_date, + ) + else: + dag_run = DagRun( + dag_id=dag.dag_id, + run_id="test", + run_type=DagRunType.MANUAL, + state=DagRunState.RUNNING, + execution_date=logical_date, + ) ti = TaskInstance(task=op) ti.dag_run = dag_run session.add(ti) diff --git a/providers/tests/smtp/hooks/test_smtp.py b/providers/tests/smtp/hooks/test_smtp.py index ead0d02229b0d..e713b84423102 100644 --- a/providers/tests/smtp/hooks/test_smtp.py +++ b/providers/tests/smtp/hooks/test_smtp.py @@ -31,8 +31,6 @@ from airflow.utils import db from airflow.utils.session import create_session -from tests_common.test_utils.config import conf_vars - pytestmark = pytest.mark.db_test @@ -223,21 +221,7 @@ def test_send_mime_ssl(self, create_default_context, mock_smtp, mock_smtp_ssl): @patch("smtplib.SMTP_SSL") @patch("smtplib.SMTP") @patch("ssl.create_default_context") - def test_send_mime_ssl_none_email_context(self, create_default_context, mock_smtp, mock_smtp_ssl): - mock_smtp_ssl.return_value = Mock() - with conf_vars({("smtp", "smtp_ssl"): "True", ("email", "ssl_context"): "none"}): - with SmtpHook() as smtp_hook: - smtp_hook.send_email_smtp( - to="to", subject="subject", html_content="content", from_email="from" - ) - assert not mock_smtp.called - assert not create_default_context.called - mock_smtp_ssl.assert_called_once_with(host="smtp_server_address", port=465, timeout=30, context=None) - - @patch("smtplib.SMTP_SSL") - @patch("smtplib.SMTP") - @patch("ssl.create_default_context") - def test_send_mime_ssl_extra_context(self, create_default_context, mock_smtp, mock_smtp_ssl): + def test_send_mime_ssl_extra_none_context(self, create_default_context, mock_smtp, mock_smtp_ssl): mock_smtp_ssl.return_value = Mock() conn = Connection( conn_id="smtp_ssl_extra", @@ -246,72 +230,55 @@ def test_send_mime_ssl_extra_context(self, create_default_context, mock_smtp, mo login=None, password="None", port=465, - extra=json.dumps(dict(ssl_context="none", from_email="from")), + extra=json.dumps(dict(use_ssl=True, ssl_context="none", from_email="from")), ) db.merge_conn(conn) - with conf_vars({("smtp", "smtp_ssl"): "True", ("smtp_provider", "ssl_context"): "default"}): - with SmtpHook(smtp_conn_id="smtp_ssl_extra") as smtp_hook: - smtp_hook.send_email_smtp( - to="to", subject="subject", html_content="content", from_email="from" - ) - assert not mock_smtp.called - assert not create_default_context.called - mock_smtp_ssl.assert_called_once_with(host="smtp_server_address", port=465, timeout=30, context=None) - - @patch("smtplib.SMTP_SSL") - @patch("smtplib.SMTP") - @patch("ssl.create_default_context") - def test_send_mime_ssl_none_smtp_provider_context(self, create_default_context, mock_smtp, mock_smtp_ssl): - mock_smtp_ssl.return_value = Mock() - with conf_vars({("smtp", "smtp_ssl"): "True", ("smtp_provider", "ssl_context"): "none"}): - with SmtpHook() as smtp_hook: - smtp_hook.send_email_smtp( - to="to", subject="subject", html_content="content", from_email="from" - ) + with SmtpHook(smtp_conn_id="smtp_ssl_extra") as smtp_hook: + smtp_hook.send_email_smtp(to="to", subject="subject", html_content="content", from_email="from") assert not mock_smtp.called - assert not create_default_context.called + create_default_context.assert_not_called() mock_smtp_ssl.assert_called_once_with(host="smtp_server_address", port=465, timeout=30, context=None) @patch("smtplib.SMTP_SSL") @patch("smtplib.SMTP") @patch("ssl.create_default_context") - def test_send_mime_ssl_none_smtp_provider_default_email_context( - self, create_default_context, mock_smtp, mock_smtp_ssl - ): + def test_send_mime_ssl_extra_default_context(self, create_default_context, mock_smtp, mock_smtp_ssl): mock_smtp_ssl.return_value = Mock() - with conf_vars( - { - ("smtp", "smtp_ssl"): "True", - ("email", "ssl_context"): "default", - ("smtp_provider", "ssl_context"): "none", - } - ): - with SmtpHook() as smtp_hook: - smtp_hook.send_email_smtp( - to="to", subject="subject", html_content="content", from_email="from" - ) + conn = Connection( + conn_id="smtp_ssl_extra", + conn_type="smtp", + host="smtp_server_address", + login=None, + password="None", + port=465, + extra=json.dumps(dict(use_ssl=True, ssl_context="default", from_email="from")), + ) + db.merge_conn(conn) + with SmtpHook() as smtp_hook: + smtp_hook.send_email_smtp(to="to", subject="subject", html_content="content", from_email="from") assert not mock_smtp.called - assert not create_default_context.called - mock_smtp_ssl.assert_called_once_with(host="smtp_server_address", port=465, timeout=30, context=None) + assert create_default_context.called + mock_smtp_ssl.assert_called_once_with( + host="smtp_server_address", port=465, timeout=30, context=create_default_context.return_value + ) @patch("smtplib.SMTP_SSL") @patch("smtplib.SMTP") @patch("ssl.create_default_context") - def test_send_mime_ssl_default_smtp_provider_none_email_context( - self, create_default_context, mock_smtp, mock_smtp_ssl - ): + def test_send_mime_default_context(self, create_default_context, mock_smtp, mock_smtp_ssl): mock_smtp_ssl.return_value = Mock() - with conf_vars( - { - ("smtp", "smtp_ssl"): "True", - ("email", "ssl_context"): "none", - ("smtp_provider", "ssl_context"): "default", - } - ): - with SmtpHook() as smtp_hook: - smtp_hook.send_email_smtp( - to="to", subject="subject", html_content="content", from_email="from" - ) + conn = Connection( + conn_id="smtp_ssl_extra", + conn_type="smtp", + host="smtp_server_address", + login=None, + password="None", + port=465, + extra=json.dumps(dict(use_ssl=True, from_email="from")), + ) + db.merge_conn(conn) + with SmtpHook() as smtp_hook: + smtp_hook.send_email_smtp(to="to", subject="subject", html_content="content", from_email="from") assert not mock_smtp.called assert create_default_context.called mock_smtp_ssl.assert_called_once_with( diff --git a/providers/tests/smtp/notifications/test_smtp.py b/providers/tests/smtp/notifications/test_smtp.py index aa95f3c6f2fba..bacabf87b4d35 100644 --- a/providers/tests/smtp/notifications/test_smtp.py +++ b/providers/tests/smtp/notifications/test_smtp.py @@ -22,7 +22,6 @@ import pytest -from airflow.configuration import conf from airflow.providers.smtp.hooks.smtp import SmtpHook from airflow.providers.smtp.notifications.smtp import ( SmtpNotifier, @@ -31,7 +30,6 @@ from airflow.providers.standard.operators.empty import EmptyOperator from airflow.utils import timezone -from tests_common.test_utils.config import conf_vars from tests_common.test_utils.version_compat import AIRFLOW_V_2_10_PLUS pytestmark = pytest.mark.db_test @@ -125,14 +123,14 @@ def test_notifier_with_defaults(self, mock_smtphook_hook, create_task_instance): ti = create_task_instance(dag_id="dag", task_id="op", logical_date=timezone.datetime(2018, 1, 1)) context = {"dag": ti.dag_run.dag, "ti": ti} notifier = SmtpNotifier( - from_email=conf.get("smtp", "smtp_mail_from"), + from_email="any email", to="test_reciver@test.com", ) mock_smtphook_hook.return_value.subject_template = None mock_smtphook_hook.return_value.html_content_template = None notifier(context) mock_smtphook_hook.return_value.__enter__().send_email_smtp.assert_called_once_with( - from_email=conf.get("smtp", "smtp_mail_from"), + from_email="any email", to="test_reciver@test.com", subject="DAG dag - Task op - Run ID test in State None", html_content=mock.ANY, @@ -147,50 +145,6 @@ def test_notifier_with_defaults(self, mock_smtphook_hook, create_task_instance): content = mock_smtphook_hook.return_value.__enter__().send_email_smtp.call_args.kwargs["html_content"] assert f"{NUM_TRY} of 1" in content - @mock.patch("airflow.providers.smtp.notifications.smtp.SmtpHook") - def test_notifier_with_nondefault_conf_vars(self, mock_smtphook_hook, create_task_instance): - ti = create_task_instance(dag_id="dag", task_id="op", logical_date=timezone.datetime(2018, 1, 1)) - context = {"dag": ti.dag_run.dag, "ti": ti} - - mock_smtphook_hook.return_value.from_email = None - mock_smtphook_hook.return_value.subject_template = None - mock_smtphook_hook.return_value.html_content_template = None - - with ( - tempfile.NamedTemporaryFile(mode="wt", suffix=".txt") as f_subject, - tempfile.NamedTemporaryFile(mode="wt", suffix=".txt") as f_content, - ): - f_subject.write("Task {{ ti.task_id }} failed") - f_subject.flush() - - f_content.write("Mock content goes here") - f_content.flush() - - with conf_vars( - { - ("smtp", "templated_html_content_path"): f_content.name, - ("smtp", "templated_email_subject_path"): f_subject.name, - } - ): - notifier = SmtpNotifier( - from_email=conf.get("smtp", "smtp_mail_from"), - to="test_reciver@test.com", - ) - notifier(context) - mock_smtphook_hook.return_value.__enter__().send_email_smtp.assert_called_once_with( - from_email=conf.get("smtp", "smtp_mail_from"), - to="test_reciver@test.com", - subject="Task op failed", - html_content="Mock content goes here", - smtp_conn_id="smtp_default", - files=None, - cc=None, - bcc=None, - mime_subtype="mixed", - mime_charset="utf-8", - custom_headers=None, - ) - @mock.patch("airflow.providers.smtp.notifications.smtp.SmtpHook") def test_notifier_with_nondefault_connection_extra(self, mock_smtphook_hook, create_task_instance): ti = create_task_instance(dag_id="dag", task_id="op", logical_date=timezone.datetime(2018, 1, 1)) @@ -206,15 +160,15 @@ def test_notifier_with_nondefault_connection_extra(self, mock_smtphook_hook, cre f_content.write("Mock content goes here") f_content.flush() + mock_smtphook_hook.return_value.from_email = "{{ ti.task_id }}@test.com" mock_smtphook_hook.return_value.subject_template = f_subject.name mock_smtphook_hook.return_value.html_content_template = f_content.name notifier = SmtpNotifier( - from_email=conf.get("smtp", "smtp_mail_from"), to="test_reciver@test.com", ) notifier(context) mock_smtphook_hook.return_value.__enter__().send_email_smtp.assert_called_once_with( - from_email=conf.get("smtp", "smtp_mail_from"), + from_email="op@test.com", to="test_reciver@test.com", subject="Task op failed", html_content="Mock content goes here", diff --git a/providers/yandex/README.rst b/providers/yandex/README.rst index ef6c09aeba288..756fe2d21fb1f 100644 --- a/providers/yandex/README.rst +++ b/providers/yandex/README.rst @@ -1,3 +1,4 @@ + .. Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements. See the NOTICE file distributed with this work for additional information @@ -15,4 +16,50 @@ specific language governing permissions and limitations under the License. -This content will be overridden by pre-commit hook + .. NOTE! THIS FILE IS AUTOMATICALLY GENERATED AND WILL BE OVERWRITTEN! + + .. IF YOU WANT TO MODIFY TEMPLATE FOR THIS FILE, YOU SHOULD MODIFY THE TEMPLATE + `PROVIDER_README_TEMPLATE.rst.jinja2` IN the `dev/breeze/src/airflow_breeze/templates` DIRECTORY + + +Package ``apache-airflow-providers-yandex`` + +Release: ``4.0.0`` + + +This package is for Yandex, including: + + - `Yandex.Cloud `__ + + +Provider package +---------------- + +This is a provider package for ``yandex`` provider. All classes for this provider package +are in ``airflow.providers.yandex`` python package. + +You can find package information and changelog for the provider +in the `documentation `_. + +Installation +------------ + +You can install this package on top of an existing Airflow 2 installation (see ``Requirements`` below +for the minimum Airflow version supported) via +``pip install apache-airflow-providers-yandex`` + +The package supports the following python versions: 3.9,3.10,3.11,3.12 + +Requirements +------------ + +======================= ================== +PIP package Version required +======================= ================== +``apache-airflow`` ``>=2.9.0`` +``yandexcloud`` ``>=0.308.0`` +``yandex-query-client`` ``>=0.1.4`` +======================= ================== + +The changelog for the provider package can be found in the +`changelog `_. diff --git a/providers/yandex/docs/operators/dataproc.rst b/providers/yandex/docs/operators/dataproc.rst index 03dfd3acae817..b7188e2ea52f6 100644 --- a/providers/yandex/docs/operators/dataproc.rst +++ b/providers/yandex/docs/operators/dataproc.rst @@ -34,4 +34,4 @@ that can be integrated with Apache Hadoop and other storage systems. Using the operators ^^^^^^^^^^^^^^^^^^^ To learn how to use Data Proc operators, -see `example DAGs `_. +see `example DAGs `_. diff --git a/providers/yandex/docs/operators/yq.rst b/providers/yandex/docs/operators/yq.rst index 23bd4ac336160..08a90bb817220 100644 --- a/providers/yandex/docs/operators/yq.rst +++ b/providers/yandex/docs/operators/yq.rst @@ -25,4 +25,4 @@ Yandex Query Operators Using the operators ^^^^^^^^^^^^^^^^^^^ To learn how to use Yandex Query operator, -see `example DAG `__. +see `example DAG `__. diff --git a/providers/src/airflow/providers/yandex/provider.yaml b/providers/yandex/provider.yaml similarity index 92% rename from providers/src/airflow/providers/yandex/provider.yaml rename to providers/yandex/provider.yaml index 22e06e21478bf..792fb195868c6 100644 --- a/providers/src/airflow/providers/yandex/provider.yaml +++ b/providers/yandex/provider.yaml @@ -52,29 +52,24 @@ versions: - 1.0.1 - 1.0.0 -dependencies: - - apache-airflow>=2.9.0 - - yandexcloud>=0.308.0 - - yandex-query-client>=0.1.4 - integrations: - integration-name: Yandex.Cloud external-doc-url: https://cloud.yandex.com/ - logo: /integration-logos/yandex/Yandex-Cloud.png + logo: /docs/integration-logos/Yandex-Cloud.png tags: [service] - integration-name: Yandex.Cloud Dataproc external-doc-url: https://cloud.yandex.com/dataproc how-to-guide: - /docs/apache-airflow-providers-yandex/operators/dataproc.rst - logo: /integration-logos/yandex/Yandex-Cloud.png + logo: /docs/integration-logos/Yandex-Cloud.png tags: [service] - integration-name: Yandex.Cloud YQ external-doc-url: https://cloud.yandex.com/en/services/query how-to-guide: - /docs/apache-airflow-providers-yandex/operators/yq.rst - logo: /integration-logos/yandex/Yandex-Cloud.png + logo: /docs/integration-logos/Yandex-Cloud.png tags: [service] operators: diff --git a/providers/yandex/pyproject.toml b/providers/yandex/pyproject.toml new file mode 100644 index 0000000000000..375951ae5d876 --- /dev/null +++ b/providers/yandex/pyproject.toml @@ -0,0 +1,79 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# NOTE! THIS FILE IS AUTOMATICALLY GENERATED AND WILL BE OVERWRITTEN! + +# IF YOU WANT TO MODIFY THIS FILE EXCEPT DEPENDENCIES, YOU SHOULD MODIFY THE TEMPLATE +# `pyproject_TEMPLATE.toml.jinja2` IN the `dev/breeze/src/airflow_breeze/templates` DIRECTORY +[build-system] +requires = ["flit_core==3.10.1"] +build-backend = "flit_core.buildapi" + +[project] +name = "apache-airflow-providers-yandex" +version = "4.0.0" +description = "Provider package apache-airflow-providers-yandex for Apache Airflow" +readme = "README.rst" +authors = [ + {name="Apache Software Foundation", email="dev@airflow.apache.org"}, +] +maintainers = [ + {name="Apache Software Foundation", email="dev@airflow.apache.org"}, +] +keywords = [ "airflow-provider", "yandex", "airflow", "integration" ] +classifiers = [ + "Development Status :: 5 - Production/Stable", + "Environment :: Console", + "Environment :: Web Environment", + "Intended Audience :: Developers", + "Intended Audience :: System Administrators", + "Framework :: Apache Airflow", + "Framework :: Apache Airflow :: Provider", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Topic :: System :: Monitoring", +] +requires-python = "~=3.9" + +# The dependencies should be modified in place in the generated file +# Any change in the dependencies is preserved when the file is regenerated +dependencies = [ + "apache-airflow>=2.9.0", + "yandexcloud>=0.308.0", + "yandex-query-client>=0.1.4", +] + +[project.urls] +"Documentation" = "https://airflow.apache.org/docs/apache-airflow-providers-yandex/4.0.0" +"Changelog" = "https://airflow.apache.org/docs/apache-airflow-providers-yandex/4.0.0/changelog.html" +"Bug Tracker" = "https://github.com/apache/airflow/issues" +"Source Code" = "https://github.com/apache/airflow" +"Slack Chat" = "https://s.apache.org/airflow-slack" +"Twitter" = "https://x.com/ApacheAirflow" +"YouTube" = "https://www.youtube.com/channel/UCSXwxpWZQ7XZ1WL3wqevChA/" + +[project.entry-points."apache_airflow_provider"] +provider_info = "airflow.providers.yandex.get_provider_info:get_provider_info" + +[tool.flit.module] +name = "airflow.providers.yandex" + +[tool.pytest.ini_options] +ignore = "tests/system/" diff --git a/providers/yandex/src/airflow/providers/yandex/LICENSE b/providers/yandex/src/airflow/providers/yandex/LICENSE new file mode 100644 index 0000000000000..11069edd79019 --- /dev/null +++ b/providers/yandex/src/airflow/providers/yandex/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + +TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + +1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + +2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + +3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + +4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + +5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + +6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + +7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + +8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + +9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + +END OF TERMS AND CONDITIONS + +APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + +Copyright [yyyy] [name of copyright owner] + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/providers/src/airflow/providers/yandex/__init__.py b/providers/yandex/src/airflow/providers/yandex/__init__.py similarity index 100% rename from providers/src/airflow/providers/yandex/__init__.py rename to providers/yandex/src/airflow/providers/yandex/__init__.py diff --git a/providers/yandex/src/airflow/providers/yandex/get_provider_info.py b/providers/yandex/src/airflow/providers/yandex/get_provider_info.py new file mode 100644 index 0000000000000..c2f62621f04c9 --- /dev/null +++ b/providers/yandex/src/airflow/providers/yandex/get_provider_info.py @@ -0,0 +1,121 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# NOTE! THIS FILE IS AUTOMATICALLY GENERATED AND WILL BE OVERWRITTEN! +# +# IF YOU WANT TO MODIFY THIS FILE, YOU SHOULD MODIFY THE TEMPLATE +# `get_provider_info_TEMPLATE.py.jinja2` IN the `dev/breeze/src/airflow_breeze/templates` DIRECTORY + + +def get_provider_info(): + return { + "package-name": "apache-airflow-providers-yandex", + "name": "Yandex", + "description": "This package is for Yandex, including:\n\n - `Yandex.Cloud `__\n", + "state": "ready", + "source-date-epoch": 1734537491, + "versions": [ + "4.0.0", + "3.12.0", + "3.11.2", + "3.11.1", + "3.11.0", + "3.10.0", + "3.9.0", + "3.8.0", + "3.7.1", + "3.7.0", + "3.6.0", + "3.5.0", + "3.4.0", + "3.3.0", + "3.2.0", + "3.1.0", + "3.0.0", + "2.2.3", + "2.2.2", + "2.2.1", + "2.2.0", + "2.1.0", + "2.0.0", + "1.0.1", + "1.0.0", + ], + "integrations": [ + { + "integration-name": "Yandex.Cloud", + "external-doc-url": "https://cloud.yandex.com/", + "logo": "/docs/integration-logos/Yandex-Cloud.png", + "tags": ["service"], + }, + { + "integration-name": "Yandex.Cloud Dataproc", + "external-doc-url": "https://cloud.yandex.com/dataproc", + "how-to-guide": ["/docs/apache-airflow-providers-yandex/operators/dataproc.rst"], + "logo": "/docs/integration-logos/Yandex-Cloud.png", + "tags": ["service"], + }, + { + "integration-name": "Yandex.Cloud YQ", + "external-doc-url": "https://cloud.yandex.com/en/services/query", + "how-to-guide": ["/docs/apache-airflow-providers-yandex/operators/yq.rst"], + "logo": "/docs/integration-logos/Yandex-Cloud.png", + "tags": ["service"], + }, + ], + "operators": [ + { + "integration-name": "Yandex.Cloud Dataproc", + "python-modules": ["airflow.providers.yandex.operators.dataproc"], + }, + { + "integration-name": "Yandex.Cloud YQ", + "python-modules": ["airflow.providers.yandex.operators.yq"], + }, + ], + "hooks": [ + {"integration-name": "Yandex.Cloud", "python-modules": ["airflow.providers.yandex.hooks.yandex"]}, + { + "integration-name": "Yandex.Cloud Dataproc", + "python-modules": ["airflow.providers.yandex.hooks.dataproc"], + }, + {"integration-name": "Yandex.Cloud YQ", "python-modules": ["airflow.providers.yandex.hooks.yq"]}, + ], + "connection-types": [ + { + "hook-class-name": "airflow.providers.yandex.hooks.yandex.YandexCloudBaseHook", + "connection-type": "yandexcloud", + } + ], + "secrets-backends": ["airflow.providers.yandex.secrets.lockbox.LockboxSecretBackend"], + "extra-links": ["airflow.providers.yandex.links.yq.YQLink"], + "config": { + "yandex": { + "description": "This section contains settings for Yandex Cloud integration.", + "options": { + "sdk_user_agent_prefix": { + "description": "Prefix for User-Agent header in Yandex.Cloud SDK requests\n", + "version_added": "3.6.0", + "type": "string", + "example": None, + "default": "", + } + }, + } + }, + "dependencies": ["apache-airflow>=2.9.0", "yandexcloud>=0.308.0", "yandex-query-client>=0.1.4"], + } diff --git a/providers/tests/cncf/kubernetes/sensors/__init__.py b/providers/yandex/src/airflow/providers/yandex/hooks/__init__.py similarity index 100% rename from providers/tests/cncf/kubernetes/sensors/__init__.py rename to providers/yandex/src/airflow/providers/yandex/hooks/__init__.py diff --git a/providers/src/airflow/providers/yandex/hooks/dataproc.py b/providers/yandex/src/airflow/providers/yandex/hooks/dataproc.py similarity index 100% rename from providers/src/airflow/providers/yandex/hooks/dataproc.py rename to providers/yandex/src/airflow/providers/yandex/hooks/dataproc.py diff --git a/providers/src/airflow/providers/yandex/hooks/yandex.py b/providers/yandex/src/airflow/providers/yandex/hooks/yandex.py similarity index 100% rename from providers/src/airflow/providers/yandex/hooks/yandex.py rename to providers/yandex/src/airflow/providers/yandex/hooks/yandex.py diff --git a/providers/src/airflow/providers/yandex/hooks/yq.py b/providers/yandex/src/airflow/providers/yandex/hooks/yq.py similarity index 100% rename from providers/src/airflow/providers/yandex/hooks/yq.py rename to providers/yandex/src/airflow/providers/yandex/hooks/yq.py diff --git a/providers/tests/cncf/kubernetes/triggers/__init__.py b/providers/yandex/src/airflow/providers/yandex/links/__init__.py similarity index 100% rename from providers/tests/cncf/kubernetes/triggers/__init__.py rename to providers/yandex/src/airflow/providers/yandex/links/__init__.py diff --git a/providers/src/airflow/providers/yandex/links/yq.py b/providers/yandex/src/airflow/providers/yandex/links/yq.py similarity index 100% rename from providers/src/airflow/providers/yandex/links/yq.py rename to providers/yandex/src/airflow/providers/yandex/links/yq.py diff --git a/providers/tests/cncf/kubernetes/utils/__init__.py b/providers/yandex/src/airflow/providers/yandex/operators/__init__.py similarity index 100% rename from providers/tests/cncf/kubernetes/utils/__init__.py rename to providers/yandex/src/airflow/providers/yandex/operators/__init__.py diff --git a/providers/src/airflow/providers/yandex/operators/dataproc.py b/providers/yandex/src/airflow/providers/yandex/operators/dataproc.py similarity index 100% rename from providers/src/airflow/providers/yandex/operators/dataproc.py rename to providers/yandex/src/airflow/providers/yandex/operators/dataproc.py diff --git a/providers/src/airflow/providers/yandex/operators/yq.py b/providers/yandex/src/airflow/providers/yandex/operators/yq.py similarity index 100% rename from providers/src/airflow/providers/yandex/operators/yq.py rename to providers/yandex/src/airflow/providers/yandex/operators/yq.py diff --git a/providers/tests/system/cncf/kubernetes/__init__.py b/providers/yandex/src/airflow/providers/yandex/secrets/__init__.py similarity index 100% rename from providers/tests/system/cncf/kubernetes/__init__.py rename to providers/yandex/src/airflow/providers/yandex/secrets/__init__.py diff --git a/providers/src/airflow/providers/yandex/secrets/lockbox.py b/providers/yandex/src/airflow/providers/yandex/secrets/lockbox.py similarity index 100% rename from providers/src/airflow/providers/yandex/secrets/lockbox.py rename to providers/yandex/src/airflow/providers/yandex/secrets/lockbox.py index d65131ab2cb1a..0381486bab76e 100644 --- a/providers/src/airflow/providers/yandex/secrets/lockbox.py +++ b/providers/yandex/src/airflow/providers/yandex/secrets/lockbox.py @@ -21,14 +21,14 @@ from functools import cached_property from typing import Any +import yandexcloud + import yandex.cloud.lockbox.v1.payload_pb2 as payload_pb import yandex.cloud.lockbox.v1.payload_service_pb2 as payload_service_pb import yandex.cloud.lockbox.v1.payload_service_pb2_grpc as payload_service_pb_grpc import yandex.cloud.lockbox.v1.secret_pb2 as secret_pb import yandex.cloud.lockbox.v1.secret_service_pb2 as secret_service_pb import yandex.cloud.lockbox.v1.secret_service_pb2_grpc as secret_service_pb_grpc -import yandexcloud - from airflow.models import Connection from airflow.providers.yandex.utils.credentials import get_credentials from airflow.providers.yandex.utils.defaults import default_conn_name diff --git a/providers/tests/system/microsoft/winrm/__init__.py b/providers/yandex/src/airflow/providers/yandex/utils/__init__.py similarity index 100% rename from providers/tests/system/microsoft/winrm/__init__.py rename to providers/yandex/src/airflow/providers/yandex/utils/__init__.py diff --git a/providers/src/airflow/providers/yandex/utils/credentials.py b/providers/yandex/src/airflow/providers/yandex/utils/credentials.py similarity index 100% rename from providers/src/airflow/providers/yandex/utils/credentials.py rename to providers/yandex/src/airflow/providers/yandex/utils/credentials.py diff --git a/providers/src/airflow/providers/yandex/utils/defaults.py b/providers/yandex/src/airflow/providers/yandex/utils/defaults.py similarity index 100% rename from providers/src/airflow/providers/yandex/utils/defaults.py rename to providers/yandex/src/airflow/providers/yandex/utils/defaults.py diff --git a/providers/src/airflow/providers/yandex/utils/fields.py b/providers/yandex/src/airflow/providers/yandex/utils/fields.py similarity index 100% rename from providers/src/airflow/providers/yandex/utils/fields.py rename to providers/yandex/src/airflow/providers/yandex/utils/fields.py diff --git a/providers/src/airflow/providers/yandex/utils/user_agent.py b/providers/yandex/src/airflow/providers/yandex/utils/user_agent.py similarity index 100% rename from providers/src/airflow/providers/yandex/utils/user_agent.py rename to providers/yandex/src/airflow/providers/yandex/utils/user_agent.py diff --git a/providers/yandex/tests/conftest.py b/providers/yandex/tests/conftest.py new file mode 100644 index 0000000000000..068fe6bbf5ae9 --- /dev/null +++ b/providers/yandex/tests/conftest.py @@ -0,0 +1,32 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import pathlib + +import pytest + +pytest_plugins = "tests_common.pytest_plugin" + + +@pytest.hookimpl(tryfirst=True) +def pytest_configure(config: pytest.Config) -> None: + deprecations_ignore_path = pathlib.Path(__file__).parent.joinpath("deprecations_ignore.yml") + dep_path = [deprecations_ignore_path] if deprecations_ignore_path.exists() else [] + config.inicfg["airflow_deprecations_ignore"] = ( + config.inicfg.get("airflow_deprecations_ignore", []) + dep_path # type: ignore[assignment,operator] + ) diff --git a/providers/yandex/tests/provider_tests/__init__.py b/providers/yandex/tests/provider_tests/__init__.py new file mode 100644 index 0000000000000..e8fd22856438c --- /dev/null +++ b/providers/yandex/tests/provider_tests/__init__.py @@ -0,0 +1,17 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +__path__ = __import__("pkgutil").extend_path(__path__, __name__) # type: ignore diff --git a/providers/tests/system/yandex/__init__.py b/providers/yandex/tests/provider_tests/yandex/__init__.py similarity index 100% rename from providers/tests/system/yandex/__init__.py rename to providers/yandex/tests/provider_tests/yandex/__init__.py diff --git a/providers/tests/yandex/__init__.py b/providers/yandex/tests/provider_tests/yandex/hooks/__init__.py similarity index 100% rename from providers/tests/yandex/__init__.py rename to providers/yandex/tests/provider_tests/yandex/hooks/__init__.py diff --git a/providers/tests/yandex/hooks/test_dataproc.py b/providers/yandex/tests/provider_tests/yandex/hooks/test_dataproc.py similarity index 99% rename from providers/tests/yandex/hooks/test_dataproc.py rename to providers/yandex/tests/provider_tests/yandex/hooks/test_dataproc.py index c960e573c1fd6..212d1e3a9c05a 100644 --- a/providers/tests/yandex/hooks/test_dataproc.py +++ b/providers/yandex/tests/provider_tests/yandex/hooks/test_dataproc.py @@ -23,8 +23,8 @@ yandexlcloud = pytest.importorskip("yandexcloud") -from airflow.models import Connection -from airflow.providers.yandex.hooks.dataproc import DataprocHook +from airflow.models import Connection # noqa: E402 +from airflow.providers.yandex.hooks.dataproc import DataprocHook # noqa: E402 # Airflow connection with type "yandexcloud" must be created CONNECTION_ID = "yandexcloud_default" diff --git a/providers/tests/yandex/hooks/test_yandex.py b/providers/yandex/tests/provider_tests/yandex/hooks/test_yandex.py similarity index 100% rename from providers/tests/yandex/hooks/test_yandex.py rename to providers/yandex/tests/provider_tests/yandex/hooks/test_yandex.py index 92188cd07f319..7907c6a434001 100644 --- a/providers/tests/yandex/hooks/test_yandex.py +++ b/providers/yandex/tests/provider_tests/yandex/hooks/test_yandex.py @@ -21,12 +21,12 @@ import pytest -yandexcloud = pytest.importorskip("yandexcloud") - from airflow.providers.yandex.hooks.yandex import YandexCloudBaseHook from tests_common.test_utils.config import conf_vars +yandexcloud = pytest.importorskip("yandexcloud") + class TestYandexHook: @mock.patch("airflow.hooks.base.BaseHook.get_connection") diff --git a/providers/tests/yandex/hooks/test_yq.py b/providers/yandex/tests/provider_tests/yandex/hooks/test_yq.py similarity index 99% rename from providers/tests/yandex/hooks/test_yq.py rename to providers/yandex/tests/provider_tests/yandex/hooks/test_yq.py index 3defe65d6c8e5..c93a5f53895f6 100644 --- a/providers/tests/yandex/hooks/test_yq.py +++ b/providers/yandex/tests/provider_tests/yandex/hooks/test_yq.py @@ -21,15 +21,14 @@ from unittest import mock import pytest - -yandexcloud = pytest.importorskip("yandexcloud") - import responses from responses import matchers from airflow.models import Connection from airflow.providers.yandex.hooks.yq import YQHook +yandexcloud = pytest.importorskip("yandexcloud") + OAUTH_TOKEN = "my_oauth_token" IAM_TOKEN = "my_iam_token" SERVICE_ACCOUNT_AUTH_KEY_JSON = """{"id":"my_id", "service_account_id":"my_sa1", "private_key":"my_pk"}""" diff --git a/providers/tests/yandex/hooks/__init__.py b/providers/yandex/tests/provider_tests/yandex/links/__init__.py similarity index 100% rename from providers/tests/yandex/hooks/__init__.py rename to providers/yandex/tests/provider_tests/yandex/links/__init__.py diff --git a/providers/tests/yandex/links/test_yq.py b/providers/yandex/tests/provider_tests/yandex/links/test_yq.py similarity index 100% rename from providers/tests/yandex/links/test_yq.py rename to providers/yandex/tests/provider_tests/yandex/links/test_yq.py diff --git a/providers/tests/yandex/links/__init__.py b/providers/yandex/tests/provider_tests/yandex/operators/__init__.py similarity index 100% rename from providers/tests/yandex/links/__init__.py rename to providers/yandex/tests/provider_tests/yandex/operators/__init__.py diff --git a/providers/tests/yandex/operators/test_dataproc.py b/providers/yandex/tests/provider_tests/yandex/operators/test_dataproc.py similarity index 100% rename from providers/tests/yandex/operators/test_dataproc.py rename to providers/yandex/tests/provider_tests/yandex/operators/test_dataproc.py index 649631547de75..cde731d2d407b 100644 --- a/providers/tests/yandex/operators/test_dataproc.py +++ b/providers/yandex/tests/provider_tests/yandex/operators/test_dataproc.py @@ -21,8 +21,6 @@ import pytest -yandexcloud = pytest.importorskip("yandexcloud") - from airflow.models.dag import DAG from airflow.providers.yandex.operators.dataproc import ( DataprocCreateClusterOperator, @@ -33,6 +31,8 @@ DataprocDeleteClusterOperator, ) +yandexcloud = pytest.importorskip("yandexcloud") + # Airflow connection with type "yandexcloud" CONNECTION_ID = "yandexcloud_default" diff --git a/providers/tests/yandex/operators/test_yq.py b/providers/yandex/tests/provider_tests/yandex/operators/test_yq.py similarity index 99% rename from providers/tests/yandex/operators/test_yq.py rename to providers/yandex/tests/provider_tests/yandex/operators/test_yq.py index 3c415f6ac6b65..127e4eb837972 100644 --- a/providers/tests/yandex/operators/test_yq.py +++ b/providers/yandex/tests/provider_tests/yandex/operators/test_yq.py @@ -21,11 +21,6 @@ from unittest.mock import MagicMock, call, patch import pytest - -from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS - -yandexcloud = pytest.importorskip("yandexcloud") - import responses from responses import matchers @@ -33,6 +28,10 @@ from airflow.models.dag import DAG from airflow.providers.yandex.operators.yq import YQExecuteQueryOperator +from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS + +yandexcloud = pytest.importorskip("yandexcloud") + OAUTH_TOKEN = "my_oauth_token" FOLDER_ID = "my_folder_id" diff --git a/providers/tests/yandex/operators/__init__.py b/providers/yandex/tests/provider_tests/yandex/secrets/__init__.py similarity index 100% rename from providers/tests/yandex/operators/__init__.py rename to providers/yandex/tests/provider_tests/yandex/secrets/__init__.py diff --git a/providers/tests/yandex/secrets/test_lockbox.py b/providers/yandex/tests/provider_tests/yandex/secrets/test_lockbox.py similarity index 99% rename from providers/tests/yandex/secrets/test_lockbox.py rename to providers/yandex/tests/provider_tests/yandex/secrets/test_lockbox.py index 7b9ab19d3e87a..148b09931d1c1 100644 --- a/providers/tests/yandex/secrets/test_lockbox.py +++ b/providers/yandex/tests/provider_tests/yandex/secrets/test_lockbox.py @@ -20,9 +20,6 @@ from unittest.mock import MagicMock, Mock, patch import pytest - -yandexcloud = pytest.importorskip("yandexcloud") - import yandex.cloud.lockbox.v1.payload_pb2 as payload_pb import yandex.cloud.lockbox.v1.secret_pb2 as secret_pb import yandex.cloud.lockbox.v1.secret_service_pb2 as secret_service_pb @@ -30,6 +27,8 @@ from airflow.providers.yandex.secrets.lockbox import LockboxSecretBackend from airflow.providers.yandex.utils.defaults import default_conn_name +yandexcloud = pytest.importorskip("yandexcloud") + class TestLockboxSecretBackend: @patch("airflow.providers.yandex.secrets.lockbox.LockboxSecretBackend._get_secret_value") diff --git a/providers/tests/yandex/secrets/__init__.py b/providers/yandex/tests/provider_tests/yandex/utils/__init__.py similarity index 100% rename from providers/tests/yandex/secrets/__init__.py rename to providers/yandex/tests/provider_tests/yandex/utils/__init__.py diff --git a/providers/tests/yandex/utils/test_credentials.py b/providers/yandex/tests/provider_tests/yandex/utils/test_credentials.py similarity index 100% rename from providers/tests/yandex/utils/test_credentials.py rename to providers/yandex/tests/provider_tests/yandex/utils/test_credentials.py diff --git a/providers/tests/yandex/utils/test_defaults.py b/providers/yandex/tests/provider_tests/yandex/utils/test_defaults.py similarity index 100% rename from providers/tests/yandex/utils/test_defaults.py rename to providers/yandex/tests/provider_tests/yandex/utils/test_defaults.py diff --git a/providers/tests/yandex/utils/test_fields.py b/providers/yandex/tests/provider_tests/yandex/utils/test_fields.py similarity index 100% rename from providers/tests/yandex/utils/test_fields.py rename to providers/yandex/tests/provider_tests/yandex/utils/test_fields.py diff --git a/providers/tests/yandex/utils/test_user_agent.py b/providers/yandex/tests/provider_tests/yandex/utils/test_user_agent.py similarity index 100% rename from providers/tests/yandex/utils/test_user_agent.py rename to providers/yandex/tests/provider_tests/yandex/utils/test_user_agent.py index 8e017a6e674a8..58cd4d3ed3968 100644 --- a/providers/tests/yandex/utils/test_user_agent.py +++ b/providers/yandex/tests/provider_tests/yandex/utils/test_user_agent.py @@ -20,10 +20,10 @@ import pytest -yandexcloud = pytest.importorskip("yandexcloud") - from airflow.providers.yandex.utils.user_agent import provider_user_agent +yandexcloud = pytest.importorskip("yandexcloud") + def test_provider_user_agent(): user_agent = provider_user_agent() diff --git a/providers/tests/yandex/utils/__init__.py b/providers/yandex/tests/system/yandex/__init__.py similarity index 100% rename from providers/tests/yandex/utils/__init__.py rename to providers/yandex/tests/system/yandex/__init__.py diff --git a/providers/tests/system/yandex/example_yandexcloud.py b/providers/yandex/tests/system/yandex/example_yandexcloud.py similarity index 100% rename from providers/tests/system/yandex/example_yandexcloud.py rename to providers/yandex/tests/system/yandex/example_yandexcloud.py index 3cb7226208baa..1e1d4ae417d9a 100644 --- a/providers/tests/system/yandex/example_yandexcloud.py +++ b/providers/yandex/tests/system/yandex/example_yandexcloud.py @@ -18,6 +18,9 @@ from datetime import datetime +from google.protobuf.json_format import MessageToDict +from yandexcloud.operations import OperationError + import yandex.cloud.dataproc.v1.cluster_pb2 as cluster_pb import yandex.cloud.dataproc.v1.cluster_service_pb2 as cluster_service_pb import yandex.cloud.dataproc.v1.cluster_service_pb2_grpc as cluster_service_grpc_pb @@ -26,9 +29,6 @@ import yandex.cloud.dataproc.v1.job_service_pb2 as job_service_pb import yandex.cloud.dataproc.v1.job_service_pb2_grpc as job_service_grpc_pb import yandex.cloud.dataproc.v1.subcluster_pb2 as subcluster_pb -from google.protobuf.json_format import MessageToDict -from yandexcloud.operations import OperationError - from airflow import DAG from airflow.decorators import task from airflow.providers.yandex.hooks.yandex import YandexCloudBaseHook diff --git a/providers/tests/system/yandex/example_yandexcloud_dataproc.py b/providers/yandex/tests/system/yandex/example_yandexcloud_dataproc.py similarity index 100% rename from providers/tests/system/yandex/example_yandexcloud_dataproc.py rename to providers/yandex/tests/system/yandex/example_yandexcloud_dataproc.py diff --git a/providers/tests/system/yandex/example_yandexcloud_dataproc_lightweight.py b/providers/yandex/tests/system/yandex/example_yandexcloud_dataproc_lightweight.py similarity index 100% rename from providers/tests/system/yandex/example_yandexcloud_dataproc_lightweight.py rename to providers/yandex/tests/system/yandex/example_yandexcloud_dataproc_lightweight.py diff --git a/providers/tests/system/yandex/example_yandexcloud_yq.py b/providers/yandex/tests/system/yandex/example_yandexcloud_yq.py similarity index 100% rename from providers/tests/system/yandex/example_yandexcloud_yq.py rename to providers/yandex/tests/system/yandex/example_yandexcloud_yq.py diff --git a/pyproject.toml b/pyproject.toml index 8f024bf37b6b2..52856e41ae4c9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -644,8 +644,10 @@ dev = [ "apache-airflow-providers-apache-cassandra", "apache-airflow-providers-apache-drill", "apache-airflow-providers-apache-druid", + "apache-airflow-providers-apache-flink", "apache-airflow-providers-apache-hive", "apache-airflow-providers-apache-iceberg", + "apache-airflow-providers-apache-impala", "apache-airflow-providers-apache-kafka", "apache-airflow-providers-apache-kylin", "apache-airflow-providers-apache-livy", @@ -656,6 +658,8 @@ dev = [ "apache-airflow-providers-asana", "apache-airflow-providers-atlassian-jira", "apache-airflow-providers-celery", + "apache-airflow-providers-cloudant", + "apache-airflow-providers-cncf-kubernetes", "apache-airflow-providers-cohere", "apache-airflow-providers-common-compat", "apache-airflow-providers-common-io", @@ -679,6 +683,7 @@ dev = [ "apache-airflow-providers-microsoft-mssql", "apache-airflow-providers-microsoft-psrp", "apache-airflow-providers-jdbc", + "apache-airflow-providers-microsoft-winrm", "apache-airflow-providers-mongo", "apache-airflow-providers-openlineage", "apache-airflow-providers-hashicorp", @@ -717,6 +722,7 @@ dev = [ "apache-airflow-providers-trino", "apache-airflow-providers-vertica", "apache-airflow-providers-weaviate", + "apache-airflow-providers-yandex", "apache-airflow-providers-ydb", "apache-airflow-providers-zendesk", "apache-airflow-task-sdk", @@ -732,8 +738,10 @@ apache-airflow-providers-apache-beam = { workspace = true } apache-airflow-providers-apache-cassandra = { workspace = true } apache-airflow-providers-apache-drill = { workspace = true } apache-airflow-providers-apache-druid = { workspace = true } +apache-airflow-providers-apache-flink = { workspace = true } apache-airflow-providers-apache-hive = { workspace = true } apache-airflow-providers-apache-iceberg = {workspace = true} +apache-airflow-providers-apache-impala = { workspace = true } apache-airflow-providers-apache-kafka = { workspace = true } apache-airflow-providers-apache-kylin = { workspace = true } apache-airflow-providers-apache-livy = { workspace = true } @@ -744,6 +752,8 @@ apache-airflow-providers-apprise = { workspace = true } apache-airflow-providers-asana = { workspace = true } apache-airflow-providers-atlassian-jira = { workspace = true } apache-airflow-providers-celery = {workspace = true} +apache-airflow-providers-cloudant = { workspace = true } +apache-airflow-providers-cncf-kubernetes = { workspace = true } apache-airflow-providers-cohere = { workspace = true } apache-airflow-providers-common-compat = { workspace = true } apache-airflow-providers-common-io = { workspace = true } @@ -767,6 +777,7 @@ apache-airflow-providers-influxdb = { workspace = true } apache-airflow-providers-microsoft-mssql = { workspace = true } apache-airflow-providers-microsoft-psrp = { workspace = true } apache-airflow-providers-jdbc = { workspace = true } +apache-airflow-providers-microsoft-winrm = { workspace = true } apache-airflow-providers-mongo = { workspace = true } apache-airflow-providers-openlineage = { workspace = true } apache-airflow-providers-hashicorp = { workspace = true } @@ -805,6 +816,7 @@ apache-airflow-providers-teradata = { workspace = true } apache-airflow-providers-trino = { workspace = true } apache-airflow-providers-vertica = { workspace = true } apache-airflow-providers-weaviate = { workspace = true } +apache-airflow-providers-yandex = { workspace = true } apache-airflow-providers-ydb = { workspace = true } apache-airflow-providers-zendesk = { workspace = true } apache-airflow-task-sdk = { workspace = true } @@ -818,8 +830,10 @@ members = [ "providers/apache/cassandra", "providers/apache/drill", "providers/apache/druid", + "providers/apache/flink", "providers/apache/hive", "providers/apache/iceberg", + "providers/apache/impala", "providers/apache/kafka", "providers/apache/kylin", "providers/apache/livy", @@ -830,6 +844,8 @@ members = [ "providers/asana", "providers/atlassian/jira", "providers/celery", + "providers/cloudant", + "providers/cncf/kubernetes", "providers/cohere", "providers/common/compat", "providers/common/io", @@ -856,6 +872,7 @@ members = [ "providers/jenkins", "providers/microsoft/mssql", "providers/microsoft/psrp", + "providers/microsoft/winrm", "providers/mongo", "providers/mysql", "providers/neo4j", @@ -891,6 +908,7 @@ members = [ "providers/trino", "providers/vertica", "providers/weaviate", + "providers/yandex", "providers/ydb", "providers/zendesk", "task_sdk", diff --git a/scripts/ci/docker-compose/remove-sources.yml b/scripts/ci/docker-compose/remove-sources.yml index 056da8cf4fa7b..13a9d7d18fc85 100644 --- a/scripts/ci/docker-compose/remove-sources.yml +++ b/scripts/ci/docker-compose/remove-sources.yml @@ -38,8 +38,10 @@ services: - ../../../empty:/opt/airflow/providers/apache/cassandra/src - ../../../empty:/opt/airflow/providers/apache/drill/src - ../../../empty:/opt/airflow/providers/apache/druid/src + - ../../../empty:/opt/airflow/providers/apache/flink/src - ../../../empty:/opt/airflow/providers/apache/hive/src - ../../../empty:/opt/airflow/providers/apache/iceberg/src + - ../../../empty:/opt/airflow/providers/apache/impala/src - ../../../empty:/opt/airflow/providers/apache/kafka/src - ../../../empty:/opt/airflow/providers/apache/kylin/src - ../../../empty:/opt/airflow/providers/apache/livy/src @@ -50,6 +52,8 @@ services: - ../../../empty:/opt/airflow/providers/asana/src - ../../../empty:/opt/airflow/providers/atlassian/jira/src - ../../../empty:/opt/airflow/providers/celery/src + - ../../../empty:/opt/airflow/providers/cloudant/src + - ../../../empty:/opt/airflow/providers/cncf/kubernetes/src - ../../../empty:/opt/airflow/providers/cohere/src - ../../../empty:/opt/airflow/providers/common/compat/src - ../../../empty:/opt/airflow/providers/common/io/src @@ -76,6 +80,7 @@ services: - ../../../empty:/opt/airflow/providers/jenkins/src - ../../../empty:/opt/airflow/providers/microsoft/mssql/src - ../../../empty:/opt/airflow/providers/microsoft/psrp/src + - ../../../empty:/opt/airflow/providers/microsoft/winrm/src - ../../../empty:/opt/airflow/providers/mongo/src - ../../../empty:/opt/airflow/providers/mysql/src - ../../../empty:/opt/airflow/providers/neo4j/src @@ -111,6 +116,7 @@ services: - ../../../empty:/opt/airflow/providers/trino/src - ../../../empty:/opt/airflow/providers/vertica/src - ../../../empty:/opt/airflow/providers/weaviate/src + - ../../../empty:/opt/airflow/providers/yandex/src - ../../../empty:/opt/airflow/providers/ydb/src - ../../../empty:/opt/airflow/providers/zendesk/src # END automatically generated volumes by generate-volumes-for-sources pre-commit diff --git a/scripts/ci/docker-compose/tests-sources.yml b/scripts/ci/docker-compose/tests-sources.yml index a1214bfb32229..80d3e6aa58a4c 100644 --- a/scripts/ci/docker-compose/tests-sources.yml +++ b/scripts/ci/docker-compose/tests-sources.yml @@ -45,8 +45,10 @@ services: - ../../../providers/apache/cassandra/tests:/opt/airflow/providers/apache/cassandra/tests - ../../../providers/apache/drill/tests:/opt/airflow/providers/apache/drill/tests - ../../../providers/apache/druid/tests:/opt/airflow/providers/apache/druid/tests + - ../../../providers/apache/flink/tests:/opt/airflow/providers/apache/flink/tests - ../../../providers/apache/hive/tests:/opt/airflow/providers/apache/hive/tests - ../../../providers/apache/iceberg/tests:/opt/airflow/providers/apache/iceberg/tests + - ../../../providers/apache/impala/tests:/opt/airflow/providers/apache/impala/tests - ../../../providers/apache/kafka/tests:/opt/airflow/providers/apache/kafka/tests - ../../../providers/apache/kylin/tests:/opt/airflow/providers/apache/kylin/tests - ../../../providers/apache/livy/tests:/opt/airflow/providers/apache/livy/tests @@ -57,6 +59,8 @@ services: - ../../../providers/asana/tests:/opt/airflow/providers/asana/tests - ../../../providers/atlassian/jira/tests:/opt/airflow/providers/atlassian/jira/tests - ../../../providers/celery/tests:/opt/airflow/providers/celery/tests + - ../../../providers/cloudant/tests:/opt/airflow/providers/cloudant/tests + - ../../../providers/cncf/kubernetes/tests:/opt/airflow/providers/cncf/kubernetes/tests - ../../../providers/cohere/tests:/opt/airflow/providers/cohere/tests - ../../../providers/common/compat/tests:/opt/airflow/providers/common/compat/tests - ../../../providers/common/io/tests:/opt/airflow/providers/common/io/tests @@ -83,6 +87,7 @@ services: - ../../../providers/jenkins/tests:/opt/airflow/providers/jenkins/tests - ../../../providers/microsoft/mssql/tests:/opt/airflow/providers/microsoft/mssql/tests - ../../../providers/microsoft/psrp/tests:/opt/airflow/providers/microsoft/psrp/tests + - ../../../providers/microsoft/winrm/tests:/opt/airflow/providers/microsoft/winrm/tests - ../../../providers/mongo/tests:/opt/airflow/providers/mongo/tests - ../../../providers/mysql/tests:/opt/airflow/providers/mysql/tests - ../../../providers/neo4j/tests:/opt/airflow/providers/neo4j/tests @@ -118,6 +123,7 @@ services: - ../../../providers/trino/tests:/opt/airflow/providers/trino/tests - ../../../providers/vertica/tests:/opt/airflow/providers/vertica/tests - ../../../providers/weaviate/tests:/opt/airflow/providers/weaviate/tests + - ../../../providers/yandex/tests:/opt/airflow/providers/yandex/tests - ../../../providers/ydb/tests:/opt/airflow/providers/ydb/tests - ../../../providers/zendesk/tests:/opt/airflow/providers/zendesk/tests # END automatically generated volumes by generate-volumes-for-sources pre-commit diff --git a/scripts/ci/kubernetes/k8s_requirements.txt b/scripts/ci/kubernetes/k8s_requirements.txt index b25efb0abfd9f..7fd367c2e014b 100644 --- a/scripts/ci/kubernetes/k8s_requirements.txt +++ b/scripts/ci/kubernetes/k8s_requirements.txt @@ -2,3 +2,4 @@ -e ./providers/standard -e ./providers -e ./task_sdk +-e ./providers/cncf/kubernetes diff --git a/scripts/ci/pre_commit/compile_ui_assets_dev.py b/scripts/ci/pre_commit/compile_ui_assets_dev.py index d820db8701eba..9cc2f985ee9b5 100755 --- a/scripts/ci/pre_commit/compile_ui_assets_dev.py +++ b/scripts/ci/pre_commit/compile_ui_assets_dev.py @@ -34,23 +34,57 @@ AIRFLOW_SOURCES_PATH = Path(__file__).parents[3].resolve() UI_CACHE_DIR = AIRFLOW_SOURCES_PATH / ".build" / "ui" + + +UI_DIRECTORY = AIRFLOW_SOURCES_PATH / "airflow" / "ui" UI_HASH_FILE = UI_CACHE_DIR / "hash.txt" UI_ASSET_OUT_FILE = UI_CACHE_DIR / "asset_compile.out" UI_ASSET_OUT_DEV_MODE_FILE = UI_CACHE_DIR / "asset_compile_dev_mode.out" + +SIMPLE_AUTH_MANAGER_UI_DIRECTORY = AIRFLOW_SOURCES_PATH / "airflow" / "auth" / "managers" / "simple" / "ui" +SIMPLE_AUTH_MANAGER_UI_HASH_FILE = UI_CACHE_DIR / "simple-auth-manager-hash.txt" +SIMPLE_AUTH_MANAGER_UI_ASSET_OUT_FILE = UI_CACHE_DIR / "simple_auth_manager_asset_compile.out" +SIMPLE_AUTH_MANAGER_UI_ASSET_OUT_DEV_MODE_FILE = ( + UI_CACHE_DIR / "simple_auth_manager_asset_compile_dev_mode.out" +) + if __name__ == "__main__": - ui_directory = AIRFLOW_SOURCES_PATH / "airflow" / "ui" UI_CACHE_DIR.mkdir(parents=True, exist_ok=True) + + env = os.environ.copy() + env["FORCE_COLOR"] = "true" + if UI_HASH_FILE.exists(): # cleanup hash of ui so that next compile-assets recompiles them UI_HASH_FILE.unlink() - env = os.environ.copy() - env["FORCE_COLOR"] = "true" UI_ASSET_OUT_FILE.unlink(missing_ok=True) + + if SIMPLE_AUTH_MANAGER_UI_HASH_FILE.exists(): + # cleanup hash of ui so that next compile-assets recompiles them + SIMPLE_AUTH_MANAGER_UI_HASH_FILE.unlink() + SIMPLE_AUTH_MANAGER_UI_ASSET_OUT_FILE.unlink(missing_ok=True) + with open(UI_ASSET_OUT_DEV_MODE_FILE, "w") as f: subprocess.run( ["pnpm", "install", "--frozen-lockfile", "--config.confirmModulesPurge=false"], - cwd=os.fspath(ui_directory), + cwd=os.fspath(UI_DIRECTORY), + check=True, + stdout=f, + stderr=subprocess.STDOUT, + ) + subprocess.Popen( + ["pnpm", "dev"], + cwd=os.fspath(UI_DIRECTORY), + env=env, + stdout=f, + stderr=subprocess.STDOUT, + ) + + with open(SIMPLE_AUTH_MANAGER_UI_ASSET_OUT_DEV_MODE_FILE, "w") as f: + subprocess.run( + ["pnpm", "install", "--frozen-lockfile", "--config.confirmModulesPurge=false"], + cwd=os.fspath(SIMPLE_AUTH_MANAGER_UI_DIRECTORY), check=True, stdout=f, stderr=subprocess.STDOUT, @@ -58,7 +92,7 @@ subprocess.run( ["pnpm", "dev"], check=True, - cwd=os.fspath(ui_directory), + cwd=os.fspath(SIMPLE_AUTH_MANAGER_UI_DIRECTORY), env=env, stdout=f, stderr=subprocess.STDOUT, diff --git a/scripts/ci/pre_commit/template_context_key_sync.py b/scripts/ci/pre_commit/template_context_key_sync.py index 33c694f853a89..525a49ca59642 100755 --- a/scripts/ci/pre_commit/template_context_key_sync.py +++ b/scripts/ci/pre_commit/template_context_key_sync.py @@ -32,6 +32,9 @@ CONTEXT_HINT = ROOT_DIR.joinpath("task_sdk", "src", "airflow", "sdk", "definitions", "context.py") TEMPLATES_REF_RST = ROOT_DIR.joinpath("docs", "apache-airflow", "templates-ref.rst") +# These are only conditionally set +IGNORE = {"ds", "ds_nodash", "ts", "ts_nodash", "ts_nodash_with_tz", "logical_date"} + def _iter_template_context_keys_from_original_return() -> typing.Iterator[str]: ti_mod = ast.parse(TASKRUNNER_PY.read_text("utf-8"), str(TASKRUNNER_PY)) @@ -70,12 +73,13 @@ def extract_keys_from_dict(node: ast.Dict) -> typing.Iterator[str]: raise ValueError("'context' is not assigned a dictionary literal") yield from extract_keys_from_dict(context_assignment.value) - # Handle keys added conditionally in `if self._ti_context_from_server` + # Handle keys added conditionally in `if x := self._ti_context_from_server` for stmt in fn_get_template_context.body: if ( isinstance(stmt, ast.If) - and isinstance(stmt.test, ast.Attribute) - and stmt.test.attr == "_ti_context_from_server" + and isinstance(stmt.test, ast.NamedExpr) + and isinstance(stmt.test.value, ast.Attribute) + and stmt.test.value.attr == "_ti_context_from_server" ): for sub_stmt in stmt.body: # Get keys from `context_from_server` assignment @@ -154,7 +158,7 @@ def _compare_keys(retn_keys: set[str], decl_keys: set[str], hint_keys: set[str], ("Context type hint", hint_keys), ("templates-ref", docs_keys), ] - canonical_keys = set.union(*(s for _, s in check_candidates)) + canonical_keys = set.union(*(s for _, s in check_candidates)) - IGNORE def _check_one(identifier: str, keys: set[str]) -> int: if missing := canonical_keys.difference(keys): diff --git a/scripts/docker/entrypoint_ci.sh b/scripts/docker/entrypoint_ci.sh index 826e05e045378..8b1cbdd622205 100755 --- a/scripts/docker/entrypoint_ci.sh +++ b/scripts/docker/entrypoint_ci.sh @@ -250,17 +250,12 @@ function check_boto_upgrade() { echo echo "${COLOR_BLUE}Upgrading boto3, botocore to latest version to run Amazon tests with them${COLOR_RESET}" echo - # shellcheck disable=SC2086 - ${PACKAGING_TOOL_CMD} uninstall ${EXTRA_UNINSTALL_FLAGS} aiobotocore s3fs yandexcloud opensearch-py || true - # We need to include few dependencies to pass pip check with other dependencies: - # * oss2 as dependency as otherwise jmespath will be bumped (sync with alibaba provider) - # * cryptography is kept for snowflake-connector-python limitation (sync with snowflake provider) set -x # shellcheck disable=SC2086 - ${PACKAGING_TOOL_CMD} install ${EXTRA_INSTALL_FLAGS} --upgrade boto3 botocore \ - "oss2>=2.14.0" "cryptography<43.0.0" "opensearch-py" + ${PACKAGING_TOOL_CMD} uninstall ${EXTRA_UNINSTALL_FLAGS} aiobotocore s3fs || true + # shellcheck disable=SC2086 + ${PACKAGING_TOOL_CMD} install ${EXTRA_INSTALL_FLAGS} --upgrade boto3 botocore set +x - pip check } # Download minimum supported version of sqlalchemy to run tests with it diff --git a/task_sdk/pyproject.toml b/task_sdk/pyproject.toml index 06b3a9404065f..22872e3f4f9d6 100644 --- a/task_sdk/pyproject.toml +++ b/task_sdk/pyproject.toml @@ -30,7 +30,7 @@ dependencies = [ "msgspec>=0.19.0", "psutil>=6.1.0", "structlog>=24.4.0", - "retryhttp>=1.2.0", + "retryhttp>=1.2.0,!=1.3.0", ] classifiers = [ "Framework :: Apache Airflow", @@ -106,7 +106,7 @@ exclude_also = [ [dependency-groups] codegen = [ - "datamodel-code-generator[http]>=0.26.3", + "datamodel-code-generator[http]>=0.26.5", ] [tool.black] diff --git a/task_sdk/src/airflow/sdk/api/client.py b/task_sdk/src/airflow/sdk/api/client.py index 821e589ad522f..b1a85fc78b2c8 100644 --- a/task_sdk/src/airflow/sdk/api/client.py +++ b/task_sdk/src/airflow/sdk/api/client.py @@ -39,6 +39,7 @@ ConnectionResponse, DagRunType, PrevSuccessfulDagRunResponse, + TerminalStateNonSuccess, TerminalTIState, TIDeferredStatePayload, TIEnterRunningPayload, @@ -132,10 +133,12 @@ def start(self, id: uuid.UUID, pid: int, when: datetime) -> TIRunContext: resp = self.client.patch(f"task-instances/{id}/run", content=body.model_dump_json()) return TIRunContext.model_validate_json(resp.read()) - def finish(self, id: uuid.UUID, state: TerminalTIState, when: datetime): + def finish(self, id: uuid.UUID, state: TerminalStateNonSuccess, when: datetime): """Tell the API server that this TI has reached a terminal state.""" + if state == TerminalTIState.SUCCESS: + raise ValueError("Logic error. SUCCESS state should call the `succeed` function instead") # TODO: handle the naming better. finish sounds wrong as "even" deferred is essentially finishing. - body = TITerminalStatePayload(end_date=when, state=TerminalTIState(state)) + body = TITerminalStatePayload(end_date=when, state=TerminalStateNonSuccess(state)) self.client.patch(f"task-instances/{id}/state", content=body.model_dump_json()) def succeed(self, id: uuid.UUID, when: datetime, task_outlets, outlet_events): @@ -253,6 +256,17 @@ class XComOperations: def __init__(self, client: Client): self.client = client + def head(self, dag_id: str, run_id: str, task_id: str, key: str) -> int: + """Get the number of mapped XCom values.""" + resp = self.client.head(f"xcoms/{dag_id}/{run_id}/{task_id}/{key}") + + # content_range: str | None + if not (content_range := resp.headers["Content-Range"]) or not content_range.startswith( + "map_indexes " + ): + raise RuntimeError(f"Unable to parse Content-Range header from HEAD {resp.request.url}") + return int(content_range[len("map_indexes ") :]) + def get( self, dag_id: str, run_id: str, task_id: str, key: str, map_index: int | None = None ) -> XComResponse: @@ -260,7 +274,7 @@ def get( # TODO: check if we need to use map_index as params in the uri # ref: https://github.com/apache/airflow/blob/v2-10-stable/airflow/api_connexion/openapi/v1.yaml#L1785C1-L1785C81 params = {} - if map_index is not None: + if map_index is not None and map_index >= 0: params.update({"map_index": map_index}) try: resp = self.client.get(f"xcoms/{dag_id}/{run_id}/{task_id}/{key}", params=params) @@ -290,7 +304,7 @@ def set( # TODO: check if we need to use map_index as params in the uri # ref: https://github.com/apache/airflow/blob/v2-10-stable/airflow/api_connexion/openapi/v1.yaml#L1785C1-L1785C81 params = {} - if map_index: + if map_index is not None and map_index >= 0: params = {"map_index": map_index} self.client.post(f"xcoms/{dag_id}/{run_id}/{task_id}/{key}", params=params, json=value) # Any error from the server will anyway be propagated down to the supervisor, @@ -344,6 +358,7 @@ def noop_handler(request: httpx.Request) -> httpx.Response: "logical_date": "2021-01-01T00:00:00Z", "start_date": "2021-01-01T00:00:00Z", "run_type": DagRunType.MANUAL, + "run_after": "2021-01-01T00:00:00Z", }, "max_tries": 0, }, @@ -427,7 +442,7 @@ def xcoms(self) -> XComOperations: @lru_cache() # type: ignore[misc] @property def assets(self) -> AssetOperations: - """Operations related to XComs.""" + """Operations related to Assets.""" return AssetOperations(self) diff --git a/task_sdk/src/airflow/sdk/api/datamodels/_generated.py b/task_sdk/src/airflow/sdk/api/datamodels/_generated.py index 1d6d0eb4156c3..ac459e817f2eb 100644 --- a/task_sdk/src/airflow/sdk/api/datamodels/_generated.py +++ b/task_sdk/src/airflow/sdk/api/datamodels/_generated.py @@ -1,3 +1,7 @@ +# generated by datamodel-codegen: +# filename: http://0.0.0.0:9091/execution/openapi.json +# version: 0.26.5 + # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -14,11 +18,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - -# generated by datamodel-codegen: -# filename: http://0.0.0.0:9091/execution/openapi.json -# version: 0.26.3 - from __future__ import annotations from datetime import datetime, timedelta @@ -26,7 +25,7 @@ from typing import Annotated, Any, Literal from uuid import UUID -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, ConfigDict, Field, RootModel class AssetProfile(BaseModel): @@ -38,6 +37,9 @@ class AssetProfile(BaseModel): AssetUriRef will have uri and asset_type defined. """ + model_config = ConfigDict( + extra="forbid", + ) name: Annotated[str | None, Field(title="Name")] = None uri: Annotated[str | None, Field(title="Uri")] = None asset_type: Annotated[str, Field(title="Asset Type")] @@ -94,6 +96,10 @@ class IntermediateTIState(str, Enum): DEFERRED = "deferred" +class JsonValue(RootModel[Any]): + root: Any + + class PrevSuccessfulDagRunResponse(BaseModel): """ Schema for response with previous successful DagRun information for Task Template Context. @@ -110,6 +116,9 @@ class TIDeferredStatePayload(BaseModel): Schema for updating TaskInstance to a deferred state. """ + model_config = ConfigDict( + extra="forbid", + ) state: Annotated[Literal["deferred"] | None, Field(title="State")] = "deferred" classpath: Annotated[str, Field(title="Classpath")] trigger_kwargs: Annotated[dict[str, Any] | None, Field(title="Trigger Kwargs")] = None @@ -122,6 +131,9 @@ class TIEnterRunningPayload(BaseModel): Schema for updating TaskInstance to 'RUNNING' state with minimal required fields. """ + model_config = ConfigDict( + extra="forbid", + ) state: Annotated[Literal["running"] | None, Field(title="State")] = "running" hostname: Annotated[str, Field(title="Hostname")] unixname: Annotated[str, Field(title="Unixname")] @@ -134,6 +146,9 @@ class TIHeartbeatInfo(BaseModel): Schema for TaskInstance heartbeat endpoint. """ + model_config = ConfigDict( + extra="forbid", + ) hostname: Annotated[str, Field(title="Hostname")] pid: Annotated[int, Field(title="Pid")] @@ -143,6 +158,9 @@ class TIRescheduleStatePayload(BaseModel): Schema for updating TaskInstance to a up_for_reschedule state. """ + model_config = ConfigDict( + extra="forbid", + ) state: Annotated[Literal["up_for_reschedule"] | None, Field(title="State")] = "up_for_reschedule" reschedule_date: Annotated[datetime, Field(title="Reschedule Date")] end_date: Annotated[datetime, Field(title="End Date")] @@ -153,6 +171,9 @@ class TIRuntimeCheckPayload(BaseModel): Payload for performing Runtime checks on the TaskInstance model as requested by the SDK. """ + model_config = ConfigDict( + extra="forbid", + ) inlets: Annotated[list[AssetProfile] | None, Field(title="Inlets")] = None outlets: Annotated[list[AssetProfile] | None, Field(title="Outlets")] = None @@ -162,6 +183,9 @@ class TISuccessStatePayload(BaseModel): Schema for updating TaskInstance to success state. """ + model_config = ConfigDict( + extra="forbid", + ) state: Annotated[Literal["success"] | None, Field(title="State")] = "success" end_date: Annotated[datetime, Field(title="End Date")] task_outlets: Annotated[list[AssetProfile] | None, Field(title="Task Outlets")] = None @@ -173,15 +197,17 @@ class TITargetStatePayload(BaseModel): Schema for updating TaskInstance to a target state, excluding terminal and running states. """ + model_config = ConfigDict( + extra="forbid", + ) state: IntermediateTIState -class TerminalTIState(str, Enum): +class TerminalStateNonSuccess(str, Enum): """ - States that a Task Instance can be in that indicate it has reached a terminal state. + TaskInstance states that can be reported without extra information. """ - SUCCESS = "success" FAILED = "failed" SKIPPED = "skipped" REMOVED = "removed" @@ -224,40 +250,59 @@ class XComResponse(BaseModel): value: Annotated[Any, Field(title="Value")] -class BundleInfo(BaseModel): - name: str - version: str | None = None - - class TaskInstance(BaseModel): """ Schema for TaskInstance model with minimal required fields needed for Runtime. """ + model_config = ConfigDict( + extra="forbid", + ) id: Annotated[UUID, Field(title="Id")] task_id: Annotated[str, Field(title="Task Id")] dag_id: Annotated[str, Field(title="Dag Id")] run_id: Annotated[str, Field(title="Run Id")] try_number: Annotated[int, Field(title="Try Number")] - map_index: Annotated[int, Field(title="Map Index")] = -1 + map_index: Annotated[int | None, Field(title="Map Index")] = -1 hostname: Annotated[str | None, Field(title="Hostname")] = None +class BundleInfo(BaseModel): + """ + Schema for telling task which bundle to run with. + """ + + name: Annotated[str, Field(title="Name")] + version: Annotated[str | None, Field(title="Version")] = None + + +class TerminalTIState(str, Enum): + SUCCESS = "success" + FAILED = "failed" + SKIPPED = "skipped" + REMOVED = "removed" + FAIL_WITHOUT_RETRY = "fail_without_retry" + + class DagRun(BaseModel): """ Schema for DagRun model with minimal required fields needed for Runtime. """ + model_config = ConfigDict( + extra="forbid", + ) dag_id: Annotated[str, Field(title="Dag Id")] run_id: Annotated[str, Field(title="Run Id")] - logical_date: Annotated[datetime, Field(title="Logical Date")] + logical_date: Annotated[datetime | None, Field(title="Logical Date")] data_interval_start: Annotated[datetime | None, Field(title="Data Interval Start")] = None data_interval_end: Annotated[datetime | None, Field(title="Data Interval End")] = None + run_after: Annotated[datetime, Field(title="Run After")] start_date: Annotated[datetime, Field(title="Start Date")] end_date: Annotated[datetime | None, Field(title="End Date")] = None run_type: DagRunType conf: Annotated[dict[str, Any] | None, Field(title="Conf")] = None - external_trigger: Annotated[bool, Field(title="External Trigger")] = False + external_trigger: Annotated[bool | None, Field(title="External Trigger")] = False class HTTPValidationError(BaseModel): @@ -273,6 +318,7 @@ class TIRunContext(BaseModel): max_tries: Annotated[int, Field(title="Max Tries")] variables: Annotated[list[VariableResponse] | None, Field(title="Variables")] = None connections: Annotated[list[ConnectionResponse] | None, Field(title="Connections")] = None + upstream_map_indexes: Annotated[dict[str, int] | None, Field(title="Upstream Map Indexes")] = None class TITerminalStatePayload(BaseModel): @@ -280,5 +326,8 @@ class TITerminalStatePayload(BaseModel): Schema for updating TaskInstance to a terminal state except SUCCESS state. """ - state: TerminalTIState + model_config = ConfigDict( + extra="forbid", + ) + state: TerminalStateNonSuccess end_date: Annotated[datetime, Field(title="End Date")] diff --git a/task_sdk/src/airflow/sdk/definitions/_internal/expandinput.py b/task_sdk/src/airflow/sdk/definitions/_internal/expandinput.py new file mode 100644 index 0000000000000..18e7d11fb08ae --- /dev/null +++ b/task_sdk/src/airflow/sdk/definitions/_internal/expandinput.py @@ -0,0 +1,278 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import functools +import operator +from collections.abc import Iterable, Mapping, Sequence, Sized +from typing import TYPE_CHECKING, Any, ClassVar, Union + +import attrs + +from airflow.sdk.definitions._internal.mixins import ResolveMixin + +if TYPE_CHECKING: + from airflow.sdk.definitions.xcom_arg import XComArg + from airflow.sdk.types import Operator + from airflow.typing_compat import TypeGuard + +ExpandInput = Union["DictOfListsExpandInput", "ListOfDictsExpandInput"] + +# Each keyword argument to expand() can be an XComArg, sequence, or dict (not +# any mapping since we need the value to be ordered). +OperatorExpandArgument = Union["MappedArgument", "XComArg", Sequence, dict[str, Any]] + +# The single argument of expand_kwargs() can be an XComArg, or a list with each +# element being either an XComArg or a dict. +OperatorExpandKwargsArgument = Union["XComArg", Sequence[Union["XComArg", Mapping[str, Any]]]] + + +class NotFullyPopulated(RuntimeError): + """ + Raise when ``get_map_lengths`` cannot populate all mapping metadata. + + This is generally due to not all upstream tasks have finished when the + function is called. + """ + + def __init__(self, missing: set[str]) -> None: + self.missing = missing + + def __str__(self) -> str: + keys = ", ".join(repr(k) for k in sorted(self.missing)) + return f"Failed to populate all mapping metadata; missing: {keys}" + + +# To replace tedious isinstance() checks. +def is_mappable(v: Any) -> TypeGuard[OperatorExpandArgument]: + from airflow.sdk.definitions.xcom_arg import XComArg + + return isinstance(v, (MappedArgument, XComArg, Mapping, Sequence)) and not isinstance(v, str) + + +# To replace tedious isinstance() checks. +def _is_parse_time_mappable(v: OperatorExpandArgument) -> TypeGuard[Mapping | Sequence]: + from airflow.models.xcom_arg import XComArg + + return not isinstance(v, (MappedArgument, XComArg)) + + +# To replace tedious isinstance() checks. +def _needs_run_time_resolution(v: OperatorExpandArgument) -> TypeGuard[MappedArgument | XComArg]: + from airflow.models.xcom_arg import XComArg + + return isinstance(v, (MappedArgument, XComArg)) + + +@attrs.define(kw_only=True) +class MappedArgument(ResolveMixin): + """ + Stand-in stub for task-group-mapping arguments. + + This is very similar to an XComArg, but resolved differently. Declared here + (instead of in the task group module) to avoid import cycles. + """ + + _input: ExpandInput = attrs.field() + _key: str + + @_input.validator + def _validate_input(self, _, input): + if isinstance(input, DictOfListsExpandInput): + for value in input.value.values(): + if isinstance(value, MappedArgument): + raise ValueError("Nested Mapped TaskGroups are not yet supported") + + def iter_references(self) -> Iterable[tuple[Operator, str]]: + yield from self._input.iter_references() + + def resolve(self, context: Mapping[str, Any]) -> Any: + data, _ = self._input.resolve(context) + return data[self._key] + + +@attrs.define() +class DictOfListsExpandInput(ResolveMixin): + """ + Storage type of a mapped operator's mapped kwargs. + + This is created from ``expand(**kwargs)``. + """ + + value: dict[str, OperatorExpandArgument] + + EXPAND_INPUT_TYPE: ClassVar[str] = "dict-of-lists" + + def _iter_parse_time_resolved_kwargs(self) -> Iterable[tuple[str, Sized]]: + """Generate kwargs with values available on parse-time.""" + return ((k, v) for k, v in self.value.items() if _is_parse_time_mappable(v)) + + def get_parse_time_mapped_ti_count(self) -> int: + if not self.value: + return 0 + literal_values = [len(v) for _, v in self._iter_parse_time_resolved_kwargs()] + if len(literal_values) != len(self.value): + literal_keys = (k for k, _ in self._iter_parse_time_resolved_kwargs()) + raise NotFullyPopulated(set(self.value).difference(literal_keys)) + return functools.reduce(operator.mul, literal_values, 1) + + def _get_map_lengths( + self, resolved_vals: dict[str, Sized], upstream_map_indexes: dict[str, int] + ) -> dict[str, int]: + """ + Return dict of argument name to map length. + + If any arguments are not known right now (upstream task not finished), + they will not be present in the dict. + """ + + # TODO: This initiates one API call for each XComArg. Would it be + # more efficient to do one single call and unpack the value here? + def _get_length(k: str, v: OperatorExpandArgument) -> int | None: + from airflow.sdk.definitions.xcom_arg import XComArg, get_task_map_length + + if isinstance(v, XComArg): + return get_task_map_length(v, resolved_vals[k], upstream_map_indexes) + + # Unfortunately a user-defined TypeGuard cannot apply negative type + # narrowing. https://github.com/python/typing/discussions/1013 + if TYPE_CHECKING: + assert isinstance(v, Sized) + return len(v) + + map_lengths = { + k: res for k, v in self.value.items() if v is not None if (res := _get_length(k, v)) is not None + } + if len(map_lengths) < len(self.value): + raise NotFullyPopulated(set(self.value).difference(map_lengths)) + return map_lengths + + def _expand_mapped_field(self, key: str, value: Any, map_index: int, all_lengths: dict[str, int]) -> Any: + def _find_index_for_this_field(index: int) -> int: + # Need to use the original user input to retain argument order. + for mapped_key in reversed(self.value): + mapped_length = all_lengths[mapped_key] + if mapped_length < 1: + raise RuntimeError(f"cannot expand field mapped to length {mapped_length!r}") + if mapped_key == key: + return index % mapped_length + index //= mapped_length + return -1 + + found_index = _find_index_for_this_field(map_index) + if found_index < 0: + return value + if isinstance(value, Sequence): + return value[found_index] + if not isinstance(value, dict): + raise TypeError(f"can't map over value of type {type(value)}") + for i, (k, v) in enumerate(value.items()): + if i == found_index: + return k, v + raise IndexError(f"index {map_index} is over mapped length") + + def iter_references(self) -> Iterable[tuple[Operator, str]]: + from airflow.models.xcom_arg import XComArg + + for x in self.value.values(): + if isinstance(x, XComArg): + yield from x.iter_references() + + def resolve(self, context: Mapping[str, Any]) -> tuple[Mapping[str, Any], set[int]]: + map_index: int | None = context["ti"].map_index + if map_index is None or map_index < 0: + raise RuntimeError("can't resolve task-mapping argument without expanding") + + upstream_map_indexes = getattr(context["ti"], "_upstream_map_indexes", {}) + + # TODO: This initiates one API call for each XComArg. Would it be + # more efficient to do one single call and unpack the value here? + resolved = { + k: v.resolve(context) if _needs_run_time_resolution(v) else v for k, v in self.value.items() + } + + all_lengths = self._get_map_lengths(resolved, upstream_map_indexes) + + data = {k: self._expand_mapped_field(k, v, map_index, all_lengths) for k, v in resolved.items()} + literal_keys = {k for k, _ in self._iter_parse_time_resolved_kwargs()} + resolved_oids = {id(v) for k, v in data.items() if k not in literal_keys} + return data, resolved_oids + + +def _describe_type(value: Any) -> str: + if value is None: + return "None" + return type(value).__name__ + + +@attrs.define() +class ListOfDictsExpandInput(ResolveMixin): + """ + Storage type of a mapped operator's mapped kwargs. + + This is created from ``expand_kwargs(xcom_arg)``. + """ + + value: OperatorExpandKwargsArgument + + EXPAND_INPUT_TYPE: ClassVar[str] = "list-of-dicts" + + def get_parse_time_mapped_ti_count(self) -> int: + if isinstance(self.value, Sized): + return len(self.value) + raise NotFullyPopulated({"expand_kwargs() argument"}) + + def iter_references(self) -> Iterable[tuple[Operator, str]]: + from airflow.models.xcom_arg import XComArg + + if isinstance(self.value, XComArg): + yield from self.value.iter_references() + else: + for x in self.value: + if isinstance(x, XComArg): + yield from x.iter_references() + + def resolve(self, context: Mapping[str, Any]) -> tuple[Mapping[str, Any], set[int]]: + map_index = context["ti"].map_index + if map_index < 0: + raise RuntimeError("can't resolve task-mapping argument without expanding") + + mapping: Any = None + if isinstance(self.value, Sized): + mapping = self.value[map_index] + if not isinstance(mapping, Mapping): + mapping = mapping.resolve(context) + else: + mappings = self.value.resolve(context) + if not isinstance(mappings, Sequence): + raise ValueError(f"expand_kwargs() expects a list[dict], not {_describe_type(mappings)}") + mapping = mappings[map_index] + + if not isinstance(mapping, Mapping): + raise ValueError(f"expand_kwargs() expects a list[dict], not list[{_describe_type(mapping)}]") + + for key in mapping: + if not isinstance(key, str): + raise ValueError( + f"expand_kwargs() input dict keys must all be str, " + f"but {key!r} is of type {_describe_type(key)}" + ) + # filter out parse time resolved values from the resolved_oids + resolved_oids = {id(v) for k, v in mapping.items() if not _is_parse_time_mappable(v)} + + return mapping, resolved_oids diff --git a/task_sdk/src/airflow/sdk/definitions/_internal/mixins.py b/task_sdk/src/airflow/sdk/definitions/_internal/mixins.py index fcd68ba20b2c6..93fd9431cbe38 100644 --- a/task_sdk/src/airflow/sdk/definitions/_internal/mixins.py +++ b/task_sdk/src/airflow/sdk/definitions/_internal/mixins.py @@ -133,7 +133,7 @@ def iter_references(self) -> Iterable[tuple[Operator, str]]: """ raise NotImplementedError - def resolve(self, context: Context, *, include_xcom: bool = True) -> Any: + def resolve(self, context: Context) -> Any: """ Resolve this value for runtime. diff --git a/task_sdk/src/airflow/sdk/definitions/_internal/templater.py b/task_sdk/src/airflow/sdk/definitions/_internal/templater.py index d7028d4d6bca7..b50c4dbb3cadb 100644 --- a/task_sdk/src/airflow/sdk/definitions/_internal/templater.py +++ b/task_sdk/src/airflow/sdk/definitions/_internal/templater.py @@ -50,7 +50,7 @@ class LiteralValue(ResolveMixin): def iter_references(self) -> Iterable[tuple[Operator, str]]: return () - def resolve(self, context: Context, *, include_xcom: bool = True) -> Any: + def resolve(self, context: Context) -> Any: return self.value @@ -179,7 +179,7 @@ def render_template( return self._render_object_storage_path(value, context, jinja_env) if resolve := getattr(value, "resolve", None): - return resolve(context, include_xcom=True) + return resolve(context) # Fast path for common built-in collections. if value.__class__ is tuple: diff --git a/task_sdk/src/airflow/sdk/definitions/dag.py b/task_sdk/src/airflow/sdk/definitions/dag.py index 5662d542859f7..3c8d99737ab8e 100644 --- a/task_sdk/src/airflow/sdk/definitions/dag.py +++ b/task_sdk/src/airflow/sdk/definitions/dag.py @@ -359,6 +359,13 @@ class DAG: # argument "description" in call to "DAG"`` etc), so for init=True args we use the `default=Factory()` # style + def __rich_repr__(self): + yield "dag_id", self.dag_id + yield "schedule", self.schedule + yield "#tasks", len(self.tasks) + + __rich_repr__.angular = True # type: ignore[attr-defined] + # NOTE: When updating arguments here, please also keep arguments in @dag() # below in sync. (Search for 'def dag(' in this file.) dag_id: str = attrs.field(kw_only=False, validator=attrs.validators.instance_of(str)) diff --git a/task_sdk/src/airflow/sdk/definitions/mappedoperator.py b/task_sdk/src/airflow/sdk/definitions/mappedoperator.py index dddeaf21cf2bc..92f83792f3c79 100644 --- a/task_sdk/src/airflow/sdk/definitions/mappedoperator.py +++ b/task_sdk/src/airflow/sdk/definitions/mappedoperator.py @@ -27,11 +27,6 @@ import methodtools from airflow.models.abstractoperator import NotMapped -from airflow.models.expandinput import ( - DictOfListsExpandInput, - ListOfDictsExpandInput, - is_mappable, -) from airflow.sdk.definitions._internal.abstractoperator import ( DEFAULT_EXECUTOR, DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST, @@ -47,13 +42,16 @@ DEFAULT_WEIGHT_RULE, AbstractOperator, ) +from airflow.sdk.definitions._internal.expandinput import ( + DictOfListsExpandInput, + ListOfDictsExpandInput, + is_mappable, +) from airflow.sdk.definitions._internal.types import NOTSET from airflow.serialization.enums import DagAttributeTypes from airflow.task.priority_strategy import PriorityWeightStrategy, validate_and_load_priority_weight_strategy -from airflow.triggers.base import StartTriggerArgs from airflow.typing_compat import Literal from airflow.utils.helpers import is_container, prevent_duplicates -from airflow.utils.task_instance_session import get_current_task_instance_session from airflow.utils.xcom import XCOM_RETURN_KEY if TYPE_CHECKING: @@ -61,7 +59,6 @@ import jinja2 # Slow import. import pendulum - from sqlalchemy.orm.session import Session from airflow.models.abstractoperator import ( TaskStateChangeCallback, @@ -78,6 +75,7 @@ from airflow.sdk.definitions.param import ParamsDict from airflow.sdk.types import Operator from airflow.ti_deps.deps.base_ti_dep import BaseTIDep + from airflow.triggers.base import StartTriggerArgs from airflow.typing_compat import TypeGuard from airflow.utils.context import Context from airflow.utils.operator_resources import Resources @@ -683,16 +681,14 @@ def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]: """Implement DAGNode.""" return DagAttributeTypes.OP, self.task_id - def _expand_mapped_kwargs( - self, context: Mapping[str, Any], session: Session, *, include_xcom: bool - ) -> tuple[Mapping[str, Any], set[int]]: + def _expand_mapped_kwargs(self, context: Mapping[str, Any]) -> tuple[Mapping[str, Any], set[int]]: """ Get the kwargs to create the unmapped operator. This exists because taskflow operators expand against op_kwargs, not the entire operator kwargs dict. """ - return self._get_specified_expand_input().resolve(context, session, include_xcom=include_xcom) + return self._get_specified_expand_input().resolve(context) def _get_unmap_kwargs(self, mapped_kwargs: Mapping[str, Any], *, strict: bool) -> dict[str, Any]: """ @@ -726,70 +722,7 @@ def _get_unmap_kwargs(self, mapped_kwargs: Mapping[str, Any], *, strict: bool) - "params": params, } - def expand_start_from_trigger(self, *, context: Context, session: Session) -> bool: - """ - Get the start_from_trigger value of the current abstract operator. - - MappedOperator uses this to unmap start_from_trigger to decide whether to start the task - execution directly from triggerer. - - :meta private: - """ - # start_from_trigger only makes sense when start_trigger_args exists. - if not self.start_trigger_args: - return False - - mapped_kwargs, _ = self._expand_mapped_kwargs(context, session, include_xcom=False) - if self._disallow_kwargs_override: - prevent_duplicates( - self.partial_kwargs, - mapped_kwargs, - fail_reason="unmappable or already specified", - ) - - # Ordering is significant; mapped kwargs should override partial ones. - return mapped_kwargs.get( - "start_from_trigger", self.partial_kwargs.get("start_from_trigger", self.start_from_trigger) - ) - - def expand_start_trigger_args(self, *, context: Context, session: Session) -> StartTriggerArgs | None: - """ - Get the kwargs to create the unmapped start_trigger_args. - - This method is for allowing mapped operator to start execution from triggerer. - """ - if not self.start_trigger_args: - return None - - mapped_kwargs, _ = self._expand_mapped_kwargs(context, session, include_xcom=False) - if self._disallow_kwargs_override: - prevent_duplicates( - self.partial_kwargs, - mapped_kwargs, - fail_reason="unmappable or already specified", - ) - - # Ordering is significant; mapped kwargs should override partial ones. - trigger_kwargs = mapped_kwargs.get( - "trigger_kwargs", - self.partial_kwargs.get("trigger_kwargs", self.start_trigger_args.trigger_kwargs), - ) - next_kwargs = mapped_kwargs.get( - "next_kwargs", - self.partial_kwargs.get("next_kwargs", self.start_trigger_args.next_kwargs), - ) - timeout = mapped_kwargs.get( - "trigger_timeout", self.partial_kwargs.get("trigger_timeout", self.start_trigger_args.timeout) - ) - return StartTriggerArgs( - trigger_cls=self.start_trigger_args.trigger_cls, - trigger_kwargs=trigger_kwargs, - next_method=self.start_trigger_args.next_method, - next_kwargs=next_kwargs, - timeout=timeout, - ) - - def unmap(self, resolve: None | Mapping[str, Any] | tuple[Context, Session]) -> BaseOperator: + def unmap(self, resolve: None | Mapping[str, Any]) -> BaseOperator: """ Get the "normal" Operator after applying the current mapping. @@ -798,30 +731,21 @@ def unmap(self, resolve: None | Mapping[str, Any] | tuple[Context, Session]) -> not a class (i.e. this DAG has been deserialized), this returns a SerializedBaseOperator that "looks like" the actual unmapping result. - If *resolve* is a two-tuple (context, session), the information is used - to resolve the mapped arguments into init arguments. If it is a mapping, - no resolving happens, the mapping directly provides those init arguments - resolved from mapped kwargs. - :meta private: """ if isinstance(self.operator_class, type): if isinstance(resolve, Mapping): kwargs = resolve elif resolve is not None: - kwargs, _ = self._expand_mapped_kwargs(*resolve, include_xcom=True) + kwargs, _ = self._expand_mapped_kwargs(*resolve) else: raise RuntimeError("cannot unmap a non-serialized operator without context") kwargs = self._get_unmap_kwargs(kwargs, strict=self._disallow_kwargs_override) is_setup = kwargs.pop("is_setup", False) is_teardown = kwargs.pop("is_teardown", False) on_failure_fail_dagrun = kwargs.pop("on_failure_fail_dagrun", False) + kwargs["task_id"] = self.task_id op = self.operator_class(**kwargs, _airflow_from_mapped=True) - # We need to overwrite task_id here because BaseOperator further - # mangles the task_id based on the task hierarchy (namely, group_id - # is prepended, and '__N' appended to deduplicate). This is hacky, - # but better than duplicating the whole mangling logic. - op.task_id = self.task_id op.is_setup = is_setup op.is_teardown = is_teardown op.on_failure_fail_dagrun = on_failure_fail_dagrun @@ -856,7 +780,7 @@ def prepare_for_execution(self) -> MappedOperator: def iter_mapped_dependencies(self) -> Iterator[Operator]: """Upstream dependencies that provide XComs used by this task for task mapping.""" - from airflow.models.xcom_arg import XComArg + from airflow.sdk.definitions.xcom_arg import XComArg for operator, _ in XComArg.iter_xcom_references(self._get_specified_expand_input()): yield operator @@ -886,23 +810,14 @@ def render_template_fields( :param context: Context dict with values to apply on content. :param jinja_env: Jinja environment to use for rendering. """ - from airflow.utils.context import context_update_for_unmapped + from airflow.sdk.execution_time.context import context_update_for_unmapped if not jinja_env: jinja_env = self.get_template_env() - # We retrieve the session here, stored by _run_raw_task in set_current_task_session - # context manager - we cannot pass the session via @provide_session because the signature - # of render_template_fields is defined by BaseOperator and there are already many subclasses - # overriding it, so changing the signature is not an option. However render_template_fields is - # always executed within "_run_raw_task" so we make sure that _run_raw_task uses the - # set_current_task_session context manager to store the session in the current task. - session = get_current_task_instance_session() - - mapped_kwargs, seen_oids = self._expand_mapped_kwargs(context, session, include_xcom=True) + mapped_kwargs, seen_oids = self._expand_mapped_kwargs(context) unmapped_task = self.unmap(mapped_kwargs) - # TODO: Task-SDK: remove arg-type ignore once Kaxil's PR lands - context_update_for_unmapped(context, unmapped_task) # type: ignore[arg-type] + context_update_for_unmapped(context, unmapped_task) # Since the operators that extend `BaseOperator` are not subclasses of # `MappedOperator`, we need to call `_do_render_template_fields` from diff --git a/task_sdk/src/airflow/sdk/definitions/param.py b/task_sdk/src/airflow/sdk/definitions/param.py index cd3ccec26a48a..d9eec82a147fc 100644 --- a/task_sdk/src/airflow/sdk/definitions/param.py +++ b/task_sdk/src/airflow/sdk/definitions/param.py @@ -67,8 +67,7 @@ def _check_json(value): json.dumps(value) except Exception: raise ParamValidationError( - "All provided parameters must be json-serializable. " - f"The value '{value}' is not serializable." + f"All provided parameters must be json-serializable. The value '{value}' is not serializable." ) def resolve(self, value: Any = NOTSET, suppress_exception: bool = False) -> Any: @@ -294,7 +293,7 @@ def __init__(self, current_dag: DAG, name: str, default: Any = NOTSET): def iter_references(self) -> Iterable[tuple[Operator, str]]: return () - def resolve(self, context: Context, *, include_xcom: bool = True) -> Any: + def resolve(self, context: Context) -> Any: """Pull DagParam value from DagRun context. This method is run during ``op.execute()``.""" with contextlib.suppress(KeyError): if context["dag_run"].conf: diff --git a/task_sdk/src/airflow/sdk/definitions/xcom_arg.py b/task_sdk/src/airflow/sdk/definitions/xcom_arg.py index 436cd9d005012..42eef8321f4b7 100644 --- a/task_sdk/src/airflow/sdk/definitions/xcom_arg.py +++ b/task_sdk/src/airflow/sdk/definitions/xcom_arg.py @@ -20,30 +20,28 @@ import contextlib import inspect import itertools -from collections.abc import Iterable, Iterator, Mapping, Sequence -from typing import TYPE_CHECKING, Any, Callable, Union, overload +from collections.abc import Iterable, Iterator, Mapping, Sequence, Sized +from functools import singledispatch +from typing import TYPE_CHECKING, Any, Callable, overload from airflow.exceptions import AirflowException, XComNotFound from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator from airflow.sdk.definitions._internal.mixins import DependencyMixin, ResolveMixin from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet -from airflow.utils.session import NEW_SESSION, provide_session +from airflow.sdk.execution_time.lazy_sequence import LazyXComSequence from airflow.utils.setup_teardown import SetupTeardownContext from airflow.utils.trigger_rule import TriggerRule from airflow.utils.xcom import XCOM_RETURN_KEY if TYPE_CHECKING: - from sqlalchemy.orm import Session - from airflow.sdk.definitions.baseoperator import BaseOperator - from airflow.sdk.definitions.dag import DAG from airflow.sdk.types import Operator from airflow.utils.edgemodifier import EdgeModifier # Callable objects contained by MapXComArg. We only accept callables from # the user, but deserialize them into strings in a serialized XComArg for # safety (those callables are arbitrary user code). -MapCallables = Sequence[Union[Callable[[Any], Any], str]] +MapCallables = Sequence[Callable[[Any], Any]] class XComArg(ResolveMixin, DependencyMixin): @@ -168,20 +166,6 @@ def _serialize(self) -> dict[str, Any]: """ raise NotImplementedError() - @classmethod - def _deserialize(cls, data: dict[str, Any], dag: DAG) -> XComArg: - """ - Deserialize an XComArg. - - The implementation should be the inverse function to ``serialize``, - implementing given a data dict converted from this XComArg derivative, - how the original XComArg should be created. DAG serialization relies on - additional information added in ``serialize_xcom_arg`` to dispatch data - dicts to the correct ``_deserialize`` information, so this function does - not need to validate whether the incoming data contains correct keys. - """ - raise NotImplementedError() - def map(self, f: Callable[[Any], Any]) -> MapXComArg: return MapXComArg(self, [f]) @@ -191,9 +175,7 @@ def zip(self, *others: XComArg, fillvalue: Any = NOTSET) -> ZipXComArg: def concat(self, *others: XComArg) -> ConcatXComArg: return ConcatXComArg([self, *others]) - def resolve( - self, context: Mapping[str, Any], session: Session | None = None, *, include_xcom: bool = True - ) -> Any: + def resolve(self, context: Mapping[str, Any]) -> Any: raise NotImplementedError() def __enter__(self): @@ -266,7 +248,7 @@ def __str__(self) -> str: **Example**: to use XComArg at BashOperator:: - BashOperator(cmd=f"... { xcomarg } ...") + BashOperator(cmd=f"... {xcomarg} ...") :return: """ @@ -285,10 +267,6 @@ def __str__(self) -> str: def _serialize(self) -> dict[str, Any]: return {"task_id": self.operator.task_id, "key": self.key} - @classmethod - def _deserialize(cls, data: dict[str, Any], dag: DAG) -> XComArg: - return cls(dag.get_task(data["task_id"]), data["key"]) - @property def is_setup(self) -> bool: return self.operator.is_setup @@ -354,17 +332,15 @@ def concat(self, *others: XComArg) -> ConcatXComArg: raise ValueError("cannot concatenate non-return XCom") return super().concat(*others) - # TODO: Task-SDK: Remove session argument once everything is ported over to Task SDK - def resolve( - self, context: Mapping[str, Any], session: Session | None = None, *, include_xcom: bool = True - ) -> Any: + def resolve(self, context: Mapping[str, Any]) -> Any: ti = context["ti"] task_id = self.operator.task_id - map_indexes = context.get("_upstream_map_indexes", {}).get(task_id) + + if self.operator.is_mapped: + return LazyXComSequence[Any](xcom_arg=self, ti=ti) result = ti.xcom_pull( task_ids=task_id, - map_indexes=map_indexes, key=self.key, default=NOTSET, ) @@ -403,18 +379,8 @@ def __init__(self, value: Sequence | dict, callables: MapCallables) -> None: def __getitem__(self, index: Any) -> Any: value = self.value[index] - # In the worker, we can access all actual callables. Call them. - callables = [f for f in self.callables if callable(f)] - if len(callables) == len(self.callables): - for f in callables: - value = f(value) - return value - - # In the scheduler, we don't have access to the actual callables, nor do - # we want to run it since it's arbitrary code. This builds a string to - # represent the call chain in the UI or logs instead. - for v in self.callables: - value = f"{_get_callable_name(v)}({value})" + for f in self.callables: + value = f(value) return value def __len__(self) -> int: @@ -448,12 +414,6 @@ def _serialize(self) -> dict[str, Any]: "callables": [inspect.getsource(c) if callable(c) else c for c in self.callables], } - @classmethod - def _deserialize(cls, data: dict[str, Any], dag: DAG) -> XComArg: - # We are deliberately NOT deserializing the callables. These are shown - # in the UI, and displaying a function object is useless. - return cls(deserialize_xcom_arg(data["arg"], dag), data["callables"]) - def iter_references(self) -> Iterator[tuple[Operator, str]]: yield from self.arg.iter_references() @@ -461,12 +421,8 @@ def map(self, f: Callable[[Any], Any]) -> MapXComArg: # Flatten arg.map(f1).map(f2) into one MapXComArg. return MapXComArg(self.arg, [*self.callables, f]) - # TODO: Task-SDK: Remove session argument once everything is ported over to Task SDK - @provide_session - def resolve( - self, context: Mapping[str, Any], session: Session = NEW_SESSION, *, include_xcom: bool = True - ) -> Any: - value = self.arg.resolve(context, session=session, include_xcom=include_xcom) + def resolve(self, context: Mapping[str, Any]) -> Any: + value = self.arg.resolve(context) if not isinstance(value, (Sequence, dict)): raise ValueError(f"XCom map expects sequence or dict, not {type(value).__name__}") return _MapResult(value, self.callables) @@ -525,22 +481,12 @@ def _serialize(self) -> dict[str, Any]: return {"args": args} return {"args": args, "fillvalue": self.fillvalue} - @classmethod - def _deserialize(cls, data: dict[str, Any], dag: DAG) -> XComArg: - return cls( - [deserialize_xcom_arg(arg, dag) for arg in data["args"]], - fillvalue=data.get("fillvalue", NOTSET), - ) - def iter_references(self) -> Iterator[tuple[Operator, str]]: for arg in self.args: yield from arg.iter_references() - @provide_session - def resolve( - self, context: Mapping[str, Any], session: Session = NEW_SESSION, *, include_xcom: bool = True - ) -> Any: - values = [arg.resolve(context, session=session, include_xcom=include_xcom) for arg in self.args] + def resolve(self, context: Mapping[str, Any]) -> Any: + values = [arg.resolve(context) for arg in self.args] for value in values: if not isinstance(value, (Sequence, dict)): raise ValueError(f"XCom zip expects sequence or dict, not {type(value).__name__}") @@ -594,10 +540,6 @@ def __repr__(self) -> str: def _serialize(self) -> dict[str, Any]: return {"args": [serialize_xcom_arg(arg) for arg in self.args]} - @classmethod - def _deserialize(cls, data: dict[str, Any], dag: DAG) -> XComArg: - return cls([deserialize_xcom_arg(arg, dag) for arg in data["args"]]) - def iter_references(self) -> Iterator[tuple[Operator, str]]: for arg in self.args: yield from arg.iter_references() @@ -606,11 +548,8 @@ def concat(self, *others: XComArg) -> ConcatXComArg: # Flatten foo.concat(x).concat(y) into one call. return ConcatXComArg([*self.args, *others]) - @provide_session - def resolve( - self, context: Mapping[str, Any], session: Session = NEW_SESSION, *, include_xcom: bool = True - ) -> Any: - values = [arg.resolve(context, session=session, include_xcom=include_xcom) for arg in self.args] + def resolve(self, context: Mapping[str, Any]) -> Any: + values = [arg.resolve(context) for arg in self.args] for value in values: if not isinstance(value, (Sequence, dict)): raise ValueError(f"XCom concat expects sequence or dict, not {type(value).__name__}") @@ -633,7 +572,44 @@ def serialize_xcom_arg(value: XComArg) -> dict[str, Any]: return value._serialize() -def deserialize_xcom_arg(data: dict[str, Any], dag: DAG) -> XComArg: - """DAG serialization interface.""" - klass = _XCOM_ARG_TYPES[data.get("type", "")] - return klass._deserialize(data, dag) +@singledispatch +def get_task_map_length( + xcom_arg: XComArg, resolved_val: Sized, upstream_map_indexes: dict[str, int] +) -> int | None: + # The base implementation -- specific XComArg subclasses have specialised implementations + raise NotImplementedError() + + +@get_task_map_length.register +def _(xcom_arg: PlainXComArg, resolved_val: Sized, upstream_map_indexes: dict[str, int]): + task_id = xcom_arg.operator.task_id + + if xcom_arg.operator.is_mapped: + # TODO: How to tell if all the upstream TIs finished? + pass + return (upstream_map_indexes.get(task_id) or 1) * len(resolved_val) + + +@get_task_map_length.register +def _(xcom_arg: MapXComArg, resolved_val: Sized, upstream_map_indexes: dict[str, int]): + return get_task_map_length(xcom_arg.arg, resolved_val, upstream_map_indexes) + + +@get_task_map_length.register +def _(xcom_arg: ZipXComArg, resolved_val: Sized, upstream_map_indexes: dict[str, int]): + all_lengths = (get_task_map_length(arg, resolved_val, upstream_map_indexes) for arg in xcom_arg.args) + ready_lengths = [length for length in all_lengths if length is not None] + if len(ready_lengths) != len(xcom_arg.args): + return None # If any of the referenced XComs is not ready, we are not ready either. + if isinstance(xcom_arg.fillvalue, ArgNotSet): + return min(ready_lengths) + return max(ready_lengths) + + +@get_task_map_length.register +def _(xcom_arg: ConcatXComArg, resolved_val: Sized, upstream_map_indexes: dict[str, int]): + all_lengths = (get_task_map_length(arg, resolved_val, upstream_map_indexes) for arg in xcom_arg.args) + ready_lengths = [length for length in all_lengths if length is not None] + if len(ready_lengths) != len(xcom_arg.args): + return None # If any of the referenced XComs is not ready, we are not ready either. + return sum(ready_lengths) diff --git a/task_sdk/src/airflow/sdk/execution_time/comms.py b/task_sdk/src/airflow/sdk/execution_time/comms.py index 1f398cb8b6009..a7749ed4ae27a 100644 --- a/task_sdk/src/airflow/sdk/execution_time/comms.py +++ b/task_sdk/src/airflow/sdk/execution_time/comms.py @@ -119,6 +119,11 @@ def from_xcom_response(cls, xcom_response: XComResponse) -> XComResult: return cls(**xcom_response.model_dump()) +class XComCountResponse(BaseModel): + len: int + type: Literal["XComLengthResponse"] = "XComLengthResponse" + + class ConnectionResult(ConnectionResponse): type: Literal["ConnectionResult"] = "ConnectionResult" @@ -184,6 +189,7 @@ class OKResponse(BaseModel): StartupDetails, VariableResult, XComResult, + XComCountResponse, OKResponse, ], Field(discriminator="type"), @@ -240,6 +246,16 @@ class GetXCom(BaseModel): type: Literal["GetXCom"] = "GetXCom" +class GetXComCount(BaseModel): + """Get the number of (mapped) XCom values available.""" + + key: str + dag_id: str + run_id: str + task_id: str + type: Literal["GetNumberXComs"] = "GetNumberXComs" + + class SetXCom(BaseModel): key: str value: Annotated[ @@ -324,6 +340,7 @@ class GetPrevSuccessfulDagRun(BaseModel): GetPrevSuccessfulDagRun, GetVariable, GetXCom, + GetXComCount, PutVariable, RescheduleTask, SetRenderedFields, diff --git a/task_sdk/src/airflow/sdk/execution_time/context.py b/task_sdk/src/airflow/sdk/execution_time/context.py index 984919ea1c86b..0674d27320a6c 100644 --- a/task_sdk/src/airflow/sdk/execution_time/context.py +++ b/task_sdk/src/airflow/sdk/execution_time/context.py @@ -42,6 +42,7 @@ if TYPE_CHECKING: from uuid import UUID + from airflow.sdk.definitions.baseoperator import BaseOperator from airflow.sdk.definitions.connection import Connection from airflow.sdk.definitions.context import Context from airflow.sdk.definitions.variable import Variable @@ -315,3 +316,21 @@ def set_current_context(context: Context) -> Generator[Context, None, None]: expected=context, got=expected_state, ) + + +def context_update_for_unmapped(context: Context, task: BaseOperator) -> None: + """ + Update context after task unmapping. + + Since ``get_template_context()`` is called before unmapping, the context + contains information about the mapped task. We need to do some in-place + updates to ensure the template context reflects the unmapped task instead. + + :meta private: + """ + from airflow.sdk.definitions.param import process_params + + context["task"] = context["ti"].task = task + context["params"] = process_params( + context["dag"], task, context["dag_run"].conf, suppress_exception=False + ) diff --git a/task_sdk/src/airflow/sdk/execution_time/lazy_sequence.py b/task_sdk/src/airflow/sdk/execution_time/lazy_sequence.py new file mode 100644 index 0000000000000..acbca29c0b954 --- /dev/null +++ b/task_sdk/src/airflow/sdk/execution_time/lazy_sequence.py @@ -0,0 +1,188 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import itertools +from collections.abc import Iterator, Sequence +from typing import TYPE_CHECKING, Any, Literal, TypeVar, overload + +import attrs +import structlog + +if TYPE_CHECKING: + from airflow.sdk.definitions.xcom_arg import PlainXComArg + from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance + +T = TypeVar("T") + +log = structlog.get_logger(logger_name=__name__) + + +@attrs.define +class LazyXComIterator(Iterator[T]): + seq: LazyXComSequence[T] + index: int = 0 + dir: Literal[1, -1] = 1 + + def __next__(self) -> T: + if self.index < 0: + # When iterating backwards, avoid extra HTTP request + raise StopIteration() + val = self.seq._get_item(self.index) + if val is None: + # None isn't the best signal (it's bad in fact) but it's the best we can do until https://github.com/apache/airflow/issues/46426 + raise StopIteration() + self.index += self.dir + return val + + def __iter__(self) -> Iterator[T]: + return self + + +@attrs.define +class LazyXComSequence(Sequence[T]): + _len: int | None = attrs.field(init=False, default=None) + _xcom_arg: PlainXComArg = attrs.field(alias="xcom_arg") + _ti: RuntimeTaskInstance = attrs.field(alias="ti") + + def __repr__(self) -> str: + if self._len is not None: + counter = "item" if (length := len(self)) == 1 else "items" + return f"LazyXComSequence([{length} {counter}])" + return "LazyXComSequence()" + + def __str__(self) -> str: + return repr(self) + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, Sequence): + return NotImplemented + z = itertools.zip_longest(iter(self), iter(other), fillvalue=object()) + return all(x == y for x, y in z) + + def __iter__(self) -> Iterator[T]: + return LazyXComIterator(seq=self) + + def __len__(self) -> int: + if self._len is None: + from airflow.sdk.execution_time.comms import ErrorResponse, GetXComCount, XComCountResponse + from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + + task = self._xcom_arg.operator + + SUPERVISOR_COMMS.send_request( + log=log, + msg=GetXComCount( + key=self._xcom_arg.key, + dag_id=task.dag_id, + run_id=self._ti.run_id, + task_id=task.task_id, + ), + ) + msg = SUPERVISOR_COMMS.get_message() + if isinstance(msg, ErrorResponse): + raise RuntimeError(msg) + elif not isinstance(msg, XComCountResponse): + raise TypeError(f"Got unexpected response to GetXComCount: {msg}") + self._len = msg.len + return self._len + + @overload + def __getitem__(self, key: int) -> T: ... + + @overload + def __getitem__(self, key: slice) -> Sequence[T]: ... + + def __getitem__(self, key: int | slice) -> T | Sequence[T]: + if isinstance(key, int): + if key >= 0: + return self._get_item(key) + else: + # val[-1] etc. + return self._get_item(len(self) + key) + + elif isinstance(key, slice): + # This implements the slicing syntax. We want to optimize negative slicing (e.g. seq[-10:]) by not + # doing an additional COUNT query (via HEAD http request) if possible. We can do this unless the + # start and stop have different signs (i.e. one is positive and another negative). + ... + """ + Todo? + elif isinstance(key, slice): + start, stop, reverse = _coerce_slice(key) + if start >= 0: + if stop is None: + stmt = self._select_asc.offset(start) + elif stop >= 0: + stmt = self._select_asc.slice(start, stop) + else: + stmt = self._select_asc.slice(start, len(self) + stop) + rows = [self._process_row(row) for row in self._session.execute(stmt)] + if reverse: + rows.reverse() + else: + if stop is None: + stmt = self._select_desc.limit(-start) + elif stop < 0: + stmt = self._select_desc.slice(-stop, -start) + else: + stmt = self._select_desc.slice(len(self) - stop, -start) + rows = [self._process_row(row) for row in self._session.execute(stmt)] + if not reverse: + rows.reverse() + return rows + """ + raise TypeError(f"Sequence indices must be integers or slices, not {type(key).__name__}") + + def _get_item(self, index: int) -> T: + # TODO: maybe we need to call SUPERVISOR_COMMS manually so we can handle not found here? + return self._ti.xcom_pull( + task_ids=self._xcom_arg.operator.task_id, + key=self._xcom_arg.key, + map_indexes=index, + ) + + +def _coerce_index(value: Any) -> int | None: + """ + Check slice attribute's type and convert it to int. + + See CPython documentation on this: + https://docs.python.org/3/reference/datamodel.html#object.__index__ + """ + if value is None or isinstance(value, int): + return value + if (index := getattr(value, "__index__", None)) is not None: + return index() + raise TypeError("slice indices must be integers or None or have an __index__ method") + + +def _coerce_slice(key: slice) -> tuple[int, int | None, bool]: + """ + Check slice content and convert it for SQL. + + See CPython documentation on this: + https://docs.python.org/3/reference/datamodel.html#slice-objects + """ + if key.step is None or key.step == 1: + reverse = False + elif key.step == -1: + reverse = True + else: + raise ValueError("non-trivial slice step not supported") + return _coerce_index(key.start) or 0, _coerce_index(key.stop), reverse diff --git a/airflow/utils/log/secrets_masker.py b/task_sdk/src/airflow/sdk/execution_time/secrets_masker.py similarity index 100% rename from airflow/utils/log/secrets_masker.py rename to task_sdk/src/airflow/sdk/execution_time/secrets_masker.py diff --git a/task_sdk/src/airflow/sdk/execution_time/supervisor.py b/task_sdk/src/airflow/sdk/execution_time/supervisor.py index 3a65247e0ee40..31ad8077aec4b 100644 --- a/task_sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task_sdk/src/airflow/sdk/execution_time/supervisor.py @@ -70,6 +70,7 @@ GetPrevSuccessfulDagRun, GetVariable, GetXCom, + GetXComCount, PrevSuccessfulDagRunResult, PutVariable, RescheduleTask, @@ -81,6 +82,7 @@ TaskState, ToSupervisor, VariableResult, + XComCountResponse, XComResult, ) @@ -797,6 +799,9 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger): xcom = self.client.xcoms.get(msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.map_index) xcom_result = XComResult.from_xcom_response(xcom) resp = xcom_result.model_dump_json().encode() + elif isinstance(msg, GetXComCount): + len = self.client.xcoms.head(msg.dag_id, msg.run_id, msg.task_id, msg.key) + resp = XComCountResponse(len=len).model_dump_json().encode() elif isinstance(msg, DeferTask): self._terminal_state = IntermediateTIState.DEFERRED self.client.task_instances.defer(self.id, msg) diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py b/task_sdk/src/airflow/sdk/execution_time/task_runner.py index a0f6189b89b2d..77cd29d090c2d 100644 --- a/task_sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py @@ -35,8 +35,10 @@ from airflow.dag_processing.bundles.manager import DagBundlesManager from airflow.sdk.api.datamodels._generated import AssetProfile, TaskInstance, TerminalTIState, TIRunContext from airflow.sdk.definitions._internal.dag_parsing_context import _airflow_parsing_context_manager +from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetNameRef, AssetUriRef from airflow.sdk.definitions.baseoperator import BaseOperator +from airflow.sdk.definitions.mappedoperator import MappedOperator from airflow.sdk.definitions.param import process_params from airflow.sdk.execution_time.comms import ( DeferTask, @@ -67,6 +69,7 @@ import jinja2 from structlog.typing import FilteringBoundLogger as Logger + from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator from airflow.sdk.definitions.context import Context @@ -83,6 +86,16 @@ class RuntimeTaskInstance(TaskInstance): max_tries: int = 0 """The maximum number of retries for the task.""" + def __rich_repr__(self): + yield "id", self.id + yield "task_id", self.task_id + yield "dag_id", self.dag_id + yield "run_id", self.run_id + yield "max_tries", self.max_tries + yield "task", type(self.task) + + __rich_repr__.angular = True # type: ignore[attr-defined] + def get_template_context(self) -> Context: # TODO: Move this to `airflow.sdk.execution_time.context` # once we port the entire context logic from airflow/utils/context.py ? @@ -123,28 +136,15 @@ def get_template_context(self) -> Context: }, "conn": ConnectionAccessor(), } - if self._ti_context_from_server: - dag_run = self._ti_context_from_server.dag_run - - logical_date = dag_run.logical_date - ds = logical_date.strftime("%Y-%m-%d") - ds_nodash = ds.replace("-", "") - ts = logical_date.isoformat() - ts_nodash = logical_date.strftime("%Y%m%dT%H%M%S") - ts_nodash_with_tz = ts.replace("-", "").replace(":", "") + if from_server := self._ti_context_from_server: + dag_run = from_server.dag_run context_from_server: Context = { # TODO: Assess if we need to pass these through timezone.coerce_datetime - "dag_run": dag_run, + "dag_run": dag_run, # type: ignore[typeddict-item] # Removable after #46522 "data_interval_end": dag_run.data_interval_end, "data_interval_start": dag_run.data_interval_start, - "logical_date": logical_date, - "ds": ds, - "ds_nodash": ds_nodash, - "task_instance_key_str": f"{self.task.dag_id}__{self.task.task_id}__{ds_nodash}", - "ts": ts, - "ts_nodash": ts_nodash, - "ts_nodash_with_tz": ts_nodash_with_tz, + "task_instance_key_str": f"{self.task.dag_id}__{self.task.task_id}__{dag_run.run_id}", "prev_data_interval_start_success": lazy_object_proxy.Proxy( lambda: get_previous_dagrun_success(self.id).data_interval_start ), @@ -160,6 +160,28 @@ def get_template_context(self) -> Context: } context.update(context_from_server) + if logical_date := dag_run.logical_date: + ds = logical_date.strftime("%Y-%m-%d") + ds_nodash = ds.replace("-", "") + ts = logical_date.isoformat() + ts_nodash = logical_date.strftime("%Y%m%dT%H%M%S") + ts_nodash_with_tz = ts.replace("-", "").replace(":", "") + context.update( + { + "logical_date": logical_date, + "ds": ds, + "ds_nodash": ds_nodash, + "task_instance_key_str": f"{self.task.dag_id}__{self.task.task_id}__{ds_nodash}", + "ts": ts, + "ts_nodash": ts_nodash, + "ts_nodash_with_tz": ts_nodash_with_tz, + } + ) + if from_server.upstream_map_indexes is not None: + # We stash this in here for later use, but we purposefully don't want to document it's + # existence. Should this be a private attribute on RuntimeTI instead perhaps? + setattr(self, "_upstream_map_indexes", from_server.upstream_map_indexes) + return context def render_templates( @@ -190,10 +212,7 @@ def render_templates( # unmapped BaseOperator created by this function! This is because the # MappedOperator is useless for template rendering, and we need to be # able to access the unmapped task instead. - original_task.render_template_fields(context, jinja_env) - # TODO: Add support for rendering templates in the MappedOperator - # if isinstance(self.task, MappedOperator): - # self.task = context["ti"].task + self.task.render_template_fields(context, jinja_env) return original_task @@ -204,7 +223,7 @@ def xcom_pull( key: str = "return_value", # TODO: Make this a constant (``XCOM_RETURN_KEY``) include_prior_dates: bool = False, # TODO: Add support for this *, - map_indexes: int | Iterable[int] | None = None, + map_indexes: int | Iterable[int] | None | ArgNotSet = NOTSET, default: Any = None, run_id: str | None = None, ) -> Any: @@ -251,7 +270,7 @@ def xcom_pull( task_ids = self.task_id elif isinstance(task_ids, str): task_ids = [task_ids] - if map_indexes is None: + if isinstance(map_indexes, ArgNotSet): map_indexes = self.map_index elif isinstance(map_indexes, Iterable): # TODO: Handle multiple map_indexes or remove support @@ -359,8 +378,10 @@ def parse(what: StartupDetails) -> RuntimeTaskInstance: # TODO: Handle task not found task = dag.task_dict[what.ti.task_id] - if not isinstance(task, BaseOperator): - raise TypeError(f"task is of the wrong type, got {type(task)}, wanted {BaseOperator}") + if not isinstance(task, (BaseOperator, MappedOperator)): + raise TypeError( + f"task is of the wrong type, got {type(task)}, wanted {BaseOperator} or {MappedOperator}" + ) return RuntimeTaskInstance.model_construct( **what.ti.model_dump(exclude_unset=True), @@ -446,22 +467,10 @@ def startup() -> tuple[RuntimeTaskInstance, Logger]: else: raise RuntimeError(f"Unhandled startup message {type(msg)} {msg}") - # TODO: Render fields here - # 1. Implementing the part where we pull in the logic to render fields and add that here - # for all operators, we should do setattr(task, templated_field, rendered_templated_field) - # task.templated_fields should give all the templated_fields and each of those fields should - # give the rendered values. task.templated_fields should already be in a JSONable format and - # we should not have to handle that here. - - # 2. Once rendered, we call the `set_rtif` API to store the rtif in the metadata DB - - # so that we do not call the API unnecessarily - if rendered_fields := _get_rendered_fields(ti.task): - SUPERVISOR_COMMS.send_request(log=log, msg=SetRenderedFields(rendered_fields=rendered_fields)) return ti, log -def _get_rendered_fields(task: BaseOperator) -> dict[str, JsonValue]: +def _serialize_rendered_fields(task: AbstractOperator) -> dict[str, JsonValue]: # TODO: Port one of the following to Task SDK # airflow.serialization.helpers.serialize_template_field or # airflow.models.renderedtifields.get_serialized_template_fields @@ -500,7 +509,35 @@ def _process_outlets(context: Context, outlets: list[AssetProfile]): return task_outlets, outlet_events -def run(ti: RuntimeTaskInstance, log: Logger): +def _prepare(ti: RuntimeTaskInstance, log: Logger, context: Context) -> ToSupervisor | None: + ti.hostname = get_hostname() + ti.task = ti.task.prepare_for_execution() + if ti.task.inlets or ti.task.outlets: + inlets = [asset.asprofile() for asset in ti.task.inlets if isinstance(asset, Asset)] + outlets = [asset.asprofile() for asset in ti.task.outlets if isinstance(asset, Asset)] + SUPERVISOR_COMMS.send_request(msg=RuntimeCheckOnTask(inlets=inlets, outlets=outlets), log=log) # type: ignore + ok_response = SUPERVISOR_COMMS.get_message() # type: ignore + if not isinstance(ok_response, OKResponse) or not ok_response.ok: + log.info("Runtime checks failed for task, marking task as failed..") + return TaskState( + state=TerminalTIState.FAILED, + end_date=datetime.now(tz=timezone.utc), + ) + + jinja_env = ti.task.dag.get_template_env() + ti.render_templates(context=context, jinja_env=jinja_env) + + if rendered_fields := _serialize_rendered_fields(ti.task): + # so that we do not call the API unnecessarily + SUPERVISOR_COMMS.send_request(log=log, msg=SetRenderedFields(rendered_fields=rendered_fields)) + + # TODO: Call pre execute etc. + + # No error, carry on and execute the task + return None + + +def run(ti: RuntimeTaskInstance, log: Logger) -> ToSupervisor | None: """Run the task in this process.""" from airflow.exceptions import ( AirflowException, @@ -519,31 +556,17 @@ def run(ti: RuntimeTaskInstance, log: Logger): msg: ToSupervisor | None = None try: - # TODO: pre execute etc. - # TODO: Get a real context object - ti.hostname = get_hostname() - ti.task = ti.task.prepare_for_execution() - if ti.task.inlets or ti.task.outlets: - inlets = [asset.asprofile() for asset in ti.task.inlets if isinstance(asset, Asset)] - outlets = [asset.asprofile() for asset in ti.task.outlets if isinstance(asset, Asset)] - SUPERVISOR_COMMS.send_request(msg=RuntimeCheckOnTask(inlets=inlets, outlets=outlets), log=log) # type: ignore - ok_response = SUPERVISOR_COMMS.get_message() # type: ignore - if not isinstance(ok_response, OKResponse) or not ok_response.ok: - log.info("Runtime checks failed for task, marking task as failed..") - msg = TaskState( - state=TerminalTIState.FAILED, - end_date=datetime.now(tz=timezone.utc), - ) - return context = ti.get_template_context() with set_current_context(context): - jinja_env = ti.task.dag.get_template_env() - ti.task = ti.render_templates(context=context, jinja_env=jinja_env) - # TODO: Get things from _execute_task_with_callbacks - # - Pre Execute - # etc - result = _execute_task(context, ti.task) + # This is the earliest that we can render templates -- as if it excepts for any reason we need to + # catch it and handle it like a normal task failure + if early_exit := _prepare(ti, log, context): + msg = early_exit + return msg + + result = _execute_task(context, ti) + log.info("Pushing xcom", ti=ti) _push_xcom_if_needed(result, ti) task_outlets, outlet_events = _process_outlets(context, ti.task.outlets) @@ -620,12 +643,15 @@ def run(ti: RuntimeTaskInstance, log: Logger): finally: if msg: SUPERVISOR_COMMS.send_request(msg=msg, log=log) + # Return the message to make unit tests easier too + return msg -def _execute_task(context: Context, task: BaseOperator): +def _execute_task(context: Context, ti: RuntimeTaskInstance): """Execute Task (optionally with a Timeout) and push Xcom results.""" from airflow.exceptions import AirflowTaskTimeout + task = ti.task if task.execution_timeout: # TODO: handle timeout in case of deferral from airflow.utils.timeout import timeout diff --git a/task_sdk/src/airflow/sdk/types.py b/task_sdk/src/airflow/sdk/types.py index fd02104fb2fc4..795130bf37462 100644 --- a/task_sdk/src/airflow/sdk/types.py +++ b/task_sdk/src/airflow/sdk/types.py @@ -17,8 +17,11 @@ from __future__ import annotations +from collections.abc import Iterable from typing import TYPE_CHECKING, Any, Protocol, Union +from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet + if TYPE_CHECKING: from collections.abc import Iterator from datetime import datetime @@ -36,14 +39,17 @@ class DagRunProtocol(Protocol): dag_id: str run_id: str - logical_date: datetime + logical_date: datetime | None data_interval_start: datetime | None data_interval_end: datetime | None start_date: datetime end_date: datetime | None run_type: Any + run_after: datetime conf: dict[str, Any] | None - external_trigger: bool + # This shouldn't be "| None", but there's a bug in the datamodel generator, and None evaluates to Falsey + # too, so this is "okay" + external_trigger: bool | None = False class RuntimeTaskInstanceProtocol(Protocol): @@ -54,7 +60,7 @@ class RuntimeTaskInstanceProtocol(Protocol): dag_id: str run_id: str try_number: int - map_index: int + map_index: int | None max_tries: int hostname: str | None = None @@ -66,7 +72,7 @@ def xcom_pull( # TODO: `include_prior_dates` isn't yet supported in the SDK # include_prior_dates: bool = False, *, - map_indexes: int | list[int] | None = None, + map_indexes: int | Iterable[int] | None | ArgNotSet = NOTSET, default: Any = None, run_id: str | None = None, ) -> Any: ... diff --git a/task_sdk/tests/api/test_client.py b/task_sdk/tests/api/test_client.py index 8315a121fc4d7..43e35dfec9ec0 100644 --- a/task_sdk/tests/api/test_client.py +++ b/task_sdk/tests/api/test_client.py @@ -37,6 +37,11 @@ def make_client(transport: httpx.MockTransport) -> Client: return Client(base_url="test://server", token="", transport=transport) +def make_client_w_dry_run() -> Client: + """Get a client with dry_run enabled""" + return Client(base_url=None, dry_run=True, token="") + + def make_client_w_responses(responses: list[httpx.Response]) -> Client: """Helper fixture to create a mock client with custom responses.""" @@ -49,6 +54,34 @@ def handle_request(request: httpx.Request) -> httpx.Response: class TestClient: + @pytest.mark.parametrize( + ["path", "json_response"], + [ + ( + "/task-instances/1/run", + { + "dag_run": { + "dag_id": "test_dag", + "run_id": "test_run", + "logical_date": "2021-01-01T00:00:00Z", + "start_date": "2021-01-01T00:00:00Z", + "run_type": "manual", + "run_after": "2021-01-01T00:00:00Z", + }, + "max_tries": 0, + }, + ), + ], + ) + def test_dry_run(self, path, json_response): + client = make_client_w_dry_run() + assert client.base_url == "dry-run://server" + + resp = client.get(path) + + assert resp.status_code == 200 + assert resp.json() == json_response + def test_error_parsing(self): responses = [ httpx.Response(422, json={"detail": [{"loc": ["#0"], "msg": "err", "type": "required"}]}) @@ -195,7 +228,9 @@ def handle_request(request: httpx.Request) -> httpx.Response: assert resp == ti_context assert call_count == 4 - @pytest.mark.parametrize("state", [state for state in TerminalTIState]) + @pytest.mark.parametrize( + "state", [state for state in TerminalTIState if state != TerminalTIState.SUCCESS] + ) def test_task_instance_finish(self, state): # Simulate a successful response from the server that finishes (moved to terminal state) a task ti_id = uuid6.uuid7() diff --git a/task_sdk/tests/conftest.py b/task_sdk/tests/conftest.py index cc4bc4f96148a..bb40c641bf9b2 100644 --- a/task_sdk/tests/conftest.py +++ b/task_sdk/tests/conftest.py @@ -35,6 +35,9 @@ from structlog.typing import EventDict, WrappedLogger from airflow.sdk.api.datamodels._generated import TIRunContext + from airflow.sdk.definitions.baseoperator import BaseOperator + from airflow.sdk.execution_time.comms import StartupDetails + from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance @pytest.hookimpl() @@ -154,7 +157,9 @@ def __call__( data_interval_start: str | datetime = ..., data_interval_end: str | datetime = ..., start_date: str | datetime = ..., + run_after: str | datetime = ..., run_type: str = ..., + conf=None, ) -> TIRunContext: ... @@ -167,7 +172,9 @@ def __call__( data_interval_start: str | datetime = ..., data_interval_end: str | datetime = ..., start_date: str | datetime = ..., + run_after: str | datetime = ..., run_type: str = ..., + conf=None, ) -> dict[str, Any]: ... @@ -183,6 +190,7 @@ def _make_context( data_interval_start: str | datetime = "2024-12-01T00:00:00Z", data_interval_end: str | datetime = "2024-12-01T01:00:00Z", start_date: str | datetime = "2024-12-01T01:00:00Z", + run_after: str | datetime = "2024-12-01T01:00:00Z", run_type: str = "manual", conf=None, ) -> TIRunContext: @@ -195,6 +203,7 @@ def _make_context( data_interval_end=data_interval_end, # type: ignore start_date=start_date, # type: ignore run_type=run_type, # type: ignore + run_after=run_after, # type: ignore conf=conf, ), max_tries=0, @@ -214,7 +223,9 @@ def _make_context_dict( data_interval_start: str | datetime = "2024-12-01T00:00:00Z", data_interval_end: str | datetime = "2024-12-01T01:00:00Z", start_date: str | datetime = "2024-12-01T00:00:00Z", + run_after: str | datetime = "2024-12-01T00:00:00Z", run_type: str = "manual", + conf=None, ) -> dict[str, Any]: context = make_ti_context( dag_id=dag_id, @@ -223,7 +234,9 @@ def _make_context_dict( data_interval_start=data_interval_start, data_interval_end=data_interval_end, start_date=start_date, + run_after=run_after, run_type=run_type, + conf=conf, ) return context.model_dump(exclude_unset=True, mode="json") @@ -236,3 +249,135 @@ def mock_supervisor_comms(): "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True ) as supervisor_comms: yield supervisor_comms + + +@pytest.fixture +def mocked_parse(spy_agency): + """ + Fixture to set up an inline DAG and use it in a stubbed `parse` function. Use this fixture if you + want to isolate and test `parse` or `run` logic without having to define a DAG file. + + This fixture returns a helper function `set_dag` that: + 1. Creates an in line DAG with the given `dag_id` and `task` (limited to one task) + 2. Constructs a `RuntimeTaskInstance` based on the provided `StartupDetails` and task. + 3. Stubs the `parse` function using `spy_agency`, to return the mocked `RuntimeTaskInstance`. + + After adding the fixture in your test function signature, you can use it like this :: + + mocked_parse( + StartupDetails( + ti=TaskInstance(id=uuid7(), task_id="hello", dag_id="super_basic_run", run_id="c", try_number=1), + file="", + requests_fd=0, + ), + "example_dag_id", + CustomOperator(task_id="hello"), + ) + """ + + def set_dag(what: StartupDetails, dag_id: str, task: BaseOperator) -> RuntimeTaskInstance: + from airflow.sdk.definitions.dag import DAG + from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance, parse + from airflow.utils import timezone + + if not task.has_dag(): + dag = DAG(dag_id=dag_id, start_date=timezone.datetime(2024, 12, 3)) + task.dag = dag + task = dag.task_dict[task.task_id] + else: + dag = task.dag + if what.ti_context.dag_run.conf: + dag.params = what.ti_context.dag_run.conf # type: ignore[assignment] + ti = RuntimeTaskInstance.model_construct( + **what.ti.model_dump(exclude_unset=True), + task=task, + _ti_context_from_server=what.ti_context, + max_tries=what.ti_context.max_tries, + ) + if hasattr(parse, "spy"): + spy_agency.unspy(parse) + spy_agency.spy_on(parse, call_fake=lambda _: ti) + return ti + + return set_dag + + +@pytest.fixture +def create_runtime_ti(mocked_parse, make_ti_context): + """ + Fixture to create a Runtime TaskInstance for testing purposes without defining a dag file. + + It mimics the behavior of the `parse` function by creating a `RuntimeTaskInstance` based on the provided + `StartupDetails` (formed from arguments) and task. This allows you to test the logic of a task without + having to define a DAG file, parse it, get context from the server, etc. + + Example usage: :: + + def test_custom_task_instance(create_runtime_ti): + class MyTaskOperator(BaseOperator): + def execute(self, context): + assert context["dag_run"].run_id == "test_run" + + task = MyTaskOperator(task_id="test_task") + ti = create_runtime_ti(task, context_from_server=make_ti_context(run_id="test_run")) + # Further test logic... + """ + from uuid6 import uuid7 + + from airflow.sdk.api.datamodels._generated import TaskInstance + from airflow.sdk.execution_time.comms import BundleInfo, StartupDetails + + def _create_task_instance( + task: BaseOperator, + dag_id: str = "test_dag", + run_id: str = "test_run", + logical_date: str | datetime = "2024-12-01T01:00:00Z", + data_interval_start: str | datetime = "2024-12-01T00:00:00Z", + data_interval_end: str | datetime = "2024-12-01T01:00:00Z", + start_date: str | datetime = "2024-12-01T01:00:00Z", + run_type: str = "manual", + try_number: int = 1, + map_index: int | None = -1, + upstream_map_indexes: dict[str, int] | None = None, + ti_id=None, + conf=None, + ) -> RuntimeTaskInstance: + if not ti_id: + ti_id = uuid7() + + if task.has_dag(): + dag_id = task.dag.dag_id + + ti_context = make_ti_context( + dag_id=dag_id, + run_id=run_id, + logical_date=logical_date, + data_interval_start=data_interval_start, + data_interval_end=data_interval_end, + start_date=start_date, + run_type=run_type, + conf=conf, + ) + + if upstream_map_indexes is not None: + ti_context.upstream_map_indexes = upstream_map_indexes + + startup_details = StartupDetails( + ti=TaskInstance( + id=ti_id, + task_id=task.task_id, + dag_id=dag_id, + run_id=run_id, + try_number=try_number, + map_index=map_index, + ), + dag_rel_path="", + bundle_info=BundleInfo(name="anything", version="any"), + requests_fd=0, + ti_context=ti_context, + ) + + ti = mocked_parse(startup_details, dag_id, task) + return ti + + return _create_task_instance diff --git a/task_sdk/tests/definitions/conftest.py b/task_sdk/tests/definitions/conftest.py new file mode 100644 index 0000000000000..02ed599c5ec60 --- /dev/null +++ b/task_sdk/tests/definitions/conftest.py @@ -0,0 +1,49 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest +import structlog + +from airflow.sdk.execution_time.comms import SucceedTask, TaskState + +if TYPE_CHECKING: + from airflow.sdk.definitions.dag import DAG + + +@pytest.fixture +def run_ti(create_runtime_ti, mock_supervisor_comms): + def run(dag: DAG, task_id: str, map_index: int): + """Run the task and return the state that the SDK sent as the result for easier asserts""" + from airflow.sdk.execution_time.task_runner import run + + log = structlog.get_logger() + + mock_supervisor_comms.send_request.reset_mock() + ti = create_runtime_ti(dag.task_dict[task_id], map_index=map_index) + run(ti, log) + + for call in mock_supervisor_comms.send_request.mock_calls: + msg = call.kwargs["msg"] + if isinstance(msg, (TaskState, SucceedTask)): + return msg.state + raise RuntimeError("Unable to find call to TaskState") + + return run diff --git a/task_sdk/tests/definitions/test_baseoperator.py b/task_sdk/tests/definitions/test_baseoperator.py index af6bf592f5373..35f33818dc198 100644 --- a/task_sdk/tests/definitions/test_baseoperator.py +++ b/task_sdk/tests/definitions/test_baseoperator.py @@ -621,26 +621,3 @@ def _do_render(): assert expected_log in caplog.text if not_expected_log: assert not_expected_log not in caplog.text - - -def test_find_mapped_dependants_in_another_group(): - from airflow.decorators import task as task_decorator - from airflow.sdk import TaskGroup - - @task_decorator - def gen(x): - return list(range(x)) - - @task_decorator - def add(x, y): - return x + y - - with DAG(dag_id="test"): - with TaskGroup(group_id="g1"): - gen_result = gen(3) - with TaskGroup(group_id="g2"): - add_result = add.partial(y=1).expand(x=gen_result) - - # breakpoint() - dependants = list(gen_result.operator.iter_mapped_dependants()) - assert dependants == [add_result.operator] diff --git a/task_sdk/tests/definitions/test_dag.py b/task_sdk/tests/definitions/test_dag.py index e6baeabe98dee..fa0cd65846ec5 100644 --- a/task_sdk/tests/definitions/test_dag.py +++ b/task_sdk/tests/definitions/test_dag.py @@ -25,7 +25,7 @@ from airflow.exceptions import DuplicateTaskIdFound from airflow.sdk.definitions.baseoperator import BaseOperator from airflow.sdk.definitions.dag import DAG, dag as dag_decorator -from airflow.sdk.definitions.param import Param, ParamsDict +from airflow.sdk.definitions.param import DagParam, Param, ParamsDict DEFAULT_DATE = datetime(2016, 1, 1, tzinfo=timezone.utc) @@ -350,6 +350,13 @@ def test__tags_mutable(): assert test_dag.tags == expected_tags +def test_create_dag_while_active_context(): + """Test that we can safely create a DAG whilst a DAG is activated via ``with dag1:``.""" + with DAG(dag_id="simple_dag"): + DAG(dag_id="dag2") + # No asserts needed, it just needs to not fail + + class TestDagDecorator: DEFAULT_ARGS = { "owner": "test", @@ -418,8 +425,55 @@ def noop_pipeline(value): ... with pytest.raises(TypeError): noop_pipeline() - def test_create_dag_while_active_context(self): - """Test that we can safely create a DAG whilst a DAG is activated via ``with dag1:``.""" - with DAG(dag_id="simple_dag"): - DAG(dag_id="dag2") - # No asserts needed, it just needs to not fail + def test_documentation_template_rendered(self): + """Test that @dag uses function docs as doc_md for DAG object""" + + @dag_decorator(schedule=None, default_args=self.DEFAULT_ARGS) + def noop_pipeline(): + """ + {% if True %} + Regular DAG documentation + {% endif %} + """ + + dag = noop_pipeline() + assert dag.dag_id == "noop_pipeline" + assert "Regular DAG documentation" in dag.doc_md + + def test_resolve_documentation_template_file_not_rendered(self, tmp_path): + """Test that @dag uses function docs as doc_md for DAG object""" + + raw_content = """ + {% if True %} + External Markdown DAG documentation + {% endif %} + """ + + path = tmp_path / "testfile.md" + path.write_text(raw_content) + + @dag_decorator("test-dag", schedule=None, start_date=DEFAULT_DATE, doc_md=str(path)) + def markdown_docs(): ... + + dag = markdown_docs() + assert dag.dag_id == "test-dag" + assert dag.doc_md == raw_content + + def test_dag_param_resolves(self): + """Test that dag param is correctly resolved by operator""" + from airflow.decorators import task + + @dag_decorator(schedule=None, default_args=self.DEFAULT_ARGS) + def xcom_pass_to_op(value=self.VALUE): + @task + def return_num(num): + return num + + xcom_arg = return_num(value) + self.operator = xcom_arg.operator + + xcom_pass_to_op() + + assert isinstance(self.operator.op_args[0], DagParam) + self.operator.render_template_fields({}) + assert self.operator.op_args[0] == 42 diff --git a/task_sdk/tests/definitions/test_mappedoperator.py b/task_sdk/tests/definitions/test_mappedoperator.py index eeb79f31b4d47..64b0b8d7b8379 100644 --- a/task_sdk/tests/definitions/test_mappedoperator.py +++ b/task_sdk/tests/definitions/test_mappedoperator.py @@ -17,16 +17,20 @@ # under the License. from __future__ import annotations +import json from datetime import datetime, timedelta +from typing import TYPE_CHECKING, Callable +from unittest import mock import pendulum import pytest +from airflow.sdk.api.datamodels._generated import TerminalTIState from airflow.sdk.definitions.baseoperator import BaseOperator from airflow.sdk.definitions.dag import DAG from airflow.sdk.definitions.mappedoperator import MappedOperator -from airflow.sdk.definitions.param import ParamsDict from airflow.sdk.definitions.xcom_arg import XComArg +from airflow.sdk.execution_time.comms import GetXCom, SetXCom, XComResult from airflow.utils.trigger_rule import TriggerRule from tests_common.test_utils.mapping import expand_mapped_task # noqa: F401 @@ -120,9 +124,6 @@ def test_map_xcom_arg(): assert task1.downstream_list == [mapped] -# def test_map_xcom_arg_multiple_upstream_xcoms(dag_maker, session): - - def test_partial_on_instance() -> None: """`.partial` on an instance should fail -- it's only designed to be called on classes""" with pytest.raises(TypeError): @@ -154,15 +155,6 @@ def test_partial_on_invalid_pool_slots_raises() -> None: MockOperator.partial(task_id="pool_slots_test", pool="test", pool_slots="a").expand(arg1=[1, 2, 3]) -# def test_expand_mapped_task_instance(dag_maker, session, num_existing_tis, expected): - - -# def test_expand_mapped_task_failed_state_in_db(dag_maker, session): - - -# def test_expand_mapped_task_instance_skipped_on_zero(dag_maker, session): - - def test_mapped_task_applies_default_args_classic(): with DAG("test", default_args={"execution_timeout": timedelta(minutes=30)}) as dag: MockOperator(task_id="simple", arg1=None, arg2=0) @@ -191,43 +183,120 @@ def mapped(arg): @pytest.mark.parametrize( - "dag_params, task_params, expected_partial_params", + ("callable", "expected"), [ - pytest.param(None, None, ParamsDict(), id="none"), - pytest.param({"a": -1}, None, ParamsDict({"a": -1}), id="dag"), - pytest.param(None, {"b": -2}, ParamsDict({"b": -2}), id="task"), - pytest.param({"a": -1}, {"b": -2}, ParamsDict({"a": -1, "b": -2}), id="merge"), + pytest.param( + lambda partial, output1: partial.expand( + map_template=output1, map_static=output1, file_template=["/path/to/file.ext"] + ), + # Note to the next person to come across this. In #32272 we changed expand_kwargs so that it + # resolves the mapped template when it's in `expand_kwargs()`, but we _didn't_ do the same for + # things in `expand()`. This feels like a bug to me (ashb) but I am not changing that now, I have + # just moved and parametrized this test. + "{{ ds }}", + id="expand", + ), + pytest.param( + lambda partial, output1: partial.expand_kwargs( + [{"map_template": "{{ ds }}", "map_static": "{{ ds }}", "file_template": "/path/to/file.ext"}] + ), + "2024-12-01", + id="expand_kwargs", + ), ], ) -def test_mapped_expand_against_params(dag_params, task_params, expected_partial_params): - with DAG("test", params=dag_params) as dag: - MockOperator.partial(task_id="t", params=task_params).expand(params=[{"c": "x"}, {"d": 1}]) - - t = dag.get_task("t") - assert isinstance(t, MappedOperator) - assert t.params == expected_partial_params - assert t.expand_input.value == {"params": [{"c": "x"}, {"d": 1}]} - +def test_mapped_render_template_fields_validating_operator( + tmp_path, create_runtime_ti, mock_supervisor_comms, callable, expected: bool +): + file_template_dir = tmp_path / "path" / "to" + file_template_dir.mkdir(parents=True, exist_ok=True) + file_template = file_template_dir / "file.ext" + file_template.write_text("loaded data") + + class MyOperator(BaseOperator): + template_fields = ("partial_template", "map_template", "file_template") + template_ext = (".ext",) + + def __init__( + self, partial_template, partial_static, map_template, map_static, file_template, **kwargs + ): + for value in [partial_template, partial_static, map_template, map_static, file_template]: + assert isinstance(value, str), "value should have been resolved before unmapping" + super().__init__(**kwargs) + self.partial_template = partial_template + self.partial_static = partial_static + self.map_template = map_template + self.map_static = map_static + self.file_template = file_template + + def execute(self, context): + pass -# def test_mapped_render_template_fields_validating_operator(dag_maker, session, tmp_path): + with DAG("test_dag", template_searchpath=tmp_path.__fspath__()): + task1 = BaseOperator(task_id="op1") + mapped = MyOperator.partial( + task_id="a", partial_template="{{ ti.task_id }}", partial_static="{{ ti.task_id }}" + ) + mapped = callable(mapped, task1.output) + mock_supervisor_comms.get_message.return_value = XComResult(key="return_value", value='["{{ ds }}"]') -# def test_mapped_expand_kwargs_render_template_fields_validating_operator(dag_maker, session, tmp_path): + mapped_ti = create_runtime_ti(task=mapped, map_index=0, upstream_map_indexes={task1.task_id: 1}) + assert isinstance(mapped_ti.task, MappedOperator) + mapped_ti.task.render_template_fields(context=mapped_ti.get_template_context()) + assert isinstance(mapped_ti.task, MyOperator) -# def test_mapped_render_nested_template_fields(dag_maker, session): + assert mapped_ti.task.partial_template == "a", "Should be templated!" + assert mapped_ti.task.partial_static == "{{ ti.task_id }}", "Should not be templated!" + assert mapped_ti.task.map_template == expected + assert mapped_ti.task.map_static == "{{ ds }}", "Should not be templated!" + assert mapped_ti.task.file_template == "loaded data", "Should be templated!" -# def test_expand_kwargs_mapped_task_instance(dag_maker, session, num_existing_tis, expected): +def test_mapped_render_nested_template_fields(create_runtime_ti, mock_supervisor_comms): + with DAG("test_dag"): + mapped = MockOperatorWithNestedFields.partial( + task_id="t", arg2=NestedFields(field_1="{{ ti.task_id }}", field_2="value_2") + ).expand(arg1=["{{ ti.task_id }}", ["s", "{{ ti.task_id }}"]]) + ti = create_runtime_ti(task=mapped, map_index=0, upstream_map_indexes={}) + ti.task.render_template_fields(context=ti.get_template_context()) + assert ti.task.arg1 == "t" + assert ti.task.arg2.field_1 == "t" + assert ti.task.arg2.field_2 == "value_2" -# def test_expand_mapped_task_instance_with_named_index( + ti = create_runtime_ti(task=mapped, map_index=1, upstream_map_indexes={}) + ti.task.render_template_fields(context=ti.get_template_context()) + assert ti.task.arg1 == ["s", "t"] + assert ti.task.arg2.field_1 == "t" + assert ti.task.arg2.field_2 == "value_2" -# def test_expand_mapped_task_task_instance_mutation_hook(dag_maker, session, create_mapped_task) -> None: +@pytest.mark.parametrize( + ("map_index", "expected"), + [ + pytest.param(0, "2024-12-01", id="0"), + pytest.param(1, 2, id="1"), + ], +) +def test_expand_kwargs_render_template_fields_validating_operator( + map_index, expected, create_runtime_ti, mock_supervisor_comms +): + with DAG("test_dag"): + task1 = BaseOperator(task_id="op1") + mapped = MockOperator.partial(task_id="a", arg2="{{ ti.task_id }}").expand_kwargs(task1.output) + mock_supervisor_comms.get_message.return_value = XComResult( + key="return_value", value=json.dumps([{"arg1": "{{ ds }}"}, {"arg1": 2}]) + ) -# def test_expand_kwargs_render_template_fields_validating_operator(dag_maker, session, map_index, expected): + ti = create_runtime_ti(task=mapped, map_index=map_index, upstream_map_indexes={}) + assert isinstance(ti.task, MappedOperator) + ti.task.render_template_fields(context=ti.get_template_context()) + assert isinstance(ti.task, MockOperator) + assert ti.task.arg1 == expected + assert ti.task.arg2 == "a" def test_xcomarg_property_of_mapped_operator(): @@ -252,9 +321,6 @@ def test_set_xcomarg_dependencies_with_mapped_operator(): assert op2 in op5.upstream_list -# def test_all_xcomargs_from_mapped_tasks_are_consumable(dag_maker, session): - - def test_task_mapping_with_task_group_context(): from airflow.sdk.definitions.taskgroup import TaskGroup @@ -299,3 +365,331 @@ def test_task_mapping_with_explicit_task_group(): assert finish.upstream_list == [mapped] assert mapped.downstream_list == [finish] + + +def test_nested_mapped_task_groups(): + from airflow.decorators import task, task_group + + with DAG("test"): + + @task + def t(): + return [[1, 2], [3, 4]] + + @task + def m(x): + return x + + @task_group + def g1(x): + @task_group + def g2(y): + return m(y) + + return g2.expand(y=x) + + # Add a test once nested mapped task groups become supported + with pytest.raises(ValueError, match="Nested Mapped TaskGroups are not yet supported"): + g1.expand(x=t()) + + +RunTI = Callable[[DAG, str, int], TerminalTIState] + + +def test_map_cross_product(run_ti: RunTI, mock_supervisor_comms): + outputs = [] + + with DAG(dag_id="cross_product") as dag: + + @dag.task + def emit_numbers(): + return [1, 2] + + @dag.task + def emit_letters(): + return {"a": "x", "b": "y", "c": "z"} + + @dag.task + def show(number, letter): + outputs.append((number, letter)) + + show.expand(number=emit_numbers(), letter=emit_letters()) + + def xcom_get(): + # TODO: Tidy this after #45927 is reopened and fixed properly + last_request = mock_supervisor_comms.send_request.mock_calls[-1].kwargs["msg"] + if not isinstance(last_request, GetXCom): + return mock.DEFAULT + task = dag.get_task(last_request.task_id) + value = json.dumps(task.python_callable()) + return XComResult(key="return_value", value=value) + + mock_supervisor_comms.get_message.side_effect = xcom_get + + states = [run_ti(dag, "show", map_index) for map_index in range(6)] + assert states == [TerminalTIState.SUCCESS] * 6 + assert outputs == [ + (1, ("a", "x")), + (1, ("b", "y")), + (1, ("c", "z")), + (2, ("a", "x")), + (2, ("b", "y")), + (2, ("c", "z")), + ] + + +def test_map_product_same(run_ti: RunTI, mock_supervisor_comms): + """Test a mapped task can refer to the same source multiple times.""" + outputs = [] + + with DAG(dag_id="product_same") as dag: + + @dag.task + def emit_numbers(): + return [1, 2] + + @dag.task + def show(a, b): + outputs.append((a, b)) + + emit_task = emit_numbers() + show.expand(a=emit_task, b=emit_task) + + def xcom_get(): + # TODO: Tidy this after #45927 is reopened and fixed properly + last_request = mock_supervisor_comms.send_request.mock_calls[-1].kwargs["msg"] + if not isinstance(last_request, GetXCom): + return mock.DEFAULT + task = dag.get_task(last_request.task_id) + value = json.dumps(task.python_callable()) + return XComResult(key="return_value", value=value) + + mock_supervisor_comms.get_message.side_effect = xcom_get + + states = [run_ti(dag, "show", map_index) for map_index in range(4)] + assert states == [TerminalTIState.SUCCESS] * 4 + assert outputs == [(1, 1), (1, 2), (2, 1), (2, 2)] + + +class NestedFields: + """Nested fields for testing purposes.""" + + def __init__(self, field_1, field_2): + self.field_1 = field_1 + self.field_2 = field_2 + + +class MockOperatorWithNestedFields(BaseOperator): + """Operator with nested fields for testing purposes.""" + + template_fields = ("arg1", "arg2") + + def __init__(self, arg1: str = "", arg2: NestedFields | None = None, **kwargs): + super().__init__(**kwargs) + self.arg1 = arg1 + self.arg2 = arg2 + + def _render_nested_template_fields(self, content, context, jinja_env, seen_oids) -> None: + if id(content) not in seen_oids: + template_fields: tuple | None = None + + if isinstance(content, NestedFields): + template_fields = ("field_1", "field_2") + + if template_fields: + seen_oids.add(id(content)) + self._do_render_template_fields(content, template_fields, context, jinja_env, seen_oids) + return + + super()._render_nested_template_fields(content, context, jinja_env, seen_oids) + + +def test_find_mapped_dependants_in_another_group(): + from airflow.decorators import task as task_decorator + from airflow.sdk import TaskGroup + + @task_decorator + def gen(x): + return list(range(x)) + + @task_decorator + def add(x, y): + return x + y + + with DAG(dag_id="test"): + with TaskGroup(group_id="g1"): + gen_result = gen(3) + with TaskGroup(group_id="g2"): + add_result = add.partial(y=1).expand(x=gen_result) + + dependants = list(gen_result.operator.iter_mapped_dependants()) + assert dependants == [add_result.operator] + + +@pytest.mark.parametrize( + "partial_params, mapped_params, expected", + [ + pytest.param(None, [{"a": 1}], [{"a": 1}], id="simple"), + pytest.param({"b": 2}, [{"a": 1}], [{"a": 1, "b": 2}], id="merge"), + pytest.param({"b": 2}, [{"a": 1, "b": 3}], [{"a": 1, "b": 3}], id="override"), + pytest.param({"b": 2}, [{"a": 1, "b": 3}, {"b": 1}], [{"a": 1, "b": 3}, {"b": 1}], id="multiple"), + ], +) +def test_mapped_expand_against_params(create_runtime_ti, partial_params, mapped_params, expected): + with DAG("test"): + task = BaseOperator.partial(task_id="t", params=partial_params).expand(params=mapped_params) + + for map_index, expansion in enumerate(expected): + mapped_ti = create_runtime_ti(task=task, map_index=map_index) + mapped_ti.task.render_template_fields(context=mapped_ti.get_template_context()) + assert mapped_ti.task.params == expansion + + +def test_operator_mapped_task_group_receives_value(create_runtime_ti, mock_supervisor_comms): + # Test the runtime expansion behaviour of mapped task groups + mapped operators + results = {} + + from airflow.decorators import task_group + + with DAG("test") as dag: + + @dag.task + def t(value, *, ti=None): + results[(ti.task_id, ti.map_index)] = value + return value + + @task_group + def tg(va): + # Each expanded group has one t1 and t2 each. + t1 = t.override(task_id="t1")(va) + t2 = t.override(task_id="t2")(t1) + + with pytest.raises(NotImplementedError) as ctx: + t.override(task_id="t4").expand(value=va) + assert str(ctx.value) == "operator expansion in an expanded task group is not yet supported" + + return t2 + + # The group is mapped by 3. + t2 = tg.expand(va=[["a", "b"], [4], ["z"]]) + + # Aggregates results from task group. + t.override(task_id="t3")(t2) + + def xcom_get(): + # TODO: Tidy this after #45927 is reopened and fixed properly + last_request = mock_supervisor_comms.send_request.mock_calls[-1].kwargs["msg"] + if not isinstance(last_request, GetXCom): + return mock.DEFAULT + key = (last_request.task_id, last_request.map_index) + if key in expected_values: + value = expected_values[key] + return XComResult(key="return_value", value=json.dumps(value)) + elif last_request.map_index is None: + # Get all mapped XComValues for this ti + value = [v for k, v in expected_values.items() if k[0] == last_request.task_id] + return XComResult(key="return_value", value=json.dumps(value)) + return mock.DEFAULT + + mock_supervisor_comms.get_message.side_effect = xcom_get + + expected_values = { + ("tg.t1", 0): ["a", "b"], + ("tg.t1", 1): [4], + ("tg.t1", 2): ["z"], + ("tg.t2", 0): ["a", "b"], + ("tg.t2", 1): [4], + ("tg.t2", 2): ["z"], + ("t3", None): [["a", "b"], [4], ["z"]], + } + + # We hard-code the number of expansions here as the server is in charge of that. + expansion_per_task_id = { + "tg.t1": range(3), + "tg.t2": range(3), + "t3": [None], + } + for task in dag.tasks: + for map_index in expansion_per_task_id[task.task_id]: + mapped_ti = create_runtime_ti(task=task.prepare_for_execution(), map_index=map_index) + context = mapped_ti.get_template_context() + mapped_ti.task.render_template_fields(context) + mapped_ti.task.execute(context) + assert results == expected_values + + +@pytest.mark.xfail(reason="SkipMixin hasn't been ported over to use the Task Execution API yet") +def test_mapped_xcom_push_skipped_tasks(create_runtime_ti, mock_supervisor_comms): + from airflow.decorators import task_group + from airflow.operators.empty import EmptyOperator + + if TYPE_CHECKING: + from airflow.providers.standard.operators.python import ShortCircuitOperator + else: + ShortCircuitOperator = pytest.importorskip( + "airflow.providers.standard.operators.python" + ).ShortCircuitOperator + + with DAG("test") as dag: + + @task_group + def group(x): + short_op_push_xcom = ShortCircuitOperator( + task_id="push_xcom_from_shortcircuit", + python_callable=lambda arg: arg % 2 == 0, + op_kwargs={"arg": x}, + ) + empty_task = EmptyOperator(task_id="empty_task") + short_op_push_xcom >> empty_task + + group.expand(x=[0, 1]) + + for task in dag.tasks: + for map_index in range(2): + ti = create_runtime_ti(task=task.prepare_for_execution(), map_index=map_index) + context = ti.get_template_context() + ti.task.render_template_fields(context) + ti.task.execute(context) + + assert ti + # TODO: these tests might not be right + mock_supervisor_comms.send_request.assert_has_calls( + [ + SetXCom( + key="skipmixin_key", + value=None, + dag_id=ti.dag_id, + run_id=ti.run_id, + task_id="group.push_xcom_from_shortcircuit", + map_index=0, + ), + SetXCom( + key="return_value", + value=True, + dag_id=ti.dag_id, + run_id=ti.run_id, + task_id="group.push_xcom_from_shortcircuit", + map_index=0, + ), + SetXCom( + key="skipmixin_key", + value={"skipped": ["group.empty_task"]}, + dag_id=ti.dag_id, + run_id=ti.run_id, + task_id="group.push_xcom_from_shortcircuit", + map_index=1, + ), + ] + ) + # + # assert ( + # tis[0].xcom_pull(task_ids="group.push_xcom_from_shortcircuit", key="return_value", map_indexes=0) + # is True + # ) + # assert ( + # tis[0].xcom_pull(task_ids="group.push_xcom_from_shortcircuit", key="skipmixin_key", map_indexes=0) + # is None + # ) + # assert tis[0].xcom_pull( + # task_ids="group.push_xcom_from_shortcircuit", key="skipmixin_key", map_indexes=1 + # ) == {"skipped": ["group.empty_task"]} diff --git a/tests/utils/log/test_secrets_masker.py b/task_sdk/tests/definitions/test_secrets_masker.py similarity index 95% rename from tests/utils/log/test_secrets_masker.py rename to task_sdk/tests/definitions/test_secrets_masker.py index 1f4642fa85b9d..167f0c1a61d16 100644 --- a/tests/utils/log/test_secrets_masker.py +++ b/task_sdk/tests/definitions/test_secrets_masker.py @@ -30,7 +30,7 @@ import pytest from airflow.models import Connection -from airflow.utils.log.secrets_masker import ( +from airflow.sdk.execution_time.secrets_masker import ( RedactedIO, SecretsMasker, mask_secret, @@ -302,7 +302,7 @@ def test_redact_filehandles(self, caplog): def test_redact_max_depth(self, val, expected, max_depth): secrets_masker = SecretsMasker() secrets_masker.add_mask("abc") - with patch("airflow.utils.log.secrets_masker._secrets_masker", return_value=secrets_masker): + with patch("airflow.sdk.execution_time.secrets_masker._secrets_masker", return_value=secrets_masker): got = redact(val, max_depth=max_depth) assert got == expected @@ -343,7 +343,7 @@ def test_redact_state_enum(self, logger, caplog, state, expected): def test_masking_quoted_strings_in_connection(self, logger, caplog): secrets_masker = next(fltr for fltr in logger.filters if isinstance(fltr, SecretsMasker)) - with patch("airflow.utils.log.secrets_masker._secrets_masker", return_value=secrets_masker): + with patch("airflow.sdk.execution_time.secrets_masker._secrets_masker", return_value=secrets_masker): test_conn_attributes = dict( conn_type="scheme", host="host/location", @@ -388,7 +388,7 @@ def test_hiding_defaults(self, key, expected_result): ], ) def test_hiding_config(self, sensitive_variable_fields, key, expected_result): - from airflow.utils.log.secrets_masker import get_sensitive_variables_fields + from airflow.sdk.execution_time.secrets_masker import get_sensitive_variables_fields with conf_vars({("core", "sensitive_var_conn_names"): str(sensitive_variable_fields)}): get_sensitive_variables_fields.cache_clear() @@ -415,7 +415,9 @@ class TestRedactedIO: @pytest.fixture(scope="class", autouse=True) def reset_secrets_masker(self): self.secrets_masker = SecretsMasker() - with patch("airflow.utils.log.secrets_masker._secrets_masker", return_value=self.secrets_masker): + with patch( + "airflow.sdk.execution_time.secrets_masker._secrets_masker", return_value=self.secrets_masker + ): mask_secret(p) yield @@ -452,8 +454,10 @@ class TestMaskSecretAdapter: @pytest.fixture(autouse=True) def reset_secrets_masker_and_skip_escape(self): self.secrets_masker = SecretsMasker() - with patch("airflow.utils.log.secrets_masker._secrets_masker", return_value=self.secrets_masker): - with patch("airflow.utils.log.secrets_masker.re2.escape", lambda x: x): + with patch( + "airflow.sdk.execution_time.secrets_masker._secrets_masker", return_value=self.secrets_masker + ): + with patch("airflow.sdk.execution_time.secrets_masker.re2.escape", lambda x: x): yield def test_calling_mask_secret_adds_adaptations_for_returned_str(self): diff --git a/task_sdk/tests/definitions/test_xcom_arg.py b/task_sdk/tests/definitions/test_xcom_arg.py new file mode 100644 index 0000000000000..dc3c7f4916de0 --- /dev/null +++ b/task_sdk/tests/definitions/test_xcom_arg.py @@ -0,0 +1,360 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import json +from typing import Callable +from unittest import mock + +import pytest +import structlog +from pytest_unordered import unordered + +from airflow.exceptions import AirflowSkipException +from airflow.sdk.api.datamodels._generated import TerminalTIState +from airflow.sdk.definitions.dag import DAG +from airflow.sdk.execution_time.comms import GetXCom, XComResult + +log = structlog.get_logger() + +RunTI = Callable[[DAG, str, int], TerminalTIState] + + +def test_xcom_map(run_ti: RunTI, mock_supervisor_comms): + results = set() + with DAG("test") as dag: + + @dag.task + def push(): + return ["a", "b", "c"] + + @dag.task + def pull(value): + results.add(value) + + pull.expand_kwargs(push().map(lambda v: {"value": v * 2})) + + # The function passed to "map" is *NOT* a task. + assert set(dag.task_dict) == {"push", "pull"} + + # Mock xcom result from push task + mock_supervisor_comms.get_message.return_value = XComResult(key="return_value", value='["a", "b", "c"]') + + for map_index in range(3): + assert run_ti(dag, "pull", map_index) == TerminalTIState.SUCCESS + + assert results == {"aa", "bb", "cc"} + + +def test_xcom_map_transform_to_none(run_ti: RunTI, mock_supervisor_comms): + results = set() + + with DAG("test") as dag: + + @dag.task() + def push(): + return ["a", "b", "c"] + + @dag.task() + def pull(value): + results.add(value) + + def c_to_none(v): + if v == "c": + return None + return v + + pull.expand(value=push().map(c_to_none)) + + # Mock xcom result from push task + mock_supervisor_comms.get_message.return_value = XComResult(key="return_value", value='["a", "b", "c"]') + + # Run "pull". This should automatically convert "c" to None. + for map_index in range(3): + assert run_ti(dag, "pull", map_index) == TerminalTIState.SUCCESS + + assert results == {"a", "b", None} + + +def test_xcom_convert_to_kwargs_fails_task(run_ti: RunTI, mock_supervisor_comms, captured_logs): + results = set() + + with DAG("test") as dag: + + @dag.task() + def push(): + return ["a", "b", "c"] + + @dag.task() + def pull(value): + results.add(value) + + def c_to_none(v): + if v == "c": + return None + return {"value": v} + + pull.expand_kwargs(push().map(c_to_none)) + + # Mock xcom result from push task + mock_supervisor_comms.get_message.return_value = XComResult(key="return_value", value='["a", "b", "c"]') + + # The first two "pull" tis should succeed. + for map_index in range(2): + assert run_ti(dag, "pull", map_index) == TerminalTIState.SUCCESS + + # Clear captured logs from the above + captured_logs[:] = [] + + # But the third one fails because the map() result cannot be used as kwargs. + assert run_ti(dag, "pull", 2) == TerminalTIState.FAILED + + assert captured_logs == unordered( + [ + { + "event": "Task failed with exception", + "level": "error", + "timestamp": mock.ANY, + "exception": [ + { + "exc_type": "ValueError", + "exc_value": "expand_kwargs() expects a list[dict], not list[None]", + "frames": mock.ANY, + "is_cause": False, + "syntax_error": None, + } + ], + }, + ] + ) + + +def test_xcom_map_error_fails_task(mock_supervisor_comms, run_ti, captured_logs): + with DAG("test") as dag: + + @dag.task() + def push(): + return ["a", "b", "c"] + + @dag.task() + def pull(value): + print(value) + + def does_not_work_with_c(v): + if v == "c": + raise RuntimeError("nope") + return {"value": v * 2} + + pull.expand_kwargs(push().map(does_not_work_with_c)) + + # Mock xcom result from push task + mock_supervisor_comms.get_message.return_value = XComResult(key="return_value", value='["a", "b", "c"]') + # The third one (for "c") will fail. + assert run_ti(dag, "pull", 2) == TerminalTIState.FAILED + + assert captured_logs == unordered( + [ + { + "event": "Task failed with exception", + "level": "error", + "timestamp": mock.ANY, + "exception": [ + { + "exc_type": "RuntimeError", + "exc_value": "nope", + "frames": mock.ANY, + "is_cause": False, + "syntax_error": None, + } + ], + }, + ] + ) + + +def test_xcom_map_nest(mock_supervisor_comms, run_ti): + results = set() + + with DAG("test") as dag: + + @dag.task() + def push(): + return ["a", "b", "c"] + + @dag.task() + def pull(value): + results.add(value) + + converted = push().map(lambda v: v * 2).map(lambda v: {"value": v}) + pull.expand_kwargs(converted) + + # Mock xcom result from push task + mock_supervisor_comms.get_message.return_value = XComResult(key="return_value", value='["a", "b", "c"]') + + # Now "pull" should apply the mapping functions in order. + for map_index in range(3): + assert run_ti(dag, "pull", map_index) == TerminalTIState.SUCCESS + assert results == {"aa", "bb", "cc"} + + +def test_xcom_map_zip_nest(mock_supervisor_comms, run_ti): + results = set() + + with DAG("test") as dag: + + @dag.task + def push_letters(): + return ["a", "b", "c", "d"] + + @dag.task + def push_numbers(): + return [1, 2, 3, 4] + + @dag.task + def pull(value): + results.add(value) + + doubled = push_numbers().map(lambda v: v * 2) + combined = doubled.zip(push_letters()) + + def convert_zipped(zipped): + letter, number = zipped + return letter * number + + pull.expand(value=combined.map(convert_zipped)) + + def xcom_get(): + # TODO: Tidy this after #45927 is reopened and fixed properly + last_request = mock_supervisor_comms.send_request.mock_calls[-1].kwargs["msg"] + if not isinstance(last_request, GetXCom): + return mock.DEFAULT + if last_request.task_id == "push_letters": + value = json.dumps(push_letters.function()) + return XComResult(key="return_value", value=value) + if last_request.task_id == "push_numbers": + value = json.dumps(push_numbers.function()) + return XComResult(key="return_value", value=value) + return mock.DEFAULT + + mock_supervisor_comms.get_message.side_effect = xcom_get + + # Run "pull". + for map_index in range(4): + assert run_ti(dag, "pull", map_index) == TerminalTIState.SUCCESS + + assert results == {"aa", "bbbb", "cccccc", "dddddddd"} + + +def test_xcom_map_raise_to_skip(run_ti, mock_supervisor_comms): + result = [] + + with DAG("test") as dag: + + @dag.task() + def push(): + return ["a", "b", "c"] + + @dag.task() + def forward(value): + result.append(value) + + def skip_c(v): + if v == "c": + raise AirflowSkipException() + return {"value": v} + + forward.expand_kwargs(push().map(skip_c)) + + # Mock xcom result from push task + mock_supervisor_comms.get_message.return_value = XComResult(key="return_value", value='["a", "b", "c"]') + + # Run "forward". This should automatically skip "c". + states = [run_ti(dag, "forward", map_index) for map_index in range(3)] + + assert states == [TerminalTIState.SUCCESS, TerminalTIState.SUCCESS, TerminalTIState.SKIPPED] + + assert result == ["a", "b"] + + +def test_xcom_concat(run_ti, mock_supervisor_comms): + from airflow.sdk.definitions.xcom_arg import _ConcatResult + + agg_results = set() + all_results = None + + with DAG("test") as dag: + + @dag.task + def push_letters(): + return ["a", "b", "c"] + + @dag.task + def push_numbers(): + return [1, 2] + + @dag.task + def pull_one(value): + agg_results.add(value) + + @dag.task + def pull_all(value): + assert isinstance(value, _ConcatResult) + assert value[0] == "a" + assert value[1] == "b" + assert value[2] == "c" + assert value[3] == 1 + assert value[4] == 2 + with pytest.raises(IndexError): + value[5] + assert value[-5] == "a" + assert value[-4] == "b" + assert value[-3] == "c" + assert value[-2] == 1 + assert value[-1] == 2 + with pytest.raises(IndexError): + value[-6] + nonlocal all_results + all_results = list(value) + + pushed_values = push_letters().concat(push_numbers()) + + pull_one.expand(value=pushed_values) + pull_all(pushed_values) + + def xcom_get(): + # TODO: Tidy this after #45927 is reopened and fixed properly + last_request = mock_supervisor_comms.send_request.mock_calls[-1].kwargs["msg"] + if not isinstance(last_request, GetXCom): + return mock.DEFAULT + if last_request.task_id == "push_letters": + value = json.dumps(push_letters.function()) + return XComResult(key="return_value", value=value) + if last_request.task_id == "push_numbers": + value = json.dumps(push_numbers.function()) + return XComResult(key="return_value", value=value) + return mock.DEFAULT + + mock_supervisor_comms.get_message.side_effect = xcom_get + + # Run "pull_one" and "pull_all". + assert run_ti(dag, "pull_all", None) == TerminalTIState.SUCCESS + assert all_results == ["a", "b", "c", 1, 2] + + states = [run_ti(dag, "pull_one", map_index) for map_index in range(5)] + assert states == [TerminalTIState.SUCCESS] * 5 + assert agg_results == {"a", "b", "c", 1, 2} diff --git a/task_sdk/tests/execution_time/conftest.py b/task_sdk/tests/execution_time/conftest.py index ac0c21246c1ce..4a537373363aa 100644 --- a/task_sdk/tests/execution_time/conftest.py +++ b/task_sdk/tests/execution_time/conftest.py @@ -18,14 +18,6 @@ from __future__ import annotations import sys -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from datetime import datetime - - from airflow.sdk.definitions.baseoperator import BaseOperator - from airflow.sdk.execution_time.comms import StartupDetails - from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance import pytest @@ -39,117 +31,3 @@ def disable_capturing(): sys.stderr = sys.__stderr__ yield sys.stdin, sys.stdout, sys.stderr = old_in, old_out, old_err - - -@pytest.fixture -def mocked_parse(spy_agency): - """ - Fixture to set up an inline DAG and use it in a stubbed `parse` function. Use this fixture if you - want to isolate and test `parse` or `run` logic without having to define a DAG file. - - This fixture returns a helper function `set_dag` that: - 1. Creates an in line DAG with the given `dag_id` and `task` (limited to one task) - 2. Constructs a `RuntimeTaskInstance` based on the provided `StartupDetails` and task. - 3. Stubs the `parse` function using `spy_agency`, to return the mocked `RuntimeTaskInstance`. - - After adding the fixture in your test function signature, you can use it like this :: - - mocked_parse( - StartupDetails( - ti=TaskInstance(id=uuid7(), task_id="hello", dag_id="super_basic_run", run_id="c", try_number=1), - file="", - requests_fd=0, - ), - "example_dag_id", - CustomOperator(task_id="hello"), - ) - """ - - def set_dag(what: StartupDetails, dag_id: str, task: BaseOperator) -> RuntimeTaskInstance: - from airflow.sdk.definitions.dag import DAG - from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance, parse - from airflow.utils import timezone - - dag = DAG(dag_id=dag_id, start_date=timezone.datetime(2024, 12, 3)) - if what.ti_context.dag_run.conf: - dag.params = what.ti_context.dag_run.conf # type: ignore[assignment] - task.dag = dag - t = dag.task_dict[task.task_id] - ti = RuntimeTaskInstance.model_construct( - **what.ti.model_dump(exclude_unset=True), - task=t, - _ti_context_from_server=what.ti_context, - max_tries=what.ti_context.max_tries, - ) - spy_agency.spy_on(parse, call_fake=lambda _: ti) - return ti - - return set_dag - - -@pytest.fixture -def create_runtime_ti(mocked_parse, make_ti_context): - """ - Fixture to create a Runtime TaskInstance for testing purposes without defining a dag file. - - It mimics the behavior of the `parse` function by creating a `RuntimeTaskInstance` based on the provided - `StartupDetails` (formed from arguments) and task. This allows you to test the logic of a task without - having to define a DAG file, parse it, get context from the server, etc. - - Example usage: :: - - def test_custom_task_instance(create_runtime_ti): - class MyTaskOperator(BaseOperator): - def execute(self, context): - assert context["dag_run"].run_id == "test_run" - - task = MyTaskOperator(task_id="test_task") - ti = create_runtime_ti(task, context_from_server=make_ti_context(run_id="test_run")) - # Further test logic... - """ - from uuid6 import uuid7 - - from airflow.sdk.api.datamodels._generated import TaskInstance - from airflow.sdk.execution_time.comms import BundleInfo, StartupDetails - - def _create_task_instance( - task: BaseOperator, - dag_id: str = "test_dag", - run_id: str = "test_run", - logical_date: str | datetime = "2024-12-01T01:00:00Z", - data_interval_start: str | datetime = "2024-12-01T00:00:00Z", - data_interval_end: str | datetime = "2024-12-01T01:00:00Z", - start_date: str | datetime = "2024-12-01T01:00:00Z", - run_type: str = "manual", - try_number: int = 1, - conf=None, - ti_id=None, - ) -> RuntimeTaskInstance: - if not ti_id: - ti_id = uuid7() - - ti_context = make_ti_context( - dag_id=dag_id, - run_id=run_id, - logical_date=logical_date, - data_interval_start=data_interval_start, - data_interval_end=data_interval_end, - start_date=start_date, - run_type=run_type, - conf=conf, - ) - - startup_details = StartupDetails( - ti=TaskInstance( - id=ti_id, task_id=task.task_id, dag_id=dag_id, run_id=run_id, try_number=try_number - ), - dag_rel_path="", - bundle_info=BundleInfo(name="anything", version="any"), - requests_fd=0, - ti_context=ti_context, - ) - - ti = mocked_parse(startup_details, dag_id, task) - return ti - - return _create_task_instance diff --git a/task_sdk/tests/execution_time/test_supervisor.py b/task_sdk/tests/execution_time/test_supervisor.py index 4bc8febc67c14..875ed3e00022f 100644 --- a/task_sdk/tests/execution_time/test_supervisor.py +++ b/task_sdk/tests/execution_time/test_supervisor.py @@ -238,7 +238,7 @@ def subprocess_main(): assert "Last chance exception handler" in captured.err assert "RuntimeError: Fake syntax error" in captured.err - def test_regular_heartbeat(self, spy_agency: kgb.SpyAgency, monkeypatch): + def test_regular_heartbeat(self, spy_agency: kgb.SpyAgency, monkeypatch, mocker, make_ti_context): """Test that the WatchedSubprocess class regularly sends heartbeat requests, up to a certain frequency""" import airflow.sdk.execution_time.supervisor @@ -252,6 +252,8 @@ def subprocess_main(): sleep(0.05) ti_id = uuid7() + _ = mocker.patch.object(sdk_client.TaskInstanceOperations, "start", return_value=make_ti_context()) + spy = spy_agency.spy_on(sdk_client.TaskInstanceOperations.heartbeat) proc = ActivitySubprocess.start( dag_rel_path=os.devnull, @@ -271,7 +273,7 @@ def subprocess_main(): # The exact number we get will depend on timing behaviour, so be a little lenient assert 1 <= len(spy.calls) <= 4 - def test_run_simple_dag(self, test_dags_dir, captured_logs, time_machine): + def test_run_simple_dag(self, test_dags_dir, captured_logs, time_machine, mocker, make_ti_context): """Test running a simple DAG in a subprocess and capturing the output.""" instant = tz.datetime(2024, 11, 7, 12, 34, 56, 78901) @@ -285,6 +287,12 @@ def test_run_simple_dag(self, test_dags_dir, captured_logs, time_machine): run_id="c", try_number=1, ) + + # Create a mock client to assert calls to the client + # We assume the implementation of the client is correct and only need to check the calls + mock_client = mocker.Mock(spec=sdk_client.Client) + mock_client.task_instances.start.return_value = make_ti_context() + bundle_info = BundleInfo(name="my-bundle", version=None) with patch.dict(os.environ, local_dag_bundle_cfg(test_dags_dir, bundle_info.name)): exit_code = supervise( @@ -293,6 +301,7 @@ def test_run_simple_dag(self, test_dags_dir, captured_logs, time_machine): token="", server="", dry_run=True, + client=mock_client, bundle_info=bundle_info, ) assert exit_code == 0, captured_logs diff --git a/task_sdk/tests/execution_time/test_task_runner.py b/task_sdk/tests/execution_time/test_task_runner.py index ffac89f51387f..85804836981d0 100644 --- a/task_sdk/tests/execution_time/test_task_runner.py +++ b/task_sdk/tests/execution_time/test_task_runner.py @@ -17,10 +17,11 @@ from __future__ import annotations +import contextlib import json import os import uuid -from datetime import timedelta +from datetime import datetime, timedelta from pathlib import Path from socket import socketpair from unittest import mock @@ -29,6 +30,7 @@ import pytest from uuid6 import uuid7 +from airflow.decorators import task as task_decorator from airflow.exceptions import ( AirflowException, AirflowFailException, @@ -37,9 +39,10 @@ AirflowTaskTerminated, ) from airflow.providers.standard.operators.python import PythonOperator -from airflow.sdk import DAG, BaseOperator, Connection, get_current_context +from airflow.sdk import DAG, BaseOperator, Connection, dag as dag_decorator, get_current_context from airflow.sdk.api.datamodels._generated import AssetProfile, TaskInstance, TerminalTIState from airflow.sdk.definitions.asset import Asset, AssetAlias +from airflow.sdk.definitions.param import DagParam from airflow.sdk.definitions.variable import Variable from airflow.sdk.execution_time.comms import ( BundleInfo, @@ -106,7 +109,7 @@ def test_recv_StartupDetails(self): b'"id": "4d828a62-a417-4936-a7a6-2b3fabacecab", "task_id": "a", "try_number": 1, "run_id": "b", "dag_id": "c" }, ' b'"ti_context":{"dag_run":{"dag_id":"c","run_id":"b","logical_date":"2024-12-01T01:00:00Z",' b'"data_interval_start":"2024-12-01T00:00:00Z","data_interval_end":"2024-12-01T01:00:00Z",' - b'"start_date":"2024-12-01T01:00:00Z","end_date":null,"run_type":"manual","conf":null},' + b'"start_date":"2024-12-01T01:00:00Z","run_after":"2024-12-01T01:00:00Z","end_date":null,"run_type":"manual","conf":null},' b'"max_tries":0,"variables":null,"connections":null},"file": "/dev/null", "dag_rel_path": "/dev/null", "bundle_info": {"name": ' b'"any-name", "version": "any-version"}, "requests_fd": ' + str(w2.fileno()).encode("ascii") @@ -161,28 +164,6 @@ def test_parse(test_dags_dir: Path, make_ti_context): assert isinstance(ti.task.dag, DAG) -def test_run_basic(time_machine, create_runtime_ti, spy_agency, mock_supervisor_comms): - """Test running a basic task.""" - instant = timezone.datetime(2024, 12, 3, 10, 0) - time_machine.move_to(instant, tick=False) - - ti = create_runtime_ti(dag_id="super_basic_run", task=CustomOperator(task_id="hello")) - - # Ensure that task is locked for execution - spy_agency.spy_on(ti.task.prepare_for_execution) - assert not ti.task._lock_for_execution - - run(ti, log=mock.MagicMock()) - - spy_agency.assert_spy_called(ti.task.prepare_for_execution) - assert ti.task._lock_for_execution - - mock_supervisor_comms.send_request.assert_called_once_with( - msg=SucceedTask(state=TerminalTIState.SUCCESS, end_date=instant, task_outlets=[], outlet_events=[]), - log=mock.ANY, - ) - - def test_run_deferred_basic(time_machine, create_runtime_ti, mock_supervisor_comms): """Test that a task can transition to a deferred state.""" import datetime @@ -216,7 +197,7 @@ def test_run_deferred_basic(time_machine, create_runtime_ti, mock_supervisor_com run(ti, log=mock.MagicMock()) # send_request will only be called when the TaskDeferred exception is raised - mock_supervisor_comms.send_request.assert_called_once_with(msg=expected_defer_task, log=mock.ANY) + mock_supervisor_comms.send_request.assert_any_call(msg=expected_defer_task, log=mock.ANY) def test_run_basic_skipped(time_machine, create_runtime_ti, mock_supervisor_comms): @@ -236,7 +217,7 @@ def test_run_basic_skipped(time_machine, create_runtime_ti, mock_supervisor_comm run(ti, log=mock.MagicMock()) - mock_supervisor_comms.send_request.assert_called_once_with( + mock_supervisor_comms.send_request.assert_called_with( msg=TaskState(state=TerminalTIState.SKIPPED, end_date=instant), log=mock.ANY ) @@ -256,7 +237,7 @@ def test_run_raises_base_exception(time_machine, create_runtime_ti, mock_supervi run(ti, log=mock.MagicMock()) - mock_supervisor_comms.send_request.assert_called_once_with( + mock_supervisor_comms.send_request.assert_called_with( msg=TaskState( state=TerminalTIState.FAILED, end_date=instant, @@ -280,7 +261,7 @@ def test_run_raises_system_exit(time_machine, create_runtime_ti, mock_supervisor run(ti, log=mock.MagicMock()) - mock_supervisor_comms.send_request.assert_called_once_with( + mock_supervisor_comms.send_request.assert_called_with( msg=TaskState( state=TerminalTIState.FAILED, end_date=instant, @@ -306,7 +287,7 @@ def test_run_raises_airflow_exception(time_machine, create_runtime_ti, mock_supe run(ti, log=mock.MagicMock()) - mock_supervisor_comms.send_request.assert_called_once_with( + mock_supervisor_comms.send_request.assert_called_with( msg=TaskState( state=TerminalTIState.FAILED, end_date=instant, @@ -333,7 +314,7 @@ def test_run_task_timeout(time_machine, create_runtime_ti, mock_supervisor_comms run(ti, log=mock.MagicMock()) # this state can only be reached if the try block passed down the exception to handler of AirflowTaskTimeout - mock_supervisor_comms.send_request.assert_called_once_with( + mock_supervisor_comms.send_request.assert_called_with( msg=TaskState( state=TerminalTIState.FAILED, end_date=instant, @@ -342,7 +323,7 @@ def test_run_task_timeout(time_machine, create_runtime_ti, mock_supervisor_comms ) -def test_startup_basic_templated_dag(mocked_parse, make_ti_context, mock_supervisor_comms): +def test_basic_templated_dag(mocked_parse, make_ti_context, mock_supervisor_comms, spy_agency): """Test running a DAG with templated task.""" from airflow.providers.standard.operators.bash import BashOperator @@ -360,15 +341,23 @@ def test_startup_basic_templated_dag(mocked_parse, make_ti_context, mock_supervi requests_fd=0, ti_context=make_ti_context(), ) - mocked_parse(what, "basic_templated_dag", task) + ti = mocked_parse(what, "basic_templated_dag", task) - mock_supervisor_comms.get_message.return_value = what - startup() + # Ensure that task is locked for execution + spy_agency.spy_on(task.prepare_for_execution) + assert not task._lock_for_execution - mock_supervisor_comms.send_request.assert_called_once_with( + # mock_supervisor_comms.get_message.return_value = what + run(ti, log=mock.Mock()) + + spy_agency.assert_spy_called(task.prepare_for_execution) + assert ti.task._lock_for_execution + assert ti.task is not task, "ti.task should be a copy of the original task" + + mock_supervisor_comms.send_request.assert_any_call( msg=SetRenderedFields( rendered_fields={ - "bash_command": "echo 'Logical date is {{ logical_date }}'", + "bash_command": "echo 'Logical date is 2024-12-01 01:00:00+00:00'", "cwd": None, "env": None, } @@ -459,15 +448,14 @@ def execute(self, context): requests_fd=0, ti_context=make_ti_context(), ) - ti = mocked_parse(what, "basic_dag", task) + mocked_parse(what, "basic_dag", task) instant = timezone.datetime(2024, 12, 3, 10, 0) time_machine.move_to(instant, tick=False) mock_supervisor_comms.get_message.return_value = what - startup() - run(ti, log=mock.MagicMock()) + run(*startup()) expected_calls = [ mock.call.send_request( msg=SetRenderedFields(rendered_fields=expected_rendered_fields), @@ -494,11 +482,11 @@ def execute(self, context): ("{{ logical_date }}", "2024-12-01 01:00:00+00:00"), ], ) +@pytest.mark.usefixtures("mock_supervisor_comms") def test_startup_and_run_dag_with_templated_fields( - command, rendered_command, create_runtime_ti, time_machine, mock_supervisor_comms + command, rendered_command, create_runtime_ti, time_machine ): """Test startup of a DAG with various templated fields.""" - from airflow.providers.standard.operators.bash import BashOperator task = BashOperator(task_id="templated_task", bash_command=command) @@ -509,7 +497,6 @@ def test_startup_and_run_dag_with_templated_fields( instant = timezone.datetime(2024, 12, 3, 10, 0) time_machine.move_to(instant, tick=False) - run(ti, log=mock.MagicMock()) assert ti.task.bash_command == rendered_command @@ -722,6 +709,7 @@ def test_run_with_inlets_and_outlets( create_runtime_ti, mock_supervisor_comms, time_machine, ok, last_expected_msg ): """Test running a basic tasks with inlets and outlets.""" + instant = timezone.datetime(2024, 12, 3, 10, 0) time_machine.move_to(instant, tick=False) @@ -1037,8 +1025,6 @@ def execute(self, context): mock_supervisor_comms.get_message.return_value = XComResult(key="key", value='"value"') - spy_agency.spy_on(runtime_ti.xcom_pull, call_original=True) - run(runtime_ti, log=mock.MagicMock()) if isinstance(task_ids, str): @@ -1086,6 +1072,41 @@ def execute(self, context): "a_simple_list": ["one", "two", "three", "actually one value is made per line"], } + @pytest.mark.parametrize( + ("logical_date", "check"), + ( + pytest.param(None, pytest.raises(KeyError), id="no-logical-date"), + pytest.param(timezone.datetime(2024, 12, 3), contextlib.nullcontext(), id="with-logical-date"), + ), + ) + def test_no_logical_date_key_error( + self, mocked_parse, make_ti_context, mock_supervisor_comms, create_runtime_ti, logical_date, check + ): + """Test that a params can be retrieved from context.""" + + class CustomOperator(BaseOperator): + def execute(self, context): + for key in ("ds", "ds_nodash", "ts", "ts_nodash", "ts_nodash_with_tz"): + with check: + context[key] + # We should always be able to get this + assert context["task_instance_key_str"] + + task = CustomOperator(task_id="print-params") + runtime_ti = create_runtime_ti( + dag_id="basic_param_dag", + logical_date=logical_date, + task=task, + conf={ + "x": 3, + "text": "Hello World!", + "flag": False, + "a_simple_list": ["one", "two", "three", "actually one value is made per line"], + }, + ) + msg = run(runtime_ti, log=mock.MagicMock()) + assert isinstance(msg, SucceedTask) + class TestXComAfterTaskExecution: @pytest.mark.parametrize( @@ -1198,6 +1219,15 @@ def execute(self, context): class TestDagParamRuntime: + DEFAULT_ARGS = { + "owner": "test", + "depends_on_past": True, + "start_date": datetime.now(tz=timezone.utc), + "retries": 1, + "retry_delay": timedelta(minutes=1), + } + VALUE = 42 + def test_dag_param_resolves_from_task(self, create_runtime_ti, mock_supervisor_comms, time_machine): """Test dagparam resolves on operator execution""" instant = timezone.datetime(2024, 12, 3, 10, 0) @@ -1252,7 +1282,7 @@ def execute(self, context): ) def test_dag_param_dag_default(self, create_runtime_ti, mock_supervisor_comms, time_machine): - """ "Test dag param is retrieved from default config""" + """Test that dag param is correctly resolved by operator""" instant = timezone.datetime(2024, 12, 3, 10, 0) time_machine.move_to(instant, tick=False) @@ -1277,3 +1307,108 @@ def execute(self, context): ), log=mock.ANY, ) + + def test_dag_param_resolves( + self, create_runtime_ti, mock_supervisor_comms, time_machine, make_ti_context + ): + """Test that dag param is correctly resolved by operator""" + + instant = timezone.datetime(2024, 12, 3, 10, 0) + time_machine.move_to(instant, tick=False) + + @dag_decorator(schedule=None, start_date=timezone.datetime(2024, 12, 3)) + def dag_with_dag_params(value="NOTSET"): + @task_decorator + def dummy_task(val): + return val + + class CustomOperator(BaseOperator): + def execute(self, context): + assert self.dag.params["value"] == "NOTSET" + + _ = dummy_task(value) + custom_task = CustomOperator(task_id="task_with_dag_params") + self.operator = custom_task + + dag_with_dag_params() + + runtime_ti = create_runtime_ti(task=self.operator, dag_id="dag_with_dag_params") + + run(runtime_ti, log=mock.MagicMock()) + + mock_supervisor_comms.send_request.assert_called_once_with( + msg=SucceedTask( + state=TerminalTIState.SUCCESS, end_date=instant, task_outlets=[], outlet_events=[] + ), + log=mock.ANY, + ) + + def test_dag_param_dagrun_parameterized( + self, create_runtime_ti, mock_supervisor_comms, time_machine, make_ti_context + ): + """Test that dag param is correctly overwritten when set in dag run""" + + instant = timezone.datetime(2024, 12, 3, 10, 0) + time_machine.move_to(instant, tick=False) + + @dag_decorator(schedule=None, start_date=timezone.datetime(2024, 12, 3)) + def dag_with_dag_params(value=self.VALUE): + @task_decorator + def dummy_task(val): + return val + + assert isinstance(value, DagParam) + + class CustomOperator(BaseOperator): + def execute(self, context): + assert self.dag.params["value"] == "new_value" + + _ = dummy_task(value) + custom_task = CustomOperator(task_id="task_with_dag_params") + self.operator = custom_task + + dag_with_dag_params() + + runtime_ti = create_runtime_ti( + task=self.operator, dag_id="dag_with_dag_params", conf={"value": "new_value"} + ) + + run(runtime_ti, log=mock.MagicMock()) + + mock_supervisor_comms.send_request.assert_called_once_with( + msg=SucceedTask( + state=TerminalTIState.SUCCESS, end_date=instant, task_outlets=[], outlet_events=[] + ), + log=mock.ANY, + ) + + @pytest.mark.parametrize("value", [VALUE, 0]) + def test_set_params_for_dag( + self, create_runtime_ti, mock_supervisor_comms, time_machine, make_ti_context, value + ): + """Test that dag param is correctly set when using dag decorator""" + + instant = timezone.datetime(2024, 12, 3, 10, 0) + time_machine.move_to(instant, tick=False) + + @dag_decorator(schedule=None, start_date=timezone.datetime(2024, 12, 3)) + def dag_with_param(value=value): + @task_decorator + def return_num(num): + return num + + xcom_arg = return_num(value) + self.operator = xcom_arg.operator + + dag_with_param() + + runtime_ti = create_runtime_ti(task=self.operator, dag_id="dag_with_param", conf={"value": value}) + + run(runtime_ti, log=mock.MagicMock()) + + mock_supervisor_comms.send_request.assert_any_call( + msg=SucceedTask( + state=TerminalTIState.SUCCESS, end_date=instant, task_outlets=[], outlet_events=[] + ), + log=mock.ANY, + ) diff --git a/tests/always/test_project_structure.py b/tests/always/test_project_structure.py index ff3347e7fd3dd..4c1e403670231 100644 --- a/tests/always/test_project_structure.py +++ b/tests/always/test_project_structure.py @@ -86,20 +86,18 @@ def test_providers_modules_should_have_tests(self): "providers/tests/celery/executors/test_celery_executor_utils.py", "providers/tests/celery/executors/test_default_celery.py", "providers/tests/cloudant/test_cloudant_fake.py", - "providers/tests/cncf/kubernetes/executors/test_kubernetes_executor_types.py", - "providers/tests/cncf/kubernetes/executors/test_kubernetes_executor_utils.py", - "providers/tests/cncf/kubernetes/operators/test_kubernetes_pod.py", - "providers/tests/cncf/kubernetes/test_k8s_model.py", - "providers/tests/cncf/kubernetes/test_kube_client.py", - "providers/tests/cncf/kubernetes/test_kube_config.py", - "providers/tests/cncf/kubernetes/test_pod_generator_deprecated.py", - "providers/tests/cncf/kubernetes/test_pod_launcher_deprecated.py", - "providers/tests/cncf/kubernetes/test_python_kubernetes_script.py", - "providers/tests/cncf/kubernetes/test_secret.py", - "providers/tests/cncf/kubernetes/triggers/test_kubernetes_pod.py", - "providers/tests/cncf/kubernetes/utils/test_delete_from.py", - "providers/tests/cncf/kubernetes/utils/test_k8s_hashlib_wrapper.py", - "providers/tests/cncf/kubernetes/utils/test_xcom_sidecar.py", + "providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/executors/test_kubernetes_executor_types.py", + "providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/executors/test_kubernetes_executor_utils.py", + "providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/operators/test_kubernetes_pod.py", + "providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/test_k8s_model.py", + "providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/test_kube_client.py", + "providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/test_kube_config.py", + "providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/test_python_kubernetes_script.py", + "providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/test_secret.py", + "providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/triggers/test_kubernetes_pod.py", + "providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/utils/test_delete_from.py", + "providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/utils/test_k8s_hashlib_wrapper.py", + "providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/utils/test_xcom_sidecar.py", "providers/google/tests/provider_tests/google/cloud/fs/test_gcs.py", "providers/google/tests/provider_tests/google/cloud/links/test_automl.py", "providers/google/tests/provider_tests/google/cloud/links/test_base.py", @@ -266,7 +264,8 @@ def get_classes_from_file( continue if is_new: - results[f"{'.'.join(module.split('.')[2:])}.{current_node.name}"] = current_node + module_path = module[module.find("airflow.providers") :] + results[f"{module_path}.{current_node.name}"] = current_node else: results[f"{module}.{current_node.name}"] = current_node print(f"{results}") @@ -568,7 +567,7 @@ class TestElasticsearchProviderProjectStructure(ExampleCoverageTest): class TestCncfProviderProjectStructure(ExampleCoverageTest): - PROVIDER = "cncf" + PROVIDER = "cncf/kubernetes" CLASS_DIRS = ProjectStructureTest.CLASS_DIRS BASE_CLASSES = {"airflow.providers.cncf.kubernetes.operators.resource.KubernetesResourceBaseOperator"} diff --git a/tests/api_fastapi/conftest.py b/tests/api_fastapi/conftest.py index 61a9b7e4b31a5..5c8b2cbe85a7d 100644 --- a/tests/api_fastapi/conftest.py +++ b/tests/api_fastapi/conftest.py @@ -16,6 +16,7 @@ # under the License. from __future__ import annotations +import datetime import os import pytest @@ -56,12 +57,13 @@ def make_dag_with_multiple_versions(dag_maker): for version_number in range(1, 4): with dag_maker(dag_id) as dag: - for i in range(version_number): - EmptyOperator(task_id=f"task{i+1}") + for task_number in range(version_number): + EmptyOperator(task_id=f"task{task_number + 1}") dag.sync_to_db() SerializedDagModel.write_dag(dag, bundle_name="dag_maker") dag_maker.create_dagrun( - run_id=f"run{i+1}", + run_id=f"run{version_number}", + logical_date=datetime.datetime(2020, 1, version_number, tzinfo=datetime.timezone.utc), dag_version=DagVersion.get_version(dag_id=dag_id, version_number=version_number), ) diff --git a/tests/api_fastapi/core_api/routes/public/test_assets.py b/tests/api_fastapi/core_api/routes/public/test_assets.py index a48c0da87fc7a..540747cff24a4 100644 --- a/tests/api_fastapi/core_api/routes/public/test_assets.py +++ b/tests/api_fastapi/core_api/routes/public/test_assets.py @@ -149,7 +149,7 @@ def _create_dag_run(session, num: int = 2): dag_id="source_dag_id", run_id=f"source_run_id_{i}", run_type=DagRunType.MANUAL, - logical_date=DEFAULT_DATE, + logical_date=DEFAULT_DATE + timedelta(days=i - 1), start_date=DEFAULT_DATE, data_interval=(DEFAULT_DATE, DEFAULT_DATE), external_trigger=True, @@ -579,7 +579,9 @@ def test_should_respond_200(self, test_client, session): { "run_id": "source_run_id_2", "dag_id": "source_dag_id", - "logical_date": from_datetime_to_zulu_without_ms(DEFAULT_DATE), + "logical_date": from_datetime_to_zulu_without_ms( + DEFAULT_DATE + timedelta(days=1), + ), "start_date": from_datetime_to_zulu_without_ms(DEFAULT_DATE), "end_date": from_datetime_to_zulu_without_ms(DEFAULT_DATE), "state": "success", @@ -747,7 +749,9 @@ def test_should_mask_sensitive_extra(self, test_client, session): { "run_id": "source_run_id_2", "dag_id": "source_dag_id", - "logical_date": from_datetime_to_zulu_without_ms(DEFAULT_DATE), + "logical_date": from_datetime_to_zulu_without_ms( + DEFAULT_DATE + timedelta(days=1), + ), "start_date": from_datetime_to_zulu_without_ms(DEFAULT_DATE), "end_date": from_datetime_to_zulu_without_ms(DEFAULT_DATE), "state": "success", diff --git a/tests/api_fastapi/core_api/routes/public/test_dag_run.py b/tests/api_fastapi/core_api/routes/public/test_dag_run.py index de21c23e8d0ae..92e19c839fcd5 100644 --- a/tests/api_fastapi/core_api/routes/public/test_dag_run.py +++ b/tests/api_fastapi/core_api/routes/public/test_dag_run.py @@ -1319,7 +1319,7 @@ def test_should_respond_400_if_a_dag_has_import_errors(self, test_client, sessio ) @time_machine.travel(timezone.utcnow(), tick=False) - def test_should_response_200_for_duplicate_logical_date(self, test_client): + def test_should_response_409_for_duplicate_logical_date(self, test_client): RUN_ID_1 = "random_1" RUN_ID_2 = "random_2" now = timezone.utcnow().isoformat().replace("+00:00", "Z") @@ -1333,28 +1333,26 @@ def test_should_response_200_for_duplicate_logical_date(self, test_client): json={"dag_run_id": RUN_ID_2, "note": note}, ) - assert response_1.status_code == response_2.status_code == 200 - body1 = response_1.json() - body2 = response_2.json() - - for each_run_id, each_body in [(RUN_ID_1, body1), (RUN_ID_2, body2)]: - assert each_body == { - "dag_run_id": each_run_id, - "dag_id": DAG1_ID, - "logical_date": now, - "queued_at": now, - "start_date": None, - "end_date": None, - "data_interval_start": now, - "data_interval_end": now, - "last_scheduling_decision": None, - "run_type": "manual", - "state": "queued", - "external_trigger": True, - "triggered_by": "rest_api", - "conf": {}, - "note": note, - } + assert response_1.status_code == 200 + assert response_1.json() == { + "dag_run_id": RUN_ID_1, + "dag_id": DAG1_ID, + "logical_date": now, + "queued_at": now, + "start_date": None, + "end_date": None, + "data_interval_start": now, + "data_interval_end": now, + "last_scheduling_decision": None, + "run_type": "manual", + "state": "queued", + "external_trigger": True, + "triggered_by": "rest_api", + "conf": {}, + "note": note, + } + + assert response_2.status_code == 409 @pytest.mark.parametrize( "data_interval_start, data_interval_end", diff --git a/tests/api_fastapi/core_api/routes/public/test_task_instances.py b/tests/api_fastapi/core_api/routes/public/test_task_instances.py index 7cf3f251a2d99..7f4d3a8a92c39 100644 --- a/tests/api_fastapi/core_api/routes/public/test_task_instances.py +++ b/tests/api_fastapi/core_api/routes/public/test_task_instances.py @@ -20,6 +20,7 @@ import datetime as dt import itertools import os +from datetime import timedelta from unittest import mock import pendulum @@ -225,7 +226,7 @@ def test_should_respond_200_with_versions(self, test_client, run_id, expected_ve "dag_id": "dag_with_multiple_versions", "dag_run_id": run_id, "map_index": -1, - "logical_date": "2016-01-01T00:00:00Z", + "logical_date": mock.ANY, "start_date": None, "end_date": mock.ANY, "duration": None, @@ -1109,9 +1110,15 @@ def test_should_respond_200_for_dag_id_filter(self, test_client, session): assert count == len(response.json()["task_instances"]) @pytest.mark.parametrize( - "order_by_field", ["start_date", "logical_date", "data_interval_start", "data_interval_end"] + "order_by_field, base_date", + [ + ("start_date", DEFAULT_DATETIME_1 + timedelta(days=20)), + ("logical_date", DEFAULT_DATETIME_2), + ("data_interval_start", DEFAULT_DATETIME_1 + timedelta(days=5)), + ("data_interval_end", DEFAULT_DATETIME_2 + timedelta(days=8)), + ], ) - def test_should_respond_200_for_order_by(self, order_by_field, test_client, session): + def test_should_respond_200_for_order_by(self, order_by_field, base_date, test_client, session): dag_id = "example_python_operator" dag_runs = [ @@ -1119,10 +1126,10 @@ def test_should_respond_200_for_order_by(self, order_by_field, test_client, sess dag_id=dag_id, run_id=f"run_{i}", run_type=DagRunType.MANUAL, - logical_date=DEFAULT_DATETIME_1 + dt.timedelta(days=i), + logical_date=base_date + dt.timedelta(days=i), data_interval=( - DEFAULT_DATETIME_1 + dt.timedelta(days=i), - DEFAULT_DATETIME_1 + dt.timedelta(days=i, hours=1), + base_date + dt.timedelta(days=i), + base_date + dt.timedelta(days=i, hours=1), ), ) for i in range(10) @@ -1133,7 +1140,7 @@ def test_should_respond_200_for_order_by(self, order_by_field, test_client, sess self.create_task_instances( session, task_instances=[ - {"run_id": f"run_{i}", "start_date": DEFAULT_DATETIME_1 + dt.timedelta(minutes=(i + 1))} + {"run_id": f"run_{i}", "start_date": base_date + dt.timedelta(minutes=(i + 1))} for i in range(10) ], dag_id=dag_id, @@ -1604,6 +1611,7 @@ def test_should_respond_200(self, test_client, session): "try_number": 1, "unixname": getuser(), "dag_run_id": "TEST_DAG_RUN_ID", + "dag_version": None, } @pytest.mark.parametrize("try_number", [1, 2]) @@ -1638,6 +1646,7 @@ def test_should_respond_200_with_different_try_numbers(self, test_client, try_nu "try_number": try_number, "unixname": getuser(), "dag_run_id": "TEST_DAG_RUN_ID", + "dag_version": None, } @pytest.mark.parametrize("try_number", [1, 2]) @@ -1701,6 +1710,7 @@ def test_should_respond_200_with_mapped_task_at_different_try_numbers( "try_number": try_number, "unixname": getuser(), "dag_run_id": "TEST_DAG_RUN_ID", + "dag_version": None, } def test_should_respond_200_with_task_state_in_deferred(self, test_client, session): @@ -1762,6 +1772,7 @@ def test_should_respond_200_with_task_state_in_deferred(self, test_client, sessi "try_number": 1, "unixname": getuser(), "dag_run_id": "TEST_DAG_RUN_ID", + "dag_version": None, } def test_should_respond_200_with_task_state_in_removed(self, test_client, session): @@ -1797,6 +1808,7 @@ def test_should_respond_200_with_task_state_in_removed(self, test_client, sessio "try_number": 1, "unixname": getuser(), "dag_run_id": "TEST_DAG_RUN_ID", + "dag_version": None, } def test_raises_404_for_nonexistent_task_instance(self, test_client, session): @@ -1810,6 +1822,54 @@ def test_raises_404_for_nonexistent_task_instance(self, test_client, session): "detail": "The Task Instance with dag_id: `example_python_operator`, run_id: `TEST_DAG_RUN_ID`, task_id: `nonexistent_task`, try_number: `0` and map_index: `-1` was not found" } + @pytest.mark.parametrize( + "run_id, expected_version_number", + [ + ("run1", 1), + ("run2", 2), + ("run3", 3), + ], + ) + @pytest.mark.usefixtures("make_dag_with_multiple_versions") + def test_should_respond_200_with_versions(self, test_client, run_id, expected_version_number): + response = test_client.get( + f"/public/dags/dag_with_multiple_versions/dagRuns/{run_id}/taskInstances/task1/tries/0" + ) + assert response.status_code == 200 + assert response.json() == { + "task_id": "task1", + "dag_id": "dag_with_multiple_versions", + "dag_run_id": run_id, + "map_index": -1, + "start_date": None, + "end_date": mock.ANY, + "duration": None, + "state": None, + "try_number": 0, + "max_tries": 0, + "task_display_name": "task1", + "hostname": "", + "unixname": getuser(), + "pool": "default_pool", + "pool_slots": 1, + "queue": "default", + "priority_weight": 1, + "operator": "EmptyOperator", + "queued_when": None, + "scheduled_when": None, + "pid": None, + "executor": None, + "executor_config": "{}", + "dag_version": { + "id": mock.ANY, + "version_number": expected_version_number, + "dag_id": "dag_with_multiple_versions", + "bundle_name": "dag_maker", + "bundle_version": None, + "created_at": mock.ANY, + }, + } + class TestPostClearTaskInstances(TestTaskInstanceEndpoint): @pytest.mark.parametrize( @@ -2587,6 +2647,7 @@ def test_should_respond_200(self, test_client, session): "try_number": 1, "unixname": getuser(), "dag_run_id": "TEST_DAG_RUN_ID", + "dag_version": None, }, { "dag_id": "example_python_operator", @@ -2612,6 +2673,7 @@ def test_should_respond_200(self, test_client, session): "try_number": 2, "unixname": getuser(), "dag_run_id": "TEST_DAG_RUN_ID", + "dag_version": None, }, ], "total_entries": 2, @@ -2658,6 +2720,7 @@ def test_ti_in_retry_state_not_returned(self, test_client, session): "try_number": 1, "unixname": getuser(), "dag_run_id": "TEST_DAG_RUN_ID", + "dag_version": None, }, ], "total_entries": 1, @@ -2725,6 +2788,7 @@ def test_mapped_task_should_respond_200(self, test_client, session): "try_number": 1, "unixname": getuser(), "dag_run_id": "TEST_DAG_RUN_ID", + "dag_version": None, }, { "dag_id": "example_python_operator", @@ -2750,6 +2814,7 @@ def test_mapped_task_should_respond_200(self, test_client, session): "try_number": 2, "unixname": getuser(), "dag_run_id": "TEST_DAG_RUN_ID", + "dag_version": None, }, ], "total_entries": 2, @@ -2766,6 +2831,55 @@ def test_raises_404_for_nonexistent_task_instance(self, test_client, session): "detail": "The Task Instance with dag_id: `example_python_operator`, run_id: `TEST_DAG_RUN_ID`, task_id: `non_existent_task` and map_index: `-1` was not found" } + @pytest.mark.parametrize( + "run_id, expected_version_number", + [ + ("run1", 1), + ("run2", 2), + ("run3", 3), + ], + ) + @pytest.mark.usefixtures("make_dag_with_multiple_versions") + def test_should_respond_200_with_versions(self, test_client, run_id, expected_version_number): + response = test_client.get( + f"/public/dags/dag_with_multiple_versions/dagRuns/{run_id}/taskInstances/task1/tries" + ) + assert response.status_code == 200 + + assert response.json()["task_instances"][0] == { + "task_id": "task1", + "dag_id": "dag_with_multiple_versions", + "dag_run_id": run_id, + "map_index": -1, + "start_date": None, + "end_date": mock.ANY, + "duration": None, + "state": mock.ANY, + "try_number": 0, + "max_tries": 0, + "task_display_name": "task1", + "hostname": "", + "unixname": getuser(), + "pool": "default_pool", + "pool_slots": 1, + "queue": "default", + "priority_weight": 1, + "operator": "EmptyOperator", + "queued_when": None, + "scheduled_when": None, + "pid": None, + "executor": None, + "executor_config": "{}", + "dag_version": { + "id": mock.ANY, + "version_number": expected_version_number, + "dag_id": "dag_with_multiple_versions", + "bundle_name": "dag_maker", + "bundle_version": None, + "created_at": mock.ANY, + }, + } + class TestPatchTaskInstance(TestTaskInstanceEndpoint): ENDPOINT_URL = ( diff --git a/tests/api_fastapi/core_api/routes/public/test_xcom.py b/tests/api_fastapi/core_api/routes/public/test_xcom.py index 7c6ce2c71e6ff..e3d9b3641a91b 100644 --- a/tests/api_fastapi/core_api/routes/public/test_xcom.py +++ b/tests/api_fastapi/core_api/routes/public/test_xcom.py @@ -579,3 +579,44 @@ def test_create_xcom_entry( assert current_data["task_id"] == task_id assert current_data["run_id"] == dag_run_id assert current_data["map_index"] == request_body.map_index + + +class TestPatchXComEntry(TestXComEndpoint): + @pytest.mark.parametrize( + "key, patch_body, expected_status, expected_detail", + [ + # Test case: Valid update, should return 200 OK + pytest.param( + TEST_XCOM_KEY, + {"value": "new_value"}, + 200, + None, + id="valid-xcom-update", + ), + # Test case: XCom entry does not exist, should return 404 + pytest.param( + TEST_XCOM_KEY, + {"value": "new_value", "map_index": -1}, + 404, + f"The XCom with key: `{TEST_XCOM_KEY}` with mentioned task instance doesn't exist.", + id="xcom-not-found", + ), + ], + ) + def test_patch_xcom_entry(self, key, patch_body, expected_status, expected_detail, test_client): + # Ensure the XCom entry exists before updating + if expected_status != 404: + self._create_xcom(TEST_XCOM_KEY, TEST_XCOM_VALUE) + new_value = XCom.serialize_value(patch_body["value"]) + + response = test_client.patch( + f"/public/dags/{TEST_DAG_ID}/dagRuns/{run_id}/taskInstances/{TEST_TASK_ID}/xcomEntries/{key}", + json=patch_body, + ) + + assert response.status_code == expected_status + + if expected_status == 200: + assert response.json()["value"] == XCom.serialize_value(new_value) + else: + assert response.json()["detail"] == expected_detail diff --git a/tests/api_fastapi/execution_api/routes/test_xcoms.py b/tests/api_fastapi/execution_api/routes/test_xcoms.py index 3232622b3ac1f..bfb5f0ca73b6a 100644 --- a/tests/api_fastapi/execution_api/routes/test_xcoms.py +++ b/tests/api_fastapi/execution_api/routes/test_xcoms.py @@ -70,7 +70,7 @@ def test_xcom_not_found(self, client, create_task_instance): assert response.status_code == 404 assert response.json() == { "detail": { - "message": "XCom with key 'xcom_non_existent' not found for task 'task' in DAG 'dag'", + "message": "XCom with key='xcom_non_existent' map_index=-1 not found for task 'task' in DAG run 'runid' of 'dag'", "reason": "not_found", } } diff --git a/tests/cli/commands/remote_commands/test_dag_command.py b/tests/cli/commands/remote_commands/test_dag_command.py index 2e6351e7bca91..28a4db231e4eb 100644 --- a/tests/cli/commands/remote_commands/test_dag_command.py +++ b/tests/cli/commands/remote_commands/test_dag_command.py @@ -693,6 +693,8 @@ def test_dag_test_conf(self, mock_get_dag): ) @mock.patch("airflow.cli.commands.remote_commands.dag_command.get_dag") def test_dag_test_show_dag(self, mock_get_dag, mock_render_dag): + mock_get_dag.return_value.test.return_value.run_id = "__test_dag_test_show_dag_fake_dag_run_run_id__" + cli_args = self.parser.parse_args( ["dags", "test", "example_bash_operator", DEFAULT_DATE.isoformat(), "--show-dagrun"] ) diff --git a/tests/cli/commands/remote_commands/test_task_command.py b/tests/cli/commands/remote_commands/test_task_command.py index dc331abd238fa..6c1414d5fbf80 100644 --- a/tests/cli/commands/remote_commands/test_task_command.py +++ b/tests/cli/commands/remote_commands/test_task_command.py @@ -228,30 +228,22 @@ def test_cli_test_different_path(self, session, tmp_path): # verify that the file was in different location when run assert ti.xcom_pull(ti.task_id) == new_file_path.as_posix() - @mock.patch("airflow.cli.commands.remote_commands.task_command.select") - @mock.patch("sqlalchemy.orm.session.Session.scalar") - def test_task_render_with_custom_timetable(self, mock_scalar, mock_select): + @mock.patch( + "airflow.cli.commands.remote_commands.task_command.fetch_dag_run_from_run_id_or_logical_date_string" + ) + def test_task_render_with_custom_timetable(self, mock_fetch_dag_run_from_run_id_or_logical_date_string): """ Test that the `tasks render` CLI command queries the database correctly for a DAG with a custom timetable. Verifies that a query is executed to fetch the appropriate DagRun and that the database interaction occurs as expected. """ - from sqlalchemy import select - - from airflow.models.dagrun import DagRun - - mock_query = ( - select(DagRun).where(DagRun.dag_id == "example_workday_timetable").order_by(DagRun.id.desc()) - ) - mock_select.return_value = mock_query - - mock_scalar.return_value = None + mock_fetch_dag_run_from_run_id_or_logical_date_string.return_value = (None, None) task_command.task_render( self.parser.parse_args(["tasks", "render", "example_workday_timetable", "run_this", "2022-01-01"]) ) - mock_select.assert_called_once() + mock_fetch_dag_run_from_run_id_or_logical_date_string.assert_called_once() @pytest.mark.filterwarnings("ignore::airflow.utils.context.AirflowContextDeprecationWarning") def test_test_with_existing_dag_run(self, caplog): diff --git a/tests/core/test_configuration.py b/tests/core/test_configuration.py index b19d6ff0cf2cd..735a78c64b268 100644 --- a/tests/core/test_configuration.py +++ b/tests/core/test_configuration.py @@ -1732,11 +1732,19 @@ def test_config_paths_is_directory(self): "sensitive_config_values", new_callable=lambda: [("mysection1", "mykey1"), ("mysection2", "mykey2")], ) - @patch("airflow.utils.log.secrets_masker.mask_secret") - def test_mask_conf_values(self, mock_mask_secret, mock_sensitive_config_values): - conf.mask_secrets() + def test_mask_conf_values(self, mock_sensitive_config_values): + from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS - mock_mask_secret.assert_any_call("supersecret1") - mock_mask_secret.assert_any_call("supersecret2") + target = ( + "airflow.sdk.execution_time.secrets_masker.mask_secret" + if AIRFLOW_V_3_0_PLUS + else "airflow.utils.log.secrets_masker.mask_secret" + ) + + with patch(target) as mock_mask_secret: + conf.mask_secrets() + + mock_mask_secret.assert_any_call("supersecret1") + mock_mask_secret.assert_any_call("supersecret2") - assert mock_mask_secret.call_count == 2 + assert mock_mask_secret.call_count == 2 diff --git a/tests/dag_processing/test_dag_bundles.py b/tests/dag_processing/test_dag_bundles.py index 6ab1f3ea68fb6..4aafaf8564ae5 100644 --- a/tests/dag_processing/test_dag_bundles.py +++ b/tests/dag_processing/test_dag_bundles.py @@ -17,6 +17,8 @@ from __future__ import annotations +import fcntl +import os import re import tempfile from pathlib import Path @@ -51,22 +53,82 @@ def test_default_dag_storage_path(): assert bundle._dag_bundle_root_storage_path == Path(tempfile.gettempdir(), "airflow", "dag_bundles") -def test_dag_bundle_root_storage_path(): - class BasicBundle(BaseDagBundle): - def refresh(self): - pass +class BasicBundle(BaseDagBundle): + def refresh(self): + pass + + def get_current_version(self): + pass - def get_current_version(self): - pass + def path(self): + pass - def path(self): - pass +def test_dag_bundle_root_storage_path(): with conf_vars({("dag_processor", "dag_bundle_storage_path"): None}): bundle = BasicBundle(name="test") assert bundle._dag_bundle_root_storage_path == Path(tempfile.gettempdir(), "airflow", "dag_bundles") +def test_lock_acquisition(): + """Test that the lock context manager sets _locked and locks a lock file.""" + bundle = BasicBundle(name="locktest") + lock_dir = bundle._dag_bundle_root_storage_path / "_locks" + lock_file = lock_dir / f"{bundle.name}.lock" + + assert not bundle._locked + + with bundle.lock(): + assert bundle._locked + assert lock_file.exists() + + # Check lock file is now locked + with open(lock_file, "w") as f: + try: + # Try to acquire an exclusive lock in non-blocking mode. + fcntl.flock(f, fcntl.LOCK_EX | fcntl.LOCK_NB) + locked = False + except OSError: + locked = True + assert locked + + # After, _locked is False and file unlock has been called. + assert bundle._locked is False + with open(lock_file, "w") as f: + try: + fcntl.flock(f, fcntl.LOCK_EX | fcntl.LOCK_NB) + unlocked = True + fcntl.flock(f, fcntl.LOCK_UN) # Release the lock immediately. + except OSError: + unlocked = False + assert unlocked + + +def test_lock_exception_handling(): + """Test that exceptions within the lock context manager still release the lock.""" + bundle = BasicBundle(name="locktest") + lock_dir = bundle._dag_bundle_root_storage_path / "_locks" + lock_file = lock_dir / f"{bundle.name}.lock" + + try: + with bundle.lock(): + assert bundle._locked + raise Exception("...") + except Exception: + pass + + # lock file should be unlocked + assert not bundle._locked + with open(lock_file, "w") as f: + try: + fcntl.flock(f, fcntl.LOCK_EX | fcntl.LOCK_NB) + acquired = True + fcntl.flock(f, fcntl.LOCK_UN) + except OSError: + acquired = False + assert acquired + + class TestLocalDagBundle: def test_path(self): bundle = LocalDagBundle(name="test", path="/hello") @@ -108,6 +170,8 @@ def git_repo(tmp_path_factory): CONN_HTTPS = "my_git_conn" CONN_HTTPS_PASSWORD = "my_git_conn_https_password" CONN_ONLY_PATH = "my_git_conn_only_path" +CONN_ONLY_INLINE_KEY = "my_git_conn_only_inline_key" +CONN_BOTH_PATH_INLINE = "my_git_conn_both_path_inline" CONN_NO_REPO_URL = "my_git_conn_no_repo_url" @@ -146,6 +210,16 @@ def setup_class(cls) -> None: conn_type="git", ) ) + db.merge_conn( + Connection( + conn_id=CONN_ONLY_INLINE_KEY, + host="path/to/repo", + conn_type="git", + extra={ + "private_key": "inline_key", + }, + ) + ) @pytest.mark.parametrize( "conn_id, expected_repo_url", @@ -160,11 +234,12 @@ def test_correct_repo_urls(self, conn_id, expected_repo_url): hook = GitHook(git_conn_id=conn_id) assert hook.repo_url == expected_repo_url - def test_env_var(self, session): - hook = GitHook(git_conn_id=CONN_DEFAULT) - assert hook.env == { - "GIT_SSH_COMMAND": "ssh -i /files/pkey.pem -o IdentitiesOnly=yes -o StrictHostKeyChecking=no" - } + def test_env_var_with_configure_hook_env(self, session): + default_hook = GitHook(git_conn_id=CONN_DEFAULT) + with default_hook.configure_hook_env(): + assert default_hook.env == { + "GIT_SSH_COMMAND": "ssh -i /files/pkey.pem -o IdentitiesOnly=yes -o StrictHostKeyChecking=no" + } db.merge_conn( Connection( conn_id="my_git_conn_strict", @@ -174,11 +249,61 @@ def test_env_var(self, session): ) ) - hook = GitHook(git_conn_id="my_git_conn_strict") + strict_default_hook = GitHook(git_conn_id="my_git_conn_strict") + with strict_default_hook.configure_hook_env(): + assert strict_default_hook.env == { + "GIT_SSH_COMMAND": "ssh -i /files/pkey.pem -o IdentitiesOnly=yes -o StrictHostKeyChecking=yes" + } + + def test_given_both_private_key_and_key_file(self): + db.merge_conn( + Connection( + conn_id=CONN_BOTH_PATH_INLINE, + host="path/to/repo", + conn_type="git", + extra={ + "key_file": "path/to/key", + "private_key": "inline_key", + }, + ) + ) + + with pytest.raises( + AirflowException, match="Both 'key_file' and 'private_key' cannot be provided at the same time" + ): + GitHook(git_conn_id=CONN_BOTH_PATH_INLINE) + + def test_key_file_git_hook_has_env_with_configure_hook_env(self): + hook = GitHook(git_conn_id=CONN_DEFAULT) + + assert hasattr(hook, "env") + with hook.configure_hook_env(): + assert hook.env == { + "GIT_SSH_COMMAND": "ssh -i /files/pkey.pem -o IdentitiesOnly=yes -o StrictHostKeyChecking=no" + } + + def test_private_key_lazy_env_var(self): + hook = GitHook(git_conn_id=CONN_ONLY_INLINE_KEY) + assert hook.env == {} + + hook.set_git_env("dummy_inline_key") assert hook.env == { - "GIT_SSH_COMMAND": "ssh -i /files/pkey.pem -o IdentitiesOnly=yes -o StrictHostKeyChecking=yes" + "GIT_SSH_COMMAND": "ssh -i dummy_inline_key -o IdentitiesOnly=yes -o StrictHostKeyChecking=no" } + def test_configure_hook_env(self): + hook = GitHook(git_conn_id=CONN_ONLY_INLINE_KEY) + assert hasattr(hook, "private_key") + + hook.set_git_env("dummy_inline_key") + + with hook.configure_hook_env(): + command = hook.env.get("GIT_SSH_COMMAND") + temp_key_path = command.split()[2] + assert os.path.exists(temp_key_path) + + assert not os.path.exists(temp_key_path) + class TestGitDagBundle: @classmethod @@ -517,3 +642,12 @@ def test_repo_url_access_missing_connection_doesnt_error(self, mock_log): ) assert bundle.repo_url == "some_repo_url" assert "Could not create GitHook for connection" in mock_log.warning.call_args[0][0] + + @mock.patch("airflow.dag_processing.bundles.git.GitHook") + def test_lock_used(self, mock_githook, git_repo): + repo_path, repo = git_repo + mock_githook.return_value.repo_url = repo_path + bundle = GitDagBundle(name="test", tracking_ref=GIT_DEFAULT_BRANCH) + with mock.patch("airflow.dag_processing.bundles.git.GitDagBundle.lock") as mock_lock: + bundle.initialize() + assert mock_lock.call_count == 2 # both initialize and refresh diff --git a/tests/decorators/test_task_group.py b/tests/decorators/test_task_group.py index d5fa174fa2511..9262ec0cffe45 100644 --- a/tests/decorators/test_task_group.py +++ b/tests/decorators/test_task_group.py @@ -215,6 +215,7 @@ def tg(a, b): @pytest.mark.db_test +@pytest.mark.need_serialized_dag def test_task_group_expand_kwargs_with_upstream(dag_maker, session, caplog): with dag_maker() as dag: @@ -239,6 +240,7 @@ def t2(): @pytest.mark.db_test +@pytest.mark.need_serialized_dag def test_task_group_expand_with_upstream(dag_maker, session, caplog): with dag_maker() as dag: diff --git a/tests/models/test_backfill.py b/tests/models/test_backfill.py index 7b1625e1043ad..0a1ad5e134921 100644 --- a/tests/models/test_backfill.py +++ b/tests/models/test_backfill.py @@ -152,6 +152,11 @@ def test_create_backfill_simple(reverse, existing, dag_maker, session): assert all(x.conf == expected_run_conf for x in dag_runs) +# Marking test xfail as backfill reprocess behaviour impacted by restoring logical date unique constraints in #46295 +# TODO: Fix backfill reprocess behaviour as per #46295 +@pytest.mark.xfail( + reason="Backfill reprocess behaviour impacted by restoring logical date unique constraints." +) @pytest.mark.parametrize( "reprocess_behavior, run_counts", [ diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py index 15a2ca96df97f..84801d6cfb5b7 100644 --- a/tests/models/test_dag.py +++ b/tests/models/test_dag.py @@ -59,7 +59,6 @@ DagOwnerAttributes, DagTag, ExecutorLoader, - dag as dag_decorator, get_asset_triggered_next_run_info, ) from airflow.models.dag_version import DagVersion @@ -73,7 +72,7 @@ from airflow.sdk.definitions._internal.contextmanager import TaskGroupContext from airflow.sdk.definitions._internal.templater import NativeEnvironment, SandboxedEnvironment from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAll, AssetAny -from airflow.sdk.definitions.param import DagParam, Param +from airflow.sdk.definitions.param import Param from airflow.security import permissions from airflow.timetables.base import DagRunInfo, DataInterval, TimeRestriction, Timetable from airflow.timetables.simple import ( @@ -150,6 +149,7 @@ def _create_dagrun( run_type: DagRunType, state: DagRunState = DagRunState.RUNNING, start_date: datetime.datetime | None = None, + **kwargs, ) -> DagRun: logical_date = timezone.coerce_datetime(logical_date) if not isinstance(data_interval, DataInterval): @@ -168,6 +168,7 @@ def _create_dagrun( state=state, start_date=start_date, triggered_by=DagRunTriggeredByType.TEST, + **kwargs, ) @@ -1536,16 +1537,14 @@ def consumer(value): PythonOperator.partial(task_id=task_id, python_callable=consumer).expand(op_args=make_arg_lists()) session = dag_maker.session - dagrun_1 = dag.create_dagrun( - run_id="backfill", + dagrun_1 = _create_dagrun( + dag, run_type=DagRunType.BACKFILL_JOB, - state=State.FAILED, + state=DagRunState.FAILED, start_date=DEFAULT_DATE, logical_date=DEFAULT_DATE, - session=session, data_interval=(DEFAULT_DATE, DEFAULT_DATE), - run_after=DEFAULT_DATE, - triggered_by=DagRunTriggeredByType.TEST, + session=session, ) # Get the (de)serialized MappedOperator mapped = dag.get_task(task_id) @@ -1663,26 +1662,6 @@ def check_task_2(my_input): mock_task_object_1.assert_called() mock_task_object_2.assert_not_called() - def test_dag_test_with_task_mapping(self): - dag = DAG(dag_id="test_local_testing_conn_file", schedule=None, start_date=DEFAULT_DATE) - mock_object = mock.MagicMock() - - @task_decorator() - def get_index(current_val, ti=None): - return ti.map_index - - @task_decorator - def check_task(my_input): - # we call a mock object with the combined map to ensure all expected indexes are called - mock_object(list(my_input)) - - with dag: - mapped_task = get_index.expand(current_val=[1, 1, 1, 1, 1]) - check_task(mapped_task) - - dag.test() - mock_object.assert_called_with([0, 1, 2, 3, 4]) - def test_dag_connection_file(self, tmp_path): test_connections_string = """ --- @@ -2519,135 +2498,6 @@ def test_count_number_queries(self, tasks_count): ) -class TestDagDecorator: - DEFAULT_ARGS = { - "owner": "test", - "depends_on_past": True, - "start_date": timezone.utcnow(), - "retries": 1, - "retry_delay": timedelta(minutes=1), - } - DEFAULT_DATE = timezone.datetime(2016, 1, 1) - VALUE = 42 - - def setup_method(self): - self.operator = None - - def teardown_method(self): - clear_db_runs() - - def test_documentation_template_rendered(self): - """Test that @dag uses function docs as doc_md for DAG object""" - - @dag_decorator(schedule=None, default_args=self.DEFAULT_ARGS) - def noop_pipeline(): - """ - {% if True %} - Regular DAG documentation - {% endif %} - """ - - dag = noop_pipeline() - assert dag.dag_id == "noop_pipeline" - assert "Regular DAG documentation" in dag.doc_md - - def test_resolve_documentation_template_file_not_rendered(self, tmp_path): - """Test that @dag uses function docs as doc_md for DAG object""" - - raw_content = """ - {% if True %} - External Markdown DAG documentation - {% endif %} - """ - - path = tmp_path / "testfile.md" - path.write_text(raw_content) - - @dag_decorator("test-dag", schedule=None, start_date=DEFAULT_DATE, doc_md=str(path)) - def markdown_docs(): ... - - dag = markdown_docs() - assert dag.dag_id == "test-dag" - assert dag.doc_md == raw_content - - def test_dag_param_resolves(self): - """Test that dag param is correctly resolved by operator""" - - @dag_decorator(schedule=None, default_args=self.DEFAULT_ARGS) - def xcom_pass_to_op(value=self.VALUE): - @task_decorator - def return_num(num): - return num - - xcom_arg = return_num(value) - self.operator = xcom_arg.operator - - dag = xcom_pass_to_op() - - dr = dag.create_dagrun( - run_id="test", - run_type=DagRunType.MANUAL, - start_date=timezone.utcnow(), - logical_date=self.DEFAULT_DATE, - data_interval=(self.DEFAULT_DATE, self.DEFAULT_DATE), - run_after=self.DEFAULT_DATE, - state=State.RUNNING, - triggered_by=DagRunTriggeredByType.TEST, - ) - - self.operator.run(start_date=self.DEFAULT_DATE, end_date=self.DEFAULT_DATE) - ti = dr.get_task_instances()[0] - assert ti.xcom_pull() == self.VALUE - - def test_dag_param_dagrun_parameterized(self): - """Test that dag param is correctly overwritten when set in dag run""" - - @dag_decorator(schedule=None, default_args=self.DEFAULT_ARGS) - def xcom_pass_to_op(value=self.VALUE): - @task_decorator - def return_num(num): - return num - - assert isinstance(value, DagParam) - - xcom_arg = return_num(value) - self.operator = xcom_arg.operator - - dag = xcom_pass_to_op() - new_value = 52 - dr = dag.create_dagrun( - run_id="test", - run_type=DagRunType.MANUAL, - start_date=timezone.utcnow(), - logical_date=self.DEFAULT_DATE, - data_interval=(self.DEFAULT_DATE, self.DEFAULT_DATE), - run_after=self.DEFAULT_DATE, - state=State.RUNNING, - conf={"value": new_value}, - triggered_by=DagRunTriggeredByType.TEST, - ) - - self.operator.run(start_date=self.DEFAULT_DATE, end_date=self.DEFAULT_DATE) - ti = dr.get_task_instances()[0] - assert ti.xcom_pull() == new_value - - @pytest.mark.parametrize("value", [VALUE, 0]) - def test_set_params_for_dag(self, value): - """Test that dag param is correctly set when using dag decorator""" - - @dag_decorator(schedule=None, default_args=self.DEFAULT_ARGS) - def xcom_pass_to_op(value=value): - @task_decorator - def return_num(num): - return num - - xcom_arg = return_num(value) - self.operator = xcom_arg.operator - - dag = xcom_pass_to_op() - assert dag.params["value"] == value - - @pytest.mark.parametrize( "run_id", ["test-run-id"], diff --git a/tests/models/test_dagrun.py b/tests/models/test_dagrun.py index a8972e88e27b0..b7c2f2282c8be 100644 --- a/tests/models/test_dagrun.py +++ b/tests/models/test_dagrun.py @@ -55,7 +55,7 @@ from tests_common.test_utils.config import conf_vars from tests_common.test_utils.mock_operators import MockOperator -pytestmark = pytest.mark.db_test +pytestmark = [pytest.mark.db_test, pytest.mark.need_serialized_dag] if TYPE_CHECKING: @@ -329,7 +329,7 @@ def test_dagrun_deadlock(self, dag_maker, session): assert dr.state == DagRunState.RUNNING ti_op2.set_state(state=None, session=session) - op2.trigger_rule = "invalid" # type: ignore + ti_op2.task.trigger_rule = "invalid" # type: ignore dr.update_state(session=session) assert dr.state == DagRunState.FAILED @@ -353,6 +353,7 @@ def test_dagrun_no_deadlock_with_depends_on_past(self, dag_maker, session): dr = dag_maker.create_dagrun( run_id="test_dagrun_no_deadlock_1", + run_type=DagRunType.SCHEDULED, start_date=DEFAULT_DATE, ) dr2 = dag_maker.create_dagrun_after( @@ -759,7 +760,6 @@ def mutate_task_instance(task_instance): (None, False), ], ) - @pytest.mark.need_serialized_dag def test_depends_on_past(self, dag_maker, session, prev_ti_state, is_ti_schedulable): # DAG tests depends_on_past dependencies with dag_maker( @@ -802,7 +802,6 @@ def test_depends_on_past(self, dag_maker, session, prev_ti_state, is_ti_schedula (None, False), ], ) - @pytest.mark.need_serialized_dag def test_wait_for_downstream(self, dag_maker, session, prev_ti_state, is_ti_schedulable): dag_id = "test_wait_for_downstream" @@ -1090,7 +1089,6 @@ def test_expand_mapped_task_instance_at_create(is_noop, dag_maker, session): assert indices == [(0,), (1,), (2,), (3,)] -@pytest.mark.need_serialized_dag @pytest.mark.parametrize("is_noop", [True, False]) def test_expand_mapped_task_instance_task_decorator(is_noop, dag_maker, session): with mock.patch("airflow.settings.task_instance_mutation_hook") as mock_mut: @@ -1114,7 +1112,6 @@ def mynameis(arg): assert indices == [(0,), (1,), (2,), (3,)] -@pytest.mark.need_serialized_dag def test_mapped_literal_verify_integrity(dag_maker, session): """Test that when the length of a mapped literal changes we remove extra TIs""" @@ -1147,7 +1144,6 @@ def task_2(arg2): ... assert indices == [(0, None), (1, None), (2, TaskInstanceState.REMOVED), (3, TaskInstanceState.REMOVED)] -@pytest.mark.need_serialized_dag def test_mapped_literal_to_xcom_arg_verify_integrity(dag_maker, session): """Test that when we change from literal to a XComArg the TIs are removed""" @@ -1181,7 +1177,6 @@ def task_2(arg2): ... ] -@pytest.mark.need_serialized_dag def test_mapped_literal_length_increase_adds_additional_ti(dag_maker, session): """Test that when the length of mapped literal increases, additional ti is added""" @@ -1224,7 +1219,6 @@ def task_2(arg2): ... ] -@pytest.mark.need_serialized_dag def test_mapped_literal_length_reduction_adds_removed_state(dag_maker, session): """Test that when the length of mapped literal reduces, removed state is added""" @@ -1265,7 +1259,6 @@ def task_2(arg2): ... ] -@pytest.mark.need_serialized_dag def test_mapped_length_increase_at_runtime_adds_additional_tis(dag_maker, session): """Test that when the length of mapped literal increases at runtime, additional ti is added""" # Variable.set(key="arg1", value=[1, 2, 3]) @@ -1317,7 +1310,6 @@ def task_2(arg2): ... ] -@pytest.mark.need_serialized_dag def test_mapped_literal_length_reduction_at_runtime_adds_removed_state(dag_maker, session): """ Test that when the length of mapped literal reduces at runtime, the missing task instances @@ -1405,7 +1397,6 @@ def task_2(arg2): ... assert len(decision.schedulable_tis) == 2 -@pytest.mark.need_serialized_dag def test_calls_to_verify_integrity_with_mapped_task_zero_length_at_runtime(dag_maker, session, caplog): """ Test zero length reduction in mapped task at runtime with calls to dagrun.verify_integrity @@ -1468,7 +1459,6 @@ def task_2(arg2): ... ) -@pytest.mark.need_serialized_dag def test_mapped_mixed_literal_not_expanded_at_create(dag_maker, session): literal = [1, 2, 3, 4] with dag_maker(session=session): @@ -1645,8 +1635,6 @@ def consumer(*args): def test_mapped_task_all_finish_before_downstream(dag_maker, session): - result = None - with dag_maker(session=session) as dag: @dag.task @@ -1659,8 +1647,8 @@ def double(value): @dag.task def consumer(value): - nonlocal result - result = list(value) + ... + # result = list(value) consumer(value=double.expand(value=make_list())) @@ -1674,26 +1662,29 @@ def _task_ids(tis): assert _task_ids(decision.schedulable_tis) == ["make_list"] # After make_list is run, double is expanded. - decision.schedulable_tis[0].run(verbose=False, session=session) + ti = decision.schedulable_tis[0] + ti.state = TaskInstanceState.SUCCESS + session.add(TaskMap.from_task_instance_xcom(ti, [1, 2])) + session.flush() + decision = dr.task_instance_scheduling_decisions(session=session) assert _task_ids(decision.schedulable_tis) == ["double", "double"] # Running just one of the mapped tis does not make downstream schedulable. - decision.schedulable_tis[0].run(verbose=False, session=session) + ti = decision.schedulable_tis[0] + ti.state = TaskInstanceState.SUCCESS + session.flush() + decision = dr.task_instance_scheduling_decisions(session=session) assert _task_ids(decision.schedulable_tis) == ["double"] - # Downstream is schedulable after all mapped tis are run. - decision.schedulable_tis[0].run(verbose=False, session=session) + # Downstream is scheduleable after all mapped tis are run. + ti = decision.schedulable_tis[0] + ti.state = TaskInstanceState.SUCCESS + session.flush() decision = dr.task_instance_scheduling_decisions(session=session) assert _task_ids(decision.schedulable_tis) == ["consumer"] - # We should be able to get all values aggregated from mapped upstreams. - decision.schedulable_tis[0].run(verbose=False, session=session) - decision = dr.task_instance_scheduling_decisions(session=session) - assert decision.schedulable_tis == [] - assert result == [2, 4] - def test_schedule_tis_map_index(dag_maker, session): with dag_maker(session=session, dag_id="test"): @@ -1771,6 +1762,7 @@ def test_schedule_tis_empty_operator_try_number(dag_maker, session: Session): assert empty_ti.try_number == 1 +@pytest.mark.xfail(reason="We can't keep this bevaviour with remote workers where scheduler can't reach xcom") def test_schedule_tis_start_trigger_through_expand(dag_maker, session): """ Test that an operator with start_trigger_args set can be directly deferred during scheduling. @@ -1925,28 +1917,18 @@ def do_something_else(i): @pytest.mark.parametrize( "partial_params, mapped_params, expected", [ - pytest.param(None, [{"a": 1}], [[("a", 1)]], id="simple"), - pytest.param({"b": 2}, [{"a": 1}], [[("a", 1), ("b", 2)]], id="merge"), - pytest.param({"b": 2}, [{"a": 1, "b": 3}], [[("a", 1), ("b", 3)]], id="override"), + pytest.param(None, [{"a": 1}], 1, id="simple"), + pytest.param({"b": 2}, [{"a": 1}], 1, id="merge"), + pytest.param({"b": 2}, [{"a": 1, "b": 3}], 1, id="override"), ], ) def test_mapped_expand_against_params(dag_maker, partial_params, mapped_params, expected): - results = [] - - class PullOperator(BaseOperator): - def execute(self, context): - results.append(sorted(context["params"].items())) - with dag_maker(): - PullOperator.partial(task_id="t", params=partial_params).expand(params=mapped_params) + BaseOperator.partial(task_id="t", params=partial_params).expand(params=mapped_params) dr: DagRun = dag_maker.create_dagrun() decision = dr.task_instance_scheduling_decisions() - - for ti in decision.schedulable_tis: - ti.run() - - assert sorted(results) == expected + assert len(decision.schedulable_tis) == expected def test_mapped_task_group_expands(dag_maker, session): @@ -1991,9 +1973,7 @@ def test_operator_mapped_task_group_receives_value(dag_maker, session): with dag_maker(session=session): @task - def t(value, *, ti=None): - results[(ti.task_id, ti.map_index)] = value - return value + def t(value): ... @task_group def tg(va): @@ -2015,24 +1995,29 @@ def tg(va): dr: DagRun = dag_maker.create_dagrun() - results = {} + results = set() decision = dr.task_instance_scheduling_decisions(session=session) for ti in decision.schedulable_tis: - ti.run() - assert results == {("tg.t1", 0): ["a", "b"], ("tg.t1", 1): [4], ("tg.t1", 2): ["z"]} + results.add((ti.task_id, ti.map_index)) + ti.state = TaskInstanceState.SUCCESS + session.flush() + assert results == {("tg.t1", 0), ("tg.t1", 1), ("tg.t1", 2)} - results = {} + results.clear() decision = dr.task_instance_scheduling_decisions(session=session) for ti in decision.schedulable_tis: - ti.run() - assert results == {("tg.t2", 0): ["a", "b"], ("tg.t2", 1): [4], ("tg.t2", 2): ["z"]} + results.add((ti.task_id, ti.map_index)) + ti.state = TaskInstanceState.SUCCESS + session.flush() + assert results == {("tg.t2", 0), ("tg.t2", 1), ("tg.t2", 2)} - results = {} + results.clear() decision = dr.task_instance_scheduling_decisions(session=session) for ti in decision.schedulable_tis: - ti.run() - assert len(results) == 1 - assert list(results[("t3", -1)]) == [["a", "b"], [4], ["z"]] + results.add((ti.task_id, ti.map_index)) + ti.state = TaskInstanceState.SUCCESS + session.flush() + assert results == {("t3", -1)} def test_mapping_against_empty_list(dag_maker, session): @@ -2100,13 +2085,15 @@ def print_value(value): decision = dr1.task_instance_scheduling_decisions(session=session) assert len(decision.schedulable_tis) == 2 for ti in decision.schedulable_tis: - ti.run(session=session) + ti.state = TaskInstanceState.SUCCESS + session.flush() # Now print_value in dr2 can run decision = dr2.task_instance_scheduling_decisions(session=session) assert len(decision.schedulable_tis) == 2 for ti in decision.schedulable_tis: - ti.run(session=session) + ti.state = TaskInstanceState.SUCCESS + session.flush() # Both runs are finished now. decision = dr1.task_instance_scheduling_decisions(session=session) @@ -2115,6 +2102,69 @@ def print_value(value): assert len(decision.unfinished_tis) == 0 +def test_xcom_map_skip_raised(dag_maker, session): + result = None + + with dag_maker(session=session) as dag: + # Note: this doesn't actually run this dag, the callbacks are for reference only. + + @dag.task() + def push(): + return ["a", "b", "c"] + + @dag.task() + def forward(value): + return value + + @dag.task(trigger_rule=TriggerRule.ALL_DONE) + def collect(value): + nonlocal result + result = list(value) + + def skip_c(v): + ... + # if v == "c": + # raise AirflowSkipException + # return {"value": v} + + collect(value=forward.expand_kwargs(push().map(skip_c))) + + dr: DagRun = dag_maker.create_dagrun(session=session) + + def _task_ids(tis): + return [(ti.task_id, ti.map_index) for ti in tis] + + # Check that when forward w/ map_index=2 ends up skipping, that the collect task can still be + # scheduled! + + # Run "push". + decision = dr.task_instance_scheduling_decisions(session=session) + assert _task_ids(decision.schedulable_tis) == [("push", -1)] + ti = decision.schedulable_tis[0] + ti.state = TaskInstanceState.SUCCESS + session.add(TaskMap.from_task_instance_xcom(ti, push.function())) + session.flush() + + decision = dr.task_instance_scheduling_decisions(session=session) + assert _task_ids(decision.schedulable_tis) == [ + ("forward", 0), + ("forward", 1), + ("forward", 2), + ] + # Run "forward". "c"/index 2 is skipped. Runtime behaviour checked in test_xcom_map_raise_to_skip in + # TaskSDK + for ti, state in zip( + decision.schedulable_tis, + [TaskInstanceState.SUCCESS, TaskInstanceState.SUCCESS, TaskInstanceState.SKIPPED], + ): + ti.state = state + session.flush() + + # Now "collect" should only get "a" and "b". + decision = dr.task_instance_scheduling_decisions(session=session) + assert _task_ids(decision.schedulable_tis) == [("collect", -1)] + + def test_clearing_task_and_moving_from_non_mapped_to_mapped(dag_maker, session): """ Test that clearing a task and moving from non-mapped to mapped clears existing @@ -2386,7 +2436,7 @@ def make_task(task_id, dag): with dag_maker() as dag: for line in input: - tasks = [make_task(x, dag) for x in line.split(" >> ")] + tasks = [make_task(x, dag_maker.dag) for x in line.split(" >> ")] reduce(lambda x, y: x >> y, tasks) dr = dag_maker.create_dagrun() diff --git a/tests/models/test_mappedoperator.py b/tests/models/test_mappedoperator.py index 2d06dc6216f84..256d74d88bc88 100644 --- a/tests/models/test_mappedoperator.py +++ b/tests/models/test_mappedoperator.py @@ -32,19 +32,13 @@ from airflow.models.taskinstance import TaskInstance from airflow.models.taskmap import TaskMap from airflow.providers.standard.operators.python import PythonOperator -from airflow.sdk.definitions.mappedoperator import MappedOperator +from airflow.sdk.execution_time.comms import XComCountResponse from airflow.utils.state import TaskInstanceState from airflow.utils.task_group import TaskGroup -from airflow.utils.task_instance_session import set_current_task_instance_session -from airflow.utils.xcom import XCOM_RETURN_KEY from tests.models import DEFAULT_DATE from tests_common.test_utils.mapping import expand_mapped_task -from tests_common.test_utils.mock_operators import ( - MockOperator, - MockOperatorWithNestedFields, - NestedFields, -) +from tests_common.test_utils.mock_operators import MockOperator pytestmark = pytest.mark.db_test @@ -81,43 +75,6 @@ def execute(self, context: Context): mock_render_template.assert_called() -def test_map_xcom_arg_multiple_upstream_xcoms(dag_maker, session): - """Test that the correct number of downstream tasks are generated when mapping with an XComArg""" - - class PushExtraXComOperator(BaseOperator): - """Push an extra XCom value along with the default return value.""" - - def __init__(self, return_value, **kwargs): - super().__init__(**kwargs) - self.return_value = return_value - - def execute(self, context): - context["task_instance"].xcom_push(key="extra_key", value="extra_value") - return self.return_value - - with dag_maker("test-dag", session=session, start_date=DEFAULT_DATE) as dag: - upstream_return = [1, 2, 3] - task1 = PushExtraXComOperator(return_value=upstream_return, task_id="task_1") - task2 = PushExtraXComOperator.partial(task_id="task_2").expand(return_value=task1.output) - task3 = PushExtraXComOperator.partial(task_id="task_3").expand(return_value=task2.output) - - dr = dag_maker.create_dagrun() - ti_1 = dr.get_task_instance("task_1", session) - ti_1.run() - - ti_2s, _ = TaskMap.expand_mapped_task(task2, dr.run_id, session=session) - for ti in ti_2s: - ti.refresh_from_task(dag.get_task("task_2")) - ti.run() - - ti_3s, _ = TaskMap.expand_mapped_task(task3, dr.run_id, session=session) - for ti in ti_3s: - ti.refresh_from_task(dag.get_task("task_3")) - ti.run() - - assert len(ti_3s) == len(ti_2s) == len(upstream_return) - - @pytest.mark.parametrize( ["num_existing_tis", "expected"], ( @@ -255,151 +212,6 @@ def test_expand_mapped_task_instance_skipped_on_zero(dag_maker, session): assert indices == [(-1, TaskInstanceState.SKIPPED)] -class _RenderTemplateFieldsValidationOperator(BaseOperator): - template_fields = ( - "partial_template", - "map_template_xcom", - "map_template_literal", - "map_template_file", - ) - template_ext = (".ext",) - - fields_to_test = [ - "partial_template", - "partial_static", - "map_template_xcom", - "map_template_literal", - "map_static", - "map_template_file", - ] - - def __init__( - self, - partial_template, - partial_static, - map_template_xcom, - map_template_literal, - map_static, - map_template_file, - **kwargs, - ): - for field in self.fields_to_test: - setattr(self, field, value := locals()[field]) - assert isinstance(value, str), "value should have been resolved before unmapping" - super().__init__(**kwargs) - - def execute(self, context): - pass - - -def test_mapped_render_template_fields_validating_operator(dag_maker, session, tmp_path): - file_template_dir = tmp_path / "path" / "to" - file_template_dir.mkdir(parents=True, exist_ok=True) - file_template = file_template_dir / "file.ext" - file_template.write_text("loaded data") - - with set_current_task_instance_session(session=session): - with dag_maker(session=session, template_searchpath=tmp_path.__fspath__()): - task1 = BaseOperator(task_id="op1") - output1 = task1.output - mapped = _RenderTemplateFieldsValidationOperator.partial( - task_id="a", partial_template="{{ ti.task_id }}", partial_static="{{ ti.task_id }}" - ).expand( - map_static=output1, - map_template_literal=["{{ ds }}"], - map_template_xcom=output1, - map_template_file=["/path/to/file.ext"], - ) - - dr = dag_maker.create_dagrun() - ti: TaskInstance = dr.get_task_instance(task1.task_id, session=session) - ti.xcom_push(key=XCOM_RETURN_KEY, value=["{{ ds }}"], session=session) - session.add( - TaskMap( - dag_id=dr.dag_id, - task_id=task1.task_id, - run_id=dr.run_id, - map_index=-1, - length=1, - keys=None, - ) - ) - session.flush() - - mapped_ti: TaskInstance = dr.get_task_instance(mapped.task_id, session=session) - mapped_ti.map_index = 0 - assert isinstance(mapped_ti.task, MappedOperator) - mapped.render_template_fields(context=mapped_ti.get_template_context(session=session)) - assert isinstance(mapped_ti.task, _RenderTemplateFieldsValidationOperator) - - assert mapped_ti.task.partial_template == "a", "Should be rendered!" - assert mapped_ti.task.partial_static == "{{ ti.task_id }}", "Should not be rendered!" - assert mapped_ti.task.map_static == "{{ ds }}", "Should not be rendered!" - assert mapped_ti.task.map_template_literal == "2016-01-01", "Should be rendered!" - assert mapped_ti.task.map_template_xcom == "{{ ds }}", "XCom resolved but not double rendered!" - assert mapped_ti.task.map_template_file == "loaded data", "Should be rendered!" - - -def test_mapped_expand_kwargs_render_template_fields_validating_operator(dag_maker, session, tmp_path): - file_template_dir = tmp_path / "path" / "to" - file_template_dir.mkdir(parents=True, exist_ok=True) - file_template = file_template_dir / "file.ext" - file_template.write_text("loaded data") - - with set_current_task_instance_session(session=session): - with dag_maker(session=session, template_searchpath=tmp_path.__fspath__()): - mapped = _RenderTemplateFieldsValidationOperator.partial( - task_id="a", partial_template="{{ ti.task_id }}", partial_static="{{ ti.task_id }}" - ).expand_kwargs( - [ - { - "map_template_literal": "{{ ds }}", - "map_static": "{{ ds }}", - "map_template_file": "/path/to/file.ext", - # This field is not tested since XCom inside a literal list - # is not rendered (matching BaseOperator rendering behavior). - "map_template_xcom": "", - } - ] - ) - - dr = dag_maker.create_dagrun() - mapped_ti: TaskInstance = dr.get_task_instance(mapped.task_id, session=session, map_index=0) - assert isinstance(mapped_ti.task, MappedOperator) - mapped.render_template_fields(context=mapped_ti.get_template_context(session=session)) - assert isinstance(mapped_ti.task, _RenderTemplateFieldsValidationOperator) - - assert mapped_ti.task.partial_template == "a", "Should be rendered!" - assert mapped_ti.task.partial_static == "{{ ti.task_id }}", "Should not be rendered!" - assert mapped_ti.task.map_template_literal == "2016-01-01", "Should be rendered!" - assert mapped_ti.task.map_static == "{{ ds }}", "Should not be rendered!" - assert mapped_ti.task.map_template_file == "loaded data", "Should be rendered!" - - -def test_mapped_render_nested_template_fields(dag_maker, session): - with dag_maker(session=session): - MockOperatorWithNestedFields.partial( - task_id="t", arg2=NestedFields(field_1="{{ ti.task_id }}", field_2="value_2") - ).expand(arg1=["{{ ti.task_id }}", ["s", "{{ ti.task_id }}"]]) - - dr = dag_maker.create_dagrun() - decision = dr.task_instance_scheduling_decisions() - tis = {(ti.task_id, ti.map_index): ti for ti in decision.schedulable_tis} - assert len(tis) == 2 - - ti = tis[("t", 0)] - ti.run(session=session) - assert ti.task.arg1 == "t" - assert ti.task.arg2.field_1 == "t" - assert ti.task.arg2.field_2 == "value_2" - - ti = tis[("t", 1)] - ti.run(session=session) - assert ti.task.arg1 == ["s", "t"] - assert ti.task.arg2.field_1 == "t" - assert ti.task.arg2.field_2 == "value_2" - - @pytest.mark.parametrize( ["num_existing_tis", "expected"], ( @@ -467,6 +279,45 @@ def test_expand_kwargs_mapped_task_instance(dag_maker, session, num_existing_tis assert indices == expected +def test_map_product_expansion(dag_maker, session): + """Test the cross-product effect of mapping two inputs""" + outputs = [] + + with dag_maker(dag_id="product", session=session) as dag: + + @dag.task + def emit_numbers(): + return [1, 2] + + @dag.task + def emit_letters(): + return {"a": "x", "b": "y", "c": "z"} + + @dag.task + def show(number, letter): + outputs.append((number, letter)) + + show.expand(number=emit_numbers(), letter=emit_letters()) + + dr = dag_maker.create_dagrun() + for fn in (emit_numbers, emit_letters): + session.add( + TaskMap( + dag_id=dr.dag_id, + task_id=fn.__name__, + run_id=dr.run_id, + map_index=-1, + length=len(fn.function()), + keys=None, + ) + ) + + session.flush() + show_task = dag.get_task("show") + mapped_tis, max_map_index = TaskMap.expand_mapped_task(show_task, dr.run_id, session=session) + assert max_map_index + 1 == len(mapped_tis) == 6 + + def _create_mapped_with_name_template_classic(*, task_id, map_names, template): class HasMapName(BaseOperator): def __init__(self, *, map_name: str, **kwargs): @@ -591,68 +442,6 @@ def test_expand_mapped_task_task_instance_mutation_hook(dag_maker, session, crea assert call.args[0].map_index == expected_map_index[index] -@pytest.mark.parametrize( - "map_index, expected", - [ - pytest.param(0, "2016-01-01", id="0"), - pytest.param(1, 2, id="1"), - ], -) -def test_expand_kwargs_render_template_fields_validating_operator(dag_maker, session, map_index, expected): - with set_current_task_instance_session(session=session): - with dag_maker(session=session): - task1 = BaseOperator(task_id="op1") - mapped = MockOperator.partial(task_id="a", arg2="{{ ti.task_id }}").expand_kwargs(task1.output) - - dr = dag_maker.create_dagrun() - ti: TaskInstance = dr.get_task_instance(task1.task_id, session=session) - - ti.xcom_push(key=XCOM_RETURN_KEY, value=[{"arg1": "{{ ds }}"}, {"arg1": 2}], session=session) - - session.add( - TaskMap( - dag_id=dr.dag_id, - task_id=task1.task_id, - run_id=dr.run_id, - map_index=-1, - length=2, - keys=None, - ) - ) - session.flush() - - ti: TaskInstance = dr.get_task_instance(mapped.task_id, session=session) - ti.refresh_from_task(mapped) - ti.map_index = map_index - assert isinstance(ti.task, MappedOperator) - mapped.render_template_fields(context=ti.get_template_context(session=session)) - assert isinstance(ti.task, MockOperator) - assert ti.task.arg1 == expected - assert ti.task.arg2 == "a" - - -def test_all_xcomargs_from_mapped_tasks_are_consumable(dag_maker, session): - class PushXcomOperator(MockOperator): - def __init__(self, arg1, **kwargs): - super().__init__(arg1=arg1, **kwargs) - - def execute(self, context): - return self.arg1 - - class ConsumeXcomOperator(PushXcomOperator): - def execute(self, context): - assert set(self.arg1) == {1, 2, 3} - - with dag_maker("test_all_xcomargs_from_mapped_tasks_are_consumable"): - op1 = PushXcomOperator.partial(task_id="op1").expand(arg1=[1, 2, 3]) - ConsumeXcomOperator(task_id="op2", arg1=op1.output) - - dr = dag_maker.create_dagrun() - tis = dr.get_task_instances(session=session) - for ti in tis: - ti.run() - - class TestMappedSetupTeardown: @staticmethod def get_states(dr): @@ -1474,7 +1263,14 @@ def my_teardown(val): my_work(s) tg1, tg2 = dag.task_group.children.values() tg1 >> tg2 - dr = dag.test() + + with mock.patch( + "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True + ) as supervisor_comms: + # TODO: TaskSDK: this is a bit of a hack that we need to stub this at all. `dag.test()` should + # really work without this! + supervisor_comms.get_message.return_value = XComCountResponse(len=3) + dr = dag.test() states = self.get_states(dr) expected = { "tg_1.my_pre_setup": "success", diff --git a/tests/models/test_param.py b/tests/models/test_param.py deleted file mode 100644 index 77cf96eda2226..0000000000000 --- a/tests/models/test_param.py +++ /dev/null @@ -1,132 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -from contextlib import nullcontext - -import pytest - -from airflow.decorators import task -from airflow.exceptions import ParamValidationError -from airflow.sdk.definitions.param import Param -from airflow.utils import timezone -from airflow.utils.types import DagRunType - -from tests_common.test_utils.db import clear_db_dags, clear_db_runs, clear_db_xcom - - -class TestDagParamRuntime: - VALUE = 42 - DEFAULT_DATE = timezone.datetime(2016, 1, 1) - - @staticmethod - def clean_db(): - clear_db_runs() - clear_db_dags() - clear_db_xcom() - - def setup_class(self): - self.clean_db() - - def teardown_method(self): - self.clean_db() - - @pytest.mark.db_test - def test_dag_param_resolves(self, dag_maker): - """Test dagparam resolves on operator execution""" - with dag_maker(dag_id="test_xcom_pass_to_op") as dag: - value = dag.param("value", default=self.VALUE) - - @task - def return_num(num): - return num - - xcom_arg = return_num(value) - - dr = dag_maker.create_dagrun( - run_id=DagRunType.MANUAL.value, - start_date=timezone.utcnow(), - ) - - xcom_arg.operator.run(dr.logical_date, dr.logical_date) - - ti = dr.get_task_instances()[0] - assert ti.xcom_pull() == self.VALUE - - @pytest.mark.db_test - def test_dag_param_overwrite(self, dag_maker): - """Test dag param is overwritten from dagrun config""" - with dag_maker(dag_id="test_xcom_pass_to_op") as dag: - value = dag.param("value", default=self.VALUE) - - @task - def return_num(num): - return num - - xcom_arg = return_num(value) - - assert dag.params["value"] == self.VALUE - new_value = 2 - dr = dag_maker.create_dagrun( - run_id=DagRunType.MANUAL.value, - start_date=timezone.utcnow(), - conf={"value": new_value}, - ) - - xcom_arg.operator.run(dr.logical_date, dr.logical_date) - - ti = dr.get_task_instances()[0] - assert ti.xcom_pull() == new_value - - @pytest.mark.db_test - def test_dag_param_default(self, dag_maker): - """Test dag param is retrieved from default config""" - with dag_maker(dag_id="test_xcom_pass_to_op", params={"value": "test"}) as dag: - value = dag.param("value") - - @task - def return_num(num): - return num - - xcom_arg = return_num(value) - - dr = dag_maker.create_dagrun(run_id=DagRunType.MANUAL.value, start_date=timezone.utcnow()) - - xcom_arg.operator.run(dr.logical_date, dr.logical_date) - - ti = dr.get_task_instances()[0] - assert ti.xcom_pull() == "test" - - @pytest.mark.db_test - @pytest.mark.parametrize( - "default, should_raise", - [ - pytest.param({0, 1, 2}, True, id="default-non-JSON-serializable"), - pytest.param(None, False, id="default-None"), # Param init should not warn - pytest.param({"b": 1}, False, id="default-JSON-serializable"), # Param init should not warn - ], - ) - def test_param_json_validation(self, default, should_raise): - exception_msg = "All provided parameters must be json-serializable" - cm = pytest.raises(ParamValidationError, match=exception_msg) if should_raise else nullcontext() - with cm: - p = Param(default=default) - if not should_raise: - p.resolve() # when resolved with NOTSET, should not warn. - p.resolve(value={"a": 1}) # when resolved with JSON-serializable, should not warn. - with pytest.raises(ParamValidationError, match=exception_msg): - p.resolve(value={1, 2, 3}) # when resolved with not JSON-serializable, should warn. diff --git a/tests/models/test_renderedtifields.py b/tests/models/test_renderedtifields.py index dfb7b652252b4..5a9d5705e4eca 100644 --- a/tests/models/test_renderedtifields.py +++ b/tests/models/test_renderedtifields.py @@ -368,30 +368,38 @@ def test_write(self, dag_maker): ) @mock.patch.dict(os.environ, {"AIRFLOW_VAR_API_KEY": "secret"}) - @mock.patch("airflow.utils.log.secrets_masker.redact", autospec=True) - def test_redact(self, redact, dag_maker): - with dag_maker("test_ritf_redact", serialized=True): - task = BashOperator( - task_id="test", - bash_command="echo {{ var.value.api_key }}", - env={"foo": "secret", "other_api_key": "masked based on key name"}, - ) - dr = dag_maker.create_dagrun() - redact.side_effect = [ - # Order depends on order in Operator template_fields - "val 1", # bash_command - "val 2", # env - "val 3", # cwd - ] + def test_redact(self, dag_maker): + from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS - ti = dr.task_instances[0] - ti.task = task - rtif = RTIF(ti=ti) - assert rtif.rendered_fields == { - "bash_command": "val 1", - "env": "val 2", - "cwd": "val 3", - } + target = ( + "airflow.sdk.execution_time.secrets_masker.redact" + if AIRFLOW_V_3_0_PLUS + else "airflow.utils.log.secrets_masker.mask_secret.redact" + ) + + with mock.patch(target, autospec=True) as redact: + with dag_maker("test_ritf_redact", serialized=True): + task = BashOperator( + task_id="test", + bash_command="echo {{ var.value.api_key }}", + env={"foo": "secret", "other_api_key": "masked based on key name"}, + ) + dr = dag_maker.create_dagrun() + redact.side_effect = [ + # Order depends on order in Operator template_fields + "val 1", # bash_command + "val 2", # env + "val 3", # cwd + ] + + ti = dr.task_instances[0] + ti.task = task + rtif = RTIF(ti=ti) + assert rtif.rendered_fields == { + "bash_command": "val 1", + "env": "val 2", + "cwd": "val 3", + } def test_rtif_deletion_stale_data_error(self, dag_maker, session): """ diff --git a/tests/models/test_skipmixin.py b/tests/models/test_skipmixin.py index 5c3654d7e69c5..8f3b7e65dc386 100644 --- a/tests/models/test_skipmixin.py +++ b/tests/models/test_skipmixin.py @@ -120,6 +120,7 @@ def get_state(ti): assert executed_states == expected_states + @pytest.mark.need_serialized_dag def test_mapped_tasks_skip_all_except(self, dag_maker): with dag_maker("dag_test_skip_all_except") as dag: diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index 2e486a92d5fd6..3d985394646d0 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -22,7 +22,6 @@ import operator import os import pathlib -import pickle import signal import sys import urllib @@ -59,7 +58,6 @@ from airflow.models.dag import DAG from airflow.models.dagbag import DagBag from airflow.models.dagrun import DagRun -from airflow.models.expandinput import EXPAND_INPUT_EMPTY, NotFullyPopulated from airflow.models.pool import Pool from airflow.models.renderedtifields import RenderedTaskInstanceFields from airflow.models.serialized_dag import SerializedDagModel @@ -73,7 +71,7 @@ from airflow.models.taskmap import TaskMap from airflow.models.taskreschedule import TaskReschedule from airflow.models.variable import Variable -from airflow.models.xcom import LazyXComSelectSequence, XCom +from airflow.models.xcom import XCom from airflow.notifications.basenotifier import BaseNotifier from airflow.providers.standard.operators.bash import BashOperator from airflow.providers.standard.operators.empty import EmptyOperator @@ -95,7 +93,6 @@ from airflow.utils.session import create_session, provide_session from airflow.utils.state import DagRunState, State, TaskInstanceState from airflow.utils.task_group import TaskGroup -from airflow.utils.task_instance_session import set_current_task_instance_session from airflow.utils.types import DagRunTriggeredByType, DagRunType from airflow.utils.xcom import XCOM_RETURN_KEY @@ -1607,192 +1604,6 @@ def test_are_dependents_done( session.flush() assert ti.are_dependents_done(session) == expected_are_dependents_done - def test_xcom_pull(self, dag_maker): - """Test xcom_pull, using different filtering methods.""" - with dag_maker(dag_id="test_xcom") as dag: - task_1 = EmptyOperator(task_id="test_xcom_1") - task_2 = EmptyOperator(task_id="test_xcom_2") - - dagrun = dag_maker.create_dagrun(start_date=timezone.datetime(2016, 6, 1, 0, 0, 0)) - ti1 = dagrun.get_task_instance(task_1.task_id) - - # Push a value - ti1.xcom_push(key="foo", value="bar") - - # Push another value with the same key (but by a different task) - XCom.set(key="foo", value="baz", task_id=task_2.task_id, dag_id=dag.dag_id, run_id=dagrun.run_id) - - # Pull with no arguments - result = ti1.xcom_pull() - assert result is None - # Pull the value pushed most recently by any task. - result = ti1.xcom_pull(key="foo") - assert result in "baz" - # Pull the value pushed by the first task - result = ti1.xcom_pull(task_ids="test_xcom_1", key="foo") - assert result == "bar" - # Pull the value pushed by the second task - result = ti1.xcom_pull(task_ids="test_xcom_2", key="foo") - assert result == "baz" - # Pull the values pushed by both tasks & Verify Order of task_ids pass & values returned - result = ti1.xcom_pull(task_ids=["test_xcom_1", "test_xcom_2"], key="foo") - assert result == ["bar", "baz"] - - def test_xcom_pull_mapped(self, dag_maker, session): - with dag_maker(dag_id="test_xcom", session=session): - # Use the private _expand() method to avoid the empty kwargs check. - # We don't care about how the operator runs here, only its presence. - task_1 = EmptyOperator.partial(task_id="task_1")._expand(EXPAND_INPUT_EMPTY, strict=False) - EmptyOperator(task_id="task_2") - - dagrun = dag_maker.create_dagrun(start_date=timezone.datetime(2016, 6, 1, 0, 0, 0)) - - ti_1_0 = dagrun.get_task_instance("task_1", session=session) - ti_1_0.map_index = 0 - ti_1_1 = session.merge(TI(task_1, run_id=dagrun.run_id, map_index=1, state=ti_1_0.state)) - session.flush() - - ti_1_0.xcom_push(key=XCOM_RETURN_KEY, value="a", session=session) - ti_1_1.xcom_push(key=XCOM_RETURN_KEY, value="b", session=session) - - ti_2 = dagrun.get_task_instance("task_2", session=session) - - assert set(ti_2.xcom_pull(["task_1"], session=session)) == {"a", "b"} # Ordering not guaranteed. - assert ti_2.xcom_pull(["task_1"], map_indexes=0, session=session) == ["a"] - - assert ti_2.xcom_pull(map_indexes=[0, 1], session=session) == ["a", "b"] - assert ti_2.xcom_pull("task_1", map_indexes=[1, 0], session=session) == ["b", "a"] - assert ti_2.xcom_pull(["task_1"], map_indexes=[0, 1], session=session) == ["a", "b"] - - assert ti_2.xcom_pull("task_1", map_indexes=1, session=session) == "b" - assert list(ti_2.xcom_pull("task_1", session=session)) == ["a", "b"] - - def test_xcom_pull_after_success(self, create_task_instance): - """ - tests xcom set/clear relative to a task in a 'success' rerun scenario - """ - key = "xcom_key" - value = "xcom_value" - - ti = create_task_instance( - dag_id="test_xcom", - schedule="@monthly", - task_id="test_xcom", - pool="test_xcom", - serialized=True, - ) - - ti.run(mark_success=True) - ti.xcom_push(key=key, value=value) - assert ti.xcom_pull(task_ids="test_xcom", key=key) == value - ti.run() - # Check that we do not clear Xcom until the task is certain to execute - assert ti.xcom_pull(task_ids="test_xcom", key=key) == value - # Xcom shouldn't be cleared if the task doesn't execute, even if dependencies are ignored - ti.run(ignore_all_deps=True, mark_success=True) - assert ti.xcom_pull(task_ids="test_xcom", key=key) == value - # Xcom IS finally cleared once task has executed - ti.run(ignore_all_deps=True) - assert ti.xcom_pull(task_ids="test_xcom", key=key) is None - - def test_xcom_pull_after_deferral(self, create_task_instance, session): - """ - tests xcom will not clear before a task runs its next method after deferral. - """ - - key = "xcom_key" - value = "xcom_value" - - ti = create_task_instance( - dag_id="test_xcom", - schedule="@monthly", - task_id="test_xcom", - pool="test_xcom", - ) - - ti.run(mark_success=True) - ti.xcom_push(key=key, value=value) - - ti.next_method = "execute" - session.merge(ti) - session.commit() - - ti.run(ignore_all_deps=True) - assert ti.xcom_pull(task_ids="test_xcom", key=key) == value - - def test_xcom_pull_different_logical_date(self, create_task_instance): - """ - tests xcom fetch behavior with different logical dates, using - both xcom_pull with "include_prior_dates" and without - """ - key = "xcom_key" - value = "xcom_value" - - ti = create_task_instance( - dag_id="test_xcom", - schedule="@monthly", - task_id="test_xcom", - pool="test_xcom", - ) - exec_date = ti.dag_run.logical_date - - ti.run(mark_success=True) - ti.xcom_push(key=key, value=value) - assert ti.xcom_pull(task_ids="test_xcom", key=key) == value - ti.run() - exec_date += datetime.timedelta(days=1) - dr = ti.task.dag.create_dagrun( - run_id="test2", - run_type=DagRunType.MANUAL, - logical_date=exec_date, - data_interval=(exec_date, exec_date), - run_after=exec_date, - state=None, - triggered_by=DagRunTriggeredByType.TEST, - ) - ti = TI(task=ti.task, run_id=dr.run_id) - ti.run() - # We have set a new logical date (and did not pass in - # 'include_prior_dates'which means this task should now have a cleared - # xcom value - assert ti.xcom_pull(task_ids="test_xcom", key=key) is None - # We *should* get a value using 'include_prior_dates' - assert ti.xcom_pull(task_ids="test_xcom", key=key, include_prior_dates=True) == value - - def test_xcom_pull_different_run_ids(self, create_task_instance): - """ - tests xcom fetch behavior w/different run ids - """ - key = "xcom_key" - task_id = "test_xcom" - diff_run_id = "diff_run_id" - same_run_id_value = "xcom_value_same_run_id" - diff_run_id_value = "xcom_value_different_run_id" - - ti_same_run_id = create_task_instance( - dag_id="test_xcom", - task_id=task_id, - ) - ti_same_run_id.run(mark_success=True) - ti_same_run_id.xcom_push(key=key, value=same_run_id_value) - - ti_diff_run_id = create_task_instance( - dag_id="test_xcom", - task_id=task_id, - run_id=diff_run_id, - ) - ti_diff_run_id.run(mark_success=True) - ti_diff_run_id.xcom_push(key=key, value=diff_run_id_value) - - assert ( - ti_same_run_id.xcom_pull(run_id=ti_same_run_id.dag_run.run_id, task_ids=task_id, key=key) - == same_run_id_value - ) - assert ( - ti_same_run_id.xcom_pull(run_id=ti_diff_run_id.dag_run.run_id, task_ids=task_id, key=key) - == diff_run_id_value - ) - def test_xcom_push_flag(self, dag_maker): """ Tests the option for Operators to push XComs @@ -2303,7 +2114,13 @@ def test_outlet_assets(self, create_task_instance, testing_dag_bundle): session.flush() run_id = str(uuid4()) - dr = DagRun(dag1.dag_id, run_id=run_id, run_type="manual", state=DagRunState.RUNNING) + dr = DagRun( + dag1.dag_id, + run_id=run_id, + run_type="manual", + state=DagRunState.RUNNING, + logical_date=timezone.utcnow(), + ) session.merge(dr) task = dag1.get_task("producing_task_1") task.bash_command = "echo 1" # make it go faster @@ -2362,7 +2179,13 @@ def test_outlet_assets_failed(self, create_task_instance, testing_dag_bundle): dagbag.collect_dags(only_if_updated=False, safe_mode=False) dagbag.sync_to_db("testing", None, session=session) run_id = str(uuid4()) - dr = DagRun(dag_with_fail_task.dag_id, run_id=run_id, run_type="manual", state=DagRunState.RUNNING) + dr = DagRun( + dag_with_fail_task.dag_id, + run_id=run_id, + run_type="manual", + state=DagRunState.RUNNING, + logical_date=timezone.utcnow(), + ) session.merge(dr) task = dag_with_fail_task.get_task("fail_task") ti = TaskInstance(task, run_id=run_id) @@ -2421,7 +2244,13 @@ def test_outlet_assets_skipped(self, testing_dag_bundle): session.flush() run_id = str(uuid4()) - dr = DagRun(dag_with_skip_task.dag_id, run_id=run_id, run_type="manual", state=DagRunState.RUNNING) + dr = DagRun( + dag_with_skip_task.dag_id, + run_id=run_id, + run_type="manual", + state=DagRunState.RUNNING, + logical_date=timezone.utcnow(), + ) session.merge(dr) task = dag_with_skip_task.get_task("skip_task") ti = TaskInstance(task, run_id=run_id) @@ -2808,6 +2637,7 @@ def producer_with_inactive(*, outlet_events): assert asset_alias_obj.assets[0].uri == asset_name @pytest.mark.want_activate_assets(True) + @pytest.mark.need_serialized_dag def test_inlet_asset_extra(self, dag_maker, session): from airflow.sdk.definitions.asset import Asset @@ -2852,11 +2682,15 @@ def read(*, inlet_events): # Run "write1", "write2", and "write3" (in this order). decision = dr.task_instance_scheduling_decisions(session=session) for ti in sorted(decision.schedulable_tis, key=operator.attrgetter("task_id")): + # TODO: TaskSDK #45549 + ti.task = dag_maker.dag.get_task(ti.task_id) ti.run(session=session) # Run "read". decision = dr.task_instance_scheduling_decisions(session=session) for ti in decision.schedulable_tis: + # TODO: TaskSDK #45549 + ti.task = dag_maker.dag.get_task(ti.task_id) ti.run(session=session) # Should be done. @@ -2864,6 +2698,7 @@ def read(*, inlet_events): assert read_task_evaluated @pytest.mark.want_activate_assets(True) + @pytest.mark.need_serialized_dag def test_inlet_asset_alias_extra(self, dag_maker, session): from airflow.sdk.definitions.asset import Asset, AssetAlias @@ -2915,17 +2750,22 @@ def read(*, inlet_events): # Run "write1", "write2", and "write3" (in this order). decision = dr.task_instance_scheduling_decisions(session=session) for ti in sorted(decision.schedulable_tis, key=operator.attrgetter("task_id")): + # TODO: TaskSDK #45549 + ti.task = dag_maker.dag.get_task(ti.task_id) ti.run(session=session) # Run "read". decision = dr.task_instance_scheduling_decisions(session=session) for ti in decision.schedulable_tis: + # TODO: TaskSDK #45549 + ti.task = dag_maker.dag.get_task(ti.task_id) ti.run(session=session) # Should be done. assert not dr.task_instance_scheduling_decisions(session=session).schedulable_tis assert read_task_evaluated + @pytest.mark.need_serialized_dag def test_inlet_unresolved_asset_alias(self, dag_maker, session): asset_alias_name = "test_inlet_asset_extra_asset_alias" @@ -2946,6 +2786,8 @@ def read(*, inlet_events): dr: DagRun = dag_maker.create_dagrun() for ti in dr.get_task_instances(session=session): + # TODO: TaskSDK #45549 + ti.task = dag_maker.dag.get_task(ti.task_id) ti.run(session=session) # Should be done. @@ -4916,80 +4758,6 @@ def show(value): ti.run() assert outputs == expected_outputs - def test_map_product(self, dag_maker, session): - outputs = [] - - with dag_maker(dag_id="product", session=session) as dag: - - @dag.task - def emit_numbers(): - return [1, 2] - - @dag.task - def emit_letters(): - return {"a": "x", "b": "y", "c": "z"} - - @dag.task - def show(number, letter): - outputs.append((number, letter)) - - show.expand(number=emit_numbers(), letter=emit_letters()) - - dag_run = dag_maker.create_dagrun() - for task_id in ["emit_numbers", "emit_letters"]: - ti = dag_run.get_task_instance(task_id, session=session) - ti.refresh_from_task(dag.get_task(task_id)) - ti.run() - - show_task = dag.get_task("show") - mapped_tis, max_map_index = TaskMap.expand_mapped_task(show_task, dag_run.run_id, session=session) - assert max_map_index + 1 == len(mapped_tis) == 6 - - for ti in sorted(mapped_tis, key=operator.attrgetter("map_index")): - ti.refresh_from_task(show_task) - ti.run() - assert outputs == [ - (1, ("a", "x")), - (1, ("b", "y")), - (1, ("c", "z")), - (2, ("a", "x")), - (2, ("b", "y")), - (2, ("c", "z")), - ] - - def test_map_product_same(self, dag_maker, session): - """Test a mapped task can refer to the same source multiple times.""" - outputs = [] - - with dag_maker(dag_id="product_same", session=session) as dag: - - @dag.task - def emit_numbers(): - return [1, 2] - - @dag.task - def show(a, b): - outputs.append((a, b)) - - emit_task = emit_numbers() - show.expand(a=emit_task, b=emit_task) - - dag_run = dag_maker.create_dagrun() - ti = dag_run.get_task_instance("emit_numbers", session=session) - ti.refresh_from_task(dag.get_task("emit_numbers")) - ti.run() - - show_task = dag.get_task("show") - with pytest.raises(NotFullyPopulated): - assert show_task.get_parse_time_mapped_ti_count() - mapped_tis, max_map_index = TaskMap.expand_mapped_task(show_task, dag_run.run_id, session=session) - assert max_map_index + 1 == len(mapped_tis) == 4 - - for ti in sorted(mapped_tis, key=operator.attrgetter("map_index")): - ti.refresh_from_task(show_task) - ti.run() - assert outputs == [(1, 1), (1, 2), (2, 1), (2, 2)] - def test_map_literal_cross_product(self, dag_maker, session): """Test a mapped task with literal cross product args expand properly.""" outputs = [] @@ -5094,132 +4862,6 @@ def _get_lazy_xcom_access_expected_sql_lines() -> list[str]: raise RuntimeError(f"unknown backend {backend!r}") -def test_lazy_xcom_access_does_not_pickle_session(dag_maker, session): - with dag_maker(session=session): - EmptyOperator(task_id="t") - - run: DagRun = dag_maker.create_dagrun() - run.get_task_instance("t", session=session).xcom_push("xxx", 123, session=session) - - with set_current_task_instance_session(session=session): - original = LazyXComSelectSequence.from_select( - select(XCom.value).filter_by( - dag_id=run.dag_id, - run_id=run.run_id, - task_id="t", - map_index=-1, - key="xxx", - ), - order_by=(), - ) - processed = pickle.loads(pickle.dumps(original)) - - # After the object went through pickling, the underlying ORM query should be - # replaced by one backed by a literal SQL string with all variables binded. - sql_lines = [line.strip() for line in str(processed._select_asc.compile(None)).splitlines()] - assert sql_lines == _get_lazy_xcom_access_expected_sql_lines() - - assert len(processed) == 1 - assert list(processed) == [123] - - -@mock.patch("airflow.models.taskinstance.XCom.deserialize_value", side_effect=XCom.deserialize_value) -def test_ti_xcom_pull_on_mapped_operator_return_lazy_iterable(mock_deserialize_value, dag_maker, session): - """Ensure we access XCom lazily when pulling from a mapped operator.""" - with dag_maker(dag_id="test_xcom", session=session): - # Use the private _expand() method to avoid the empty kwargs check. - # We don't care about how the operator runs here, only its presence. - task_1 = EmptyOperator.partial(task_id="task_1")._expand(EXPAND_INPUT_EMPTY, strict=False) - EmptyOperator(task_id="task_2") - - dagrun = dag_maker.create_dagrun() - - ti_1_0 = dagrun.get_task_instance("task_1", session=session) - ti_1_0.map_index = 0 - ti_1_1 = session.merge(TaskInstance(task_1, run_id=dagrun.run_id, map_index=1, state=ti_1_0.state)) - session.flush() - - ti_1_0.xcom_push(key=XCOM_RETURN_KEY, value="a", session=session) - ti_1_1.xcom_push(key=XCOM_RETURN_KEY, value="b", session=session) - - ti_2 = dagrun.get_task_instance("task_2", session=session) - - # Simply pulling the joined XCom value should not deserialize. - joined = ti_2.xcom_pull("task_1", session=session) - assert isinstance(joined, LazyXComSelectSequence) - assert mock_deserialize_value.call_count == 0 - - # Only when we go through the iterable does deserialization happen. - it = iter(joined) - assert next(it) == "a" - assert mock_deserialize_value.call_count == 1 - assert next(it) == "b" - assert mock_deserialize_value.call_count == 2 - with pytest.raises(StopIteration): - next(it) - - -def test_ti_mapped_depends_on_mapped_xcom_arg(dag_maker, session): - with dag_maker(session=session) as dag: - - @dag.task - def add_one(x): - return x + 1 - - two_three_four = add_one.expand(x=[1, 2, 3]) - add_one.expand(x=two_three_four) - - dagrun = dag_maker.create_dagrun() - for map_index in range(3): - ti = dagrun.get_task_instance("add_one", map_index=map_index, session=session) - ti.refresh_from_task(dag.get_task("add_one")) - ti.run() - - task_345 = dag.get_task("add_one__1") - for ti in TaskMap.expand_mapped_task(task_345, dagrun.run_id, session=session)[0]: - ti.refresh_from_task(task_345) - ti.run() - - query = XCom.get_many(run_id=dagrun.run_id, task_ids=["add_one__1"], session=session) - assert [x.value for x in query.order_by(None).order_by(XCom.map_index)] == [3, 4, 5] - - -def test_mapped_upstream_return_none_should_skip(dag_maker, session): - results = set() - - with dag_maker(dag_id="test_mapped_upstream_return_none_should_skip", session=session) as dag: - - @dag.task() - def transform(value): - if value == "b": # Now downstream doesn't map against this! - return None - return value - - @dag.task() - def pull(value): - results.add(value) - - original = ["a", "b", "c"] - transformed = transform.expand(value=original) # ["a", None, "c"] - pull.expand(value=transformed) # ["a", "c"] - - dr = dag_maker.create_dagrun() - - decision = dr.task_instance_scheduling_decisions(session=session) - tis = {(ti.task_id, ti.map_index): ti for ti in decision.schedulable_tis} - assert sorted(tis) == [("transform", 0), ("transform", 1), ("transform", 2)] - for ti in tis.values(): - ti.run() - - decision = dr.task_instance_scheduling_decisions(session=session) - tis = {(ti.task_id, ti.map_index): ti for ti in decision.schedulable_tis} - assert sorted(tis) == [("pull", 0), ("pull", 1)] - for ti in tis.values(): - ti.run() - - assert results == {"a", "c"} - - def test_expand_non_templated_field(dag_maker, session): """Test expand on non-templated fields sets upstream deps properly.""" diff --git a/tests/models/test_taskmap.py b/tests/models/test_taskmap.py new file mode 100644 index 0000000000000..10fb4d99bba09 --- /dev/null +++ b/tests/models/test_taskmap.py @@ -0,0 +1,76 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import pytest + +from airflow.models.taskinstance import TaskInstance +from airflow.models.taskmap import TaskMap, TaskMapVariant +from airflow.providers.standard.operators.empty import EmptyOperator + +pytestmark = pytest.mark.db_test + + +def test_task_map_from_task_instance_xcom(): + task = EmptyOperator(task_id="test_task") + ti = TaskInstance(task=task, run_id="test_run", map_index=0) + ti.dag_id = "test_dag" + value = {"key1": "value1", "key2": "value2"} + + # Test case where run_id is not None + task_map = TaskMap.from_task_instance_xcom(ti, value) + assert task_map.dag_id == ti.dag_id + assert task_map.task_id == ti.task_id + assert task_map.run_id == ti.run_id + assert task_map.map_index == ti.map_index + assert task_map.length == len(value) + assert task_map.keys == list(value) + + # Test case where run_id is None + ti.run_id = None + with pytest.raises(ValueError, match="cannot record task map for unrun task instance"): + TaskMap.from_task_instance_xcom(ti, value) + + +def test_task_map_with_invalid_task_instance(): + task = EmptyOperator(task_id="test_task") + ti = TaskInstance(task=task, run_id=None, map_index=0) + ti.dag_id = "test_dag" + + # Define some arbitrary XCom-like value data + value = {"example_key": "example_value"} + + with pytest.raises(ValueError, match="cannot record task map for unrun task instance"): + TaskMap.from_task_instance_xcom(ti, value) + + +def test_task_map_variant(): + # Test case where keys is None + task_map = TaskMap( + dag_id="test_dag", + task_id="test_task", + run_id="test_run", + map_index=0, + length=3, + keys=None, + ) + assert task_map.variant == TaskMapVariant.LIST + + # Test case where keys is not None + task_map.keys = ["key1", "key2"] + assert task_map.variant == TaskMapVariant.DICT diff --git a/tests/models/test_variable.py b/tests/models/test_variable.py index 67a1079d0198f..2c7e9d2763403 100644 --- a/tests/models/test_variable.py +++ b/tests/models/test_variable.py @@ -302,7 +302,12 @@ def test_cache_invalidation_on_set(self, session): ], ) def test_masking_only_secret_values(variable_value, deserialize_json, expected_masked_values, session): - from airflow.utils.log.secrets_masker import _secrets_masker + from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS + + if AIRFLOW_V_3_0_PLUS: + from airflow.sdk.execution_time.secrets_masker import _secrets_masker + else: + from airflow.utils.log.secrets_masker import _secrets_masker SecretCache.reset() diff --git a/tests/models/test_xcom_arg_map.py b/tests/models/test_xcom_arg_map.py deleted file mode 100644 index 6d1ea09f3b217..0000000000000 --- a/tests/models/test_xcom_arg_map.py +++ /dev/null @@ -1,433 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -import pytest - -from airflow.exceptions import AirflowSkipException -from airflow.models.taskinstance import TaskInstance -from airflow.models.taskmap import TaskMap, TaskMapVariant -from airflow.providers.standard.operators.empty import EmptyOperator -from airflow.utils.state import TaskInstanceState -from airflow.utils.trigger_rule import TriggerRule - -pytestmark = pytest.mark.db_test - - -def test_xcom_map(dag_maker, session): - results = set() - with dag_maker(session=session) as dag: - - @dag.task - def push(): - return ["a", "b", "c"] - - @dag.task - def pull(value): - results.add(value) - - pull.expand_kwargs(push().map(lambda v: {"value": v * 2})) - - # The function passed to "map" is *NOT* a task. - assert set(dag.task_dict) == {"push", "pull"} - - dr = dag_maker.create_dagrun(session=session) - - # Run "push". - decision = dr.task_instance_scheduling_decisions(session=session) - for ti in decision.schedulable_tis: - ti.run(session=session) - session.commit() - - # Run "pull". - decision = dr.task_instance_scheduling_decisions(session=session) - tis = {(ti.task_id, ti.map_index): ti for ti in decision.schedulable_tis} - assert sorted(tis) == [("pull", 0), ("pull", 1), ("pull", 2)] - for ti in tis.values(): - ti.run(session=session) - - assert results == {"aa", "bb", "cc"} - - -def test_xcom_map_transform_to_none(dag_maker, session): - results = set() - - with dag_maker(session=session) as dag: - - @dag.task() - def push(): - return ["a", "b", "c"] - - @dag.task() - def pull(value): - results.add(value) - - def c_to_none(v): - if v == "c": - return None - return v - - pull.expand(value=push().map(c_to_none)) - - dr = dag_maker.create_dagrun(session=session) - - # Run "push". - decision = dr.task_instance_scheduling_decisions(session=session) - for ti in decision.schedulable_tis: - ti.run(session=session) - - # Run "pull". This should automatically convert "c" to None. - decision = dr.task_instance_scheduling_decisions(session=session) - for ti in decision.schedulable_tis: - ti.run(session=session) - assert results == {"a", "b", None} - - -def test_xcom_convert_to_kwargs_fails_task(dag_maker, session): - results = set() - - with dag_maker(session=session) as dag: - - @dag.task() - def push(): - return ["a", "b", "c"] - - @dag.task() - def pull(value): - results.add(value) - - def c_to_none(v): - if v == "c": - return None - return {"value": v} - - pull.expand_kwargs(push().map(c_to_none)) - - dr = dag_maker.create_dagrun(session=session) - - # Run "push". - decision = dr.task_instance_scheduling_decisions(session=session) - for ti in decision.schedulable_tis: - ti.run(session=session) - - # Prepare to run "pull"... - decision = dr.task_instance_scheduling_decisions(session=session) - tis = {(ti.task_id, ti.map_index): ti for ti in decision.schedulable_tis} - - # The first two "pull" tis should also succeed. - tis[("pull", 0)].run(session=session) - tis[("pull", 1)].run(session=session) - - # But the third one fails because the map() result cannot be used as kwargs. - with pytest.raises(ValueError) as ctx: - tis[("pull", 2)].run(session=session) - assert str(ctx.value) == "expand_kwargs() expects a list[dict], not list[None]" - - assert [tis[("pull", i)].state for i in range(3)] == [ - TaskInstanceState.SUCCESS, - TaskInstanceState.SUCCESS, - TaskInstanceState.FAILED, - ] - - -def test_xcom_map_error_fails_task(dag_maker, session): - with dag_maker(session=session) as dag: - - @dag.task() - def push(): - return ["a", "b", "c"] - - @dag.task() - def pull(value): - print(value) - - def does_not_work_with_c(v): - if v == "c": - raise ValueError("nope") - return {"value": v * 2} - - pull.expand_kwargs(push().map(does_not_work_with_c)) - - dr = dag_maker.create_dagrun(session=session) - - # The "push" task should not fail. - decision = dr.task_instance_scheduling_decisions(session=session) - for ti in decision.schedulable_tis: - ti.run(session=session) - assert [ti.state for ti in decision.schedulable_tis] == [TaskInstanceState.SUCCESS] - - # Prepare to run "pull"... - decision = dr.task_instance_scheduling_decisions(session=session) - tis = {(ti.task_id, ti.map_index): ti for ti in decision.schedulable_tis} - - # The first two "pull" tis should also succeed. - tis[("pull", 0)].run(session=session) - tis[("pull", 1)].run(session=session) - - # But the third one (for "c") will fail. - with pytest.raises(ValueError) as ctx: - tis[("pull", 2)].run(session=session) - assert str(ctx.value) == "nope" - - assert [tis[("pull", i)].state for i in range(3)] == [ - TaskInstanceState.SUCCESS, - TaskInstanceState.SUCCESS, - TaskInstanceState.FAILED, - ] - - -def test_task_map_from_task_instance_xcom(): - task = EmptyOperator(task_id="test_task") - ti = TaskInstance(task=task, run_id="test_run", map_index=0) - ti.dag_id = "test_dag" - value = {"key1": "value1", "key2": "value2"} - - # Test case where run_id is not None - task_map = TaskMap.from_task_instance_xcom(ti, value) - assert task_map.dag_id == ti.dag_id - assert task_map.task_id == ti.task_id - assert task_map.run_id == ti.run_id - assert task_map.map_index == ti.map_index - assert task_map.length == len(value) - assert task_map.keys == list(value) - - # Test case where run_id is None - ti.run_id = None - with pytest.raises(ValueError, match="cannot record task map for unrun task instance"): - TaskMap.from_task_instance_xcom(ti, value) - - -def test_task_map_with_invalid_task_instance(): - task = EmptyOperator(task_id="test_task") - ti = TaskInstance(task=task, run_id=None, map_index=0) - ti.dag_id = "test_dag" - - # Define some arbitrary XCom-like value data - value = {"example_key": "example_value"} - - with pytest.raises(ValueError, match="cannot record task map for unrun task instance"): - TaskMap.from_task_instance_xcom(ti, value) - - -def test_task_map_variant(): - # Test case where keys is None - task_map = TaskMap( - dag_id="test_dag", - task_id="test_task", - run_id="test_run", - map_index=0, - length=3, - keys=None, - ) - assert task_map.variant == TaskMapVariant.LIST - - # Test case where keys is not None - task_map.keys = ["key1", "key2"] - assert task_map.variant == TaskMapVariant.DICT - - -def test_xcom_map_raise_to_skip(dag_maker, session): - result = None - - with dag_maker(session=session) as dag: - - @dag.task() - def push(): - return ["a", "b", "c"] - - @dag.task() - def forward(value): - return value - - @dag.task(trigger_rule=TriggerRule.ALL_DONE) - def collect(value): - nonlocal result - result = list(value) - - def skip_c(v): - if v == "c": - raise AirflowSkipException - return {"value": v} - - collect(value=forward.expand_kwargs(push().map(skip_c))) - - dr = dag_maker.create_dagrun(session=session) - - # Run "push". - decision = dr.task_instance_scheduling_decisions(session=session) - for ti in decision.schedulable_tis: - ti.run(session=session) - - # Run "forward". This should automatically skip "c". - decision = dr.task_instance_scheduling_decisions(session=session) - for ti in decision.schedulable_tis: - ti.run(session=session) - - # Now "collect" should only get "a" and "b". - decision = dr.task_instance_scheduling_decisions(session=session) - for ti in decision.schedulable_tis: - ti.run(session=session) - assert result == ["a", "b"] - - -def test_xcom_map_nest(dag_maker, session): - results = set() - - with dag_maker(session=session) as dag: - - @dag.task() - def push(): - return ["a", "b", "c"] - - @dag.task() - def pull(value): - results.add(value) - - converted = push().map(lambda v: v * 2).map(lambda v: {"value": v}) - pull.expand_kwargs(converted) - - dr = dag_maker.create_dagrun(session=session) - - # Run "push". - decision = dr.task_instance_scheduling_decisions(session=session) - for ti in decision.schedulable_tis: - ti.run(session=session) - - session.flush() - session.commit() - - # Now "pull" should apply the mapping functions in order. - decision = dr.task_instance_scheduling_decisions(session=session) - for ti in decision.schedulable_tis: - ti.run(session=session) - assert results == {"aa", "bb", "cc"} - - -def test_xcom_map_zip_nest(dag_maker, session): - results = set() - - with dag_maker(session=session) as dag: - - @dag.task - def push_letters(): - return ["a", "b", "c", "d"] - - @dag.task - def push_numbers(): - return [1, 2, 3, 4] - - @dag.task - def pull(value): - results.add(value) - - doubled = push_numbers().map(lambda v: v * 2) - combined = doubled.zip(push_letters()) - - def convert_zipped(zipped): - letter, number = zipped - return letter * number - - pull.expand(value=combined.map(convert_zipped)) - - dr = dag_maker.create_dagrun(session=session) - - # Run "push_letters" and "push_numbers". - decision = dr.task_instance_scheduling_decisions(session=session) - assert decision.schedulable_tis - assert all(ti.task_id.startswith("push_") for ti in decision.schedulable_tis) - for ti in decision.schedulable_tis: - ti.run(session=session) - session.commit() - - # Run "pull". - decision = dr.task_instance_scheduling_decisions(session=session) - assert decision.schedulable_tis - assert all(ti.task_id == "pull" for ti in decision.schedulable_tis) - for ti in decision.schedulable_tis: - ti.run(session=session) - - assert results == {"aa", "bbbb", "cccccc", "dddddddd"} - - -def test_xcom_concat(dag_maker, session): - from airflow.sdk.definitions.xcom_arg import _ConcatResult - - agg_results = set() - all_results = None - - with dag_maker(session=session) as dag: - - @dag.task - def push_letters(): - return ["a", "b", "c"] - - @dag.task - def push_numbers(): - return [1, 2] - - @dag.task - def pull_one(value): - agg_results.add(value) - - @dag.task - def pull_all(value): - assert isinstance(value, _ConcatResult) - assert value[0] == "a" - assert value[1] == "b" - assert value[2] == "c" - assert value[3] == 1 - assert value[4] == 2 - with pytest.raises(IndexError): - value[5] - assert value[-5] == "a" - assert value[-4] == "b" - assert value[-3] == "c" - assert value[-2] == 1 - assert value[-1] == 2 - with pytest.raises(IndexError): - value[-6] - nonlocal all_results - all_results = list(value) - - pushed_values = push_letters().concat(push_numbers()) - - pull_one.expand(value=pushed_values) - pull_all(pushed_values) - - dr = dag_maker.create_dagrun(session=session) - - # Run "push_letters" and "push_numbers". - decision = dr.task_instance_scheduling_decisions(session=session) - assert len(decision.schedulable_tis) == 2 - assert all(ti.task_id.startswith("push_") for ti in decision.schedulable_tis) - for ti in decision.schedulable_tis: - ti.run(session=session) - session.commit() - - # Run "pull_one" and "pull_all". - decision = dr.task_instance_scheduling_decisions(session=session) - assert len(decision.schedulable_tis) == 6 - assert all(ti.task_id.startswith("pull_") for ti in decision.schedulable_tis) - for ti in decision.schedulable_tis: - ti.run(session=session) - - assert agg_results == {"a", "b", "c", 1, 2} - assert all_results == ["a", "b", "c", 1, 2] - - decision = dr.task_instance_scheduling_decisions(session=session) - assert not decision.schedulable_tis diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py index 37364e2fcbbd1..281bcc73c346b 100644 --- a/tests/serialization/test_dag_serialization.py +++ b/tests/serialization/test_dag_serialization.py @@ -40,7 +40,7 @@ from typing import TYPE_CHECKING from unittest import mock -import attr +import attrs import pendulum import pytest from dateutil.relativedelta import FR, relativedelta @@ -694,7 +694,7 @@ def validate_deserialized_task( } else: # Promised to be mapped by the assert above. assert isinstance(serialized_task, MappedOperator) - fields_to_check = {f.name for f in attr.fields(MappedOperator)} + fields_to_check = {f.name for f in attrs.fields(MappedOperator)} fields_to_check -= { "map_index_template", # Matching logic in BaseOperator.get_serialized_fields(). @@ -707,6 +707,7 @@ def validate_deserialized_task( # Checked separately. "operator_class", "partial_kwargs", + "expand_input", } fields_to_check |= {"deps"} @@ -753,6 +754,9 @@ def validate_deserialized_task( original_partial_kwargs = {**default_partial_kwargs, **task.partial_kwargs} assert serialized_partial_kwargs == original_partial_kwargs + # ExpandInputs have different classes between scheduler and definition + assert attrs.asdict(serialized_task.expand_input) == attrs.asdict(task.expand_input) + @pytest.mark.parametrize( "dag_start_date, task_start_date, expected_task_start_date", [ @@ -2452,7 +2456,8 @@ def test_operator_expand_serde(): def test_operator_expand_xcomarg_serde(): - from airflow.models.xcom_arg import PlainXComArg, XComArg + from airflow.models.xcom_arg import SchedulerPlainXComArg + from airflow.sdk.definitions.xcom_arg import XComArg from airflow.serialization.serialized_objects import _XComRef with DAG("test-dag", schedule=None, start_date=datetime(2020, 1, 1)) as dag: @@ -2501,13 +2506,13 @@ def test_operator_expand_xcomarg_serde(): serialized_dag: DAG = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) xcom_arg = serialized_dag.task_dict["task_2"].expand_input.value["arg2"] - assert isinstance(xcom_arg, PlainXComArg) + assert isinstance(xcom_arg, SchedulerPlainXComArg) assert xcom_arg.operator is serialized_dag.task_dict["op1"] @pytest.mark.parametrize("strict", [True, False]) def test_operator_expand_kwargs_literal_serde(strict): - from airflow.models.xcom_arg import PlainXComArg, XComArg + from airflow.sdk.definitions.xcom_arg import XComArg from airflow.serialization.serialized_objects import _XComRef with DAG("test-dag", schedule=None, start_date=datetime(2020, 1, 1)) as dag: @@ -2567,15 +2572,16 @@ def test_operator_expand_kwargs_literal_serde(strict): serialized_dag: DAG = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) resolved_expand_value = serialized_dag.task_dict["task_2"].expand_input.value - resolved_expand_value == [ + assert resolved_expand_value == [ {"a": "x"}, - {"a": PlainXComArg(serialized_dag.task_dict["op1"])}, + {"a": _XComRef({"task_id": "op1", "key": "return_value"})}, ] @pytest.mark.parametrize("strict", [True, False]) def test_operator_expand_kwargs_xcomarg_serde(strict): - from airflow.models.xcom_arg import PlainXComArg, XComArg + from airflow.models.xcom_arg import SchedulerPlainXComArg + from airflow.sdk.definitions.xcom_arg import XComArg from airflow.serialization.serialized_objects import _XComRef with DAG("test-dag", schedule=None, start_date=datetime(2020, 1, 1)) as dag: @@ -2620,7 +2626,7 @@ def test_operator_expand_kwargs_xcomarg_serde(strict): serialized_dag: DAG = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) xcom_arg = serialized_dag.task_dict["task_2"].expand_input.value - assert isinstance(xcom_arg, PlainXComArg) + assert isinstance(xcom_arg, SchedulerPlainXComArg) assert xcom_arg.operator is serialized_dag.task_dict["op1"] @@ -2901,7 +2907,7 @@ def x(arg1, arg2, arg3): def test_mapped_task_group_serde(): from airflow.decorators.task_group import task_group - from airflow.models.expandinput import DictOfListsExpandInput + from airflow.models.expandinput import SchedulerDictOfListsExpandInput from airflow.sdk.definitions.taskgroup import MappedTaskGroup with DAG("test-dag", schedule=None, start_date=datetime(2020, 1, 1)) as dag: @@ -2944,7 +2950,7 @@ def tg(a: str) -> None: serde_dag = SerializedDAG.deserialize_dag(ser_dag[Encoding.VAR]) serde_tg = serde_dag.task_group.children["tg"] assert isinstance(serde_tg, MappedTaskGroup) - assert serde_tg._expand_input == DictOfListsExpandInput({"a": [".", ".."]}) + assert serde_tg._expand_input == SchedulerDictOfListsExpandInput({"a": [".", ".."]}) @pytest.mark.db_test diff --git a/tests/ti_deps/deps/test_mapped_task_upstream_dep.py b/tests/ti_deps/deps/test_mapped_task_upstream_dep.py index 3f1abdc3136ca..b47da12e672a1 100644 --- a/tests/ti_deps/deps/test_mapped_task_upstream_dep.py +++ b/tests/ti_deps/deps/test_mapped_task_upstream_dep.py @@ -27,8 +27,9 @@ from airflow.ti_deps.deps.base_ti_dep import TIDepStatus from airflow.ti_deps.deps.mapped_task_upstream_dep import MappedTaskUpstreamDep from airflow.utils.state import TaskInstanceState +from airflow.utils.xcom import XCOM_RETURN_KEY -pytestmark = pytest.mark.db_test +pytestmark = [pytest.mark.db_test, pytest.mark.need_serialized_dag] if TYPE_CHECKING: from sqlalchemy.orm.session import Session @@ -274,35 +275,6 @@ def tg(x, y): assert not schedulable_tis -def test_nested_mapped_task_groups(dag_maker, session: Session): - from airflow.decorators import task, task_group - - with dag_maker(session=session): - - @task - def t(): - return [[1, 2], [3, 4]] - - @task - def m(x): - return x - - @task_group - def g1(x): - @task_group - def g2(y): - return m(y) - - return g2.expand(y=x) - - g1.expand(x=t()) - - # Add a test once nested mapped task groups become supported - with pytest.raises(NotImplementedError) as ctx: - dag_maker.create_dagrun() - assert str(ctx.value) == "" - - def test_mapped_in_mapped_task_group(dag_maker, session: Session): from airflow.decorators import task, task_group @@ -391,7 +363,7 @@ def tg(): assert not get_dep_statuses(dr, "tg.op1", session) -@pytest.mark.parametrize("upstream_instance_state", [None, SKIPPED, FAILED]) +@pytest.mark.parametrize("upstream_instance_state", [SUCCESS, SKIPPED, FAILED]) @pytest.mark.parametrize("testcase", ["task", "group"]) def test_upstream_mapped_expanded( dag_maker, session: Session, upstream_instance_state: TaskInstanceState | None, testcase: str @@ -432,35 +404,46 @@ def tg(x): assert sorted(schedulable_tis) == [f"{mapped_task_1}_0", f"{mapped_task_1}_1", f"{mapped_task_1}_2"] assert not finished_tis_states - # Run expanded m1 tasks - schedulable_tis[f"{mapped_task_1}_1"].run() - schedulable_tis[f"{mapped_task_1}_2"].run() - if upstream_instance_state != FAILED: - schedulable_tis[f"{mapped_task_1}_0"].run() - else: - with pytest.raises(AirflowFailException): - schedulable_tis[f"{mapped_task_1}_0"].run() + # "Run" expanded m1 tasks + for ti, state in ( + (schedulable_tis[f"{mapped_task_1}_1"], SUCCESS), + (schedulable_tis[f"{mapped_task_1}_2"], SUCCESS), + (schedulable_tis[f"{mapped_task_1}_0"], upstream_instance_state), + ): + ti.state = state + if state == SUCCESS: + ti.xcom_push(XCOM_RETURN_KEY, "doesn't matter", session=session) + session.flush() schedulable_tis, finished_tis_states = _one_scheduling_decision_iteration(dr, session) # Expect that m2 can still be expanded since the dependency check does not fail. If one of the expanded # m1 tasks fails or is skipped, there is one fewer m2 expanded tasks expected_schedulable = [f"{mapped_task_2}_0", f"{mapped_task_2}_1"] - if upstream_instance_state is None: + if upstream_instance_state is SUCCESS: expected_schedulable.append(f"{mapped_task_2}_2") assert list(schedulable_tis.keys()) == expected_schedulable # Run the expanded m2 tasks - schedulable_tis[f"{mapped_task_2}_0"].run() - schedulable_tis[f"{mapped_task_2}_1"].run() - if upstream_instance_state is None: - schedulable_tis[f"{mapped_task_2}_2"].run() + + to_run: tuple[tuple[TaskInstance, TaskInstanceState], ...] = ( + (schedulable_tis[f"{mapped_task_2}_0"], SUCCESS), + (schedulable_tis[f"{mapped_task_2}_1"], SUCCESS), + ) + if upstream_instance_state == SUCCESS: + to_run += ((schedulable_tis[f"{mapped_task_2}_2"], upstream_instance_state),) + for ti, state in to_run: + ti.state = state + if state is SUCCESS: + ti.xcom_push(XCOM_RETURN_KEY, "doesn't matter", session=session) + session.flush() schedulable_tis, finished_tis_states = _one_scheduling_decision_iteration(dr, session) assert not schedulable_tis + expected_finished_tis_states = { ti: "success" for ti in (f"{mapped_task_1}_1", f"{mapped_task_1}_2", f"{mapped_task_2}_0", f"{mapped_task_2}_1") } - if upstream_instance_state is None: + if upstream_instance_state is SUCCESS: expected_finished_tis_states[f"{mapped_task_1}_0"] = "success" expected_finished_tis_states[f"{mapped_task_2}_2"] = "success" else: diff --git a/tests/timetables/test_trigger_timetable.py b/tests/timetables/test_trigger_timetable.py index 4b4deb6eb6b04..0cca3f1ca81e4 100644 --- a/tests/timetables/test_trigger_timetable.py +++ b/tests/timetables/test_trigger_timetable.py @@ -26,7 +26,7 @@ from airflow.exceptions import AirflowTimetableInvalid from airflow.timetables.base import DagRunInfo, DataInterval, TimeRestriction -from airflow.timetables.trigger import CronTriggerTimetable +from airflow.timetables.trigger import CronTriggerTimetable, MultipleCronTriggerTimetable from airflow.utils.timezone import utc START_DATE = pendulum.DateTime(2021, 9, 4, tzinfo=utc) @@ -285,3 +285,36 @@ def test_run_immediately_fast_dag(catchup): restriction=TimeRestriction(earliest=None, latest=None, catchup=catchup), ) assert next_info == PREVIOUS + + +@pytest.mark.parametrize( + "start_date, expected", + [ + (pendulum.datetime(2025, 1, 1), pendulum.datetime(2025, 1, 1)), + (pendulum.datetime(2025, 1, 1, minute=5), pendulum.datetime(2025, 1, 1, minute=30)), + (pendulum.datetime(2025, 1, 1, minute=35), pendulum.datetime(2025, 1, 1, hour=1)), + ], +) +def test_multi_run_first(start_date, expected): + timetable = MultipleCronTriggerTimetable("@hourly", "30 * * * *", timezone=utc) + next_info = timetable.next_dagrun_info( + last_automated_data_interval=None, + restriction=TimeRestriction(earliest=start_date, latest=None, catchup=True), + ) + assert next_info == DagRunInfo.exact(expected) + + +@pytest.mark.parametrize( + "last, expected", + [ + (pendulum.datetime(2025, 1, 1), pendulum.datetime(2025, 1, 1, minute=30)), + (pendulum.datetime(2025, 1, 1, minute=30), pendulum.datetime(2025, 1, 1, hour=1)), + ], +) +def test_multi_run_next(last, expected): + timetable = MultipleCronTriggerTimetable("@hourly", "30 * * * *", timezone=utc) + next_info = timetable.next_dagrun_info( + last_automated_data_interval=DataInterval.exact(last), + restriction=TimeRestriction(earliest=None, latest=None, catchup=True), + ) + assert next_info == DagRunInfo.exact(expected) diff --git a/tests/www/views/test_views_tasks.py b/tests/www/views/test_views_tasks.py index a4d6ef6afc23b..04160a83bf6f3 100644 --- a/tests/www/views/test_views_tasks.py +++ b/tests/www/views/test_views_tasks.py @@ -395,6 +395,7 @@ def test_rendered_k8s_without_k8s(admin_client): def test_tree_trigger_origin_tree_view(app, admin_client): + clear_db_runs() app.dag_bag.get_dag("example_bash_operator").create_dagrun( run_id="test", run_type=DagRunType.SCHEDULED, @@ -414,6 +415,7 @@ def test_tree_trigger_origin_tree_view(app, admin_client): def test_graph_trigger_origin_grid_view(app, admin_client): + clear_db_runs() app.dag_bag.get_dag("example_bash_operator").create_dagrun( run_id="test", run_type=DagRunType.SCHEDULED, @@ -433,6 +435,7 @@ def test_graph_trigger_origin_grid_view(app, admin_client): def test_gantt_trigger_origin_grid_view(app, admin_client): + clear_db_runs() app.dag_bag.get_dag("example_bash_operator").create_dagrun( run_id="test", run_type=DagRunType.SCHEDULED, diff --git a/tests_common/pytest_plugin.py b/tests_common/pytest_plugin.py index 28da87def3c9f..8b2ee9e6646a3 100644 --- a/tests_common/pytest_plugin.py +++ b/tests_common/pytest_plugin.py @@ -544,8 +544,7 @@ def skip_db_test(item): # also automatically skip tests marked with `backend` marker as they are implicitly # db tests pytest.skip( - f"The test is skipped as it is DB test " - f"and --skip-db-tests is flag is passed to pytest. {item}" + f"The test is skipped as it is DB test and --skip-db-tests is flag is passed to pytest. {item}" ) @@ -788,7 +787,12 @@ def __enter__(self): self.dag.__enter__() if self.want_serialized: - return lazy_object_proxy.Proxy(self._serialized_dag) + + class DAGProxy(lazy_object_proxy.Proxy): + # Make `@dag.task` decorator work when need_serialized_dag marker is set + task = self.dag.task + + return DAGProxy(self._serialized_dag) return self.dag def _serialized_dag(self): @@ -868,6 +872,10 @@ def __exit__(self, type, value, traceback): if self.want_activate_assets: self._activate_assets() if sdm: + sdm._SerializedDagModel__data_cache = ( + self.serialized_model._SerializedDagModel__data_cache + ) + sdm._data = self.serialized_model._data self.serialized_model = sdm else: self.session.merge(self.serialized_model) @@ -937,9 +945,13 @@ def create_dagrun(self, *, logical_date=None, **kwargs): kwargs.pop("dag_version", None) kwargs.pop("triggered_by", None) kwargs["execution_date"] = logical_date + + if self.want_serialized: + dag = self.serialized_model.dag self.dag_run = dag.create_dagrun(**kwargs) for ti in self.dag_run.task_instances: - ti.refresh_from_task(dag.get_task(ti.task_id)) + # This need to always operate on the _real_ dag + ti.refresh_from_task(self.dag.get_task(ti.task_id)) if self.want_serialized: self.session.commit() return self.dag_run @@ -1542,13 +1554,21 @@ def _disable_redact(request: pytest.FixtureRequest, mocker): """Disable redacted text in tests, except specific.""" from airflow import settings + from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS + if next(request.node.iter_markers("enable_redact"), None): with pytest.MonkeyPatch.context() as mp_ctx: mp_ctx.setattr(settings, "MASK_SECRETS_IN_LOGS", True) yield return - mocked_redact = mocker.patch("airflow.utils.log.secrets_masker.SecretsMasker.redact") + target = ( + "airflow.sdk.execution_time.secrets_masker.SecretsMasker.redact" + if AIRFLOW_V_3_0_PLUS + else "airflow.utils.log.secrets_masker.SecretsMasker.redact" + ) + + mocked_redact = mocker.patch(target) mocked_redact.side_effect = lambda item, name=None, max_depth=None: item with pytest.MonkeyPatch.context() as mp_ctx: mp_ctx.setattr(settings, "MASK_SECRETS_IN_LOGS", False) @@ -1627,3 +1647,25 @@ def url_safe_serializer(secret_key) -> URLSafeSerializer: from itsdangerous import URLSafeSerializer return URLSafeSerializer(secret_key) + + +@pytest.fixture +def create_db_api_hook(request): + from unittest.mock import MagicMock + + from sqlalchemy.engine import Inspector + + from airflow.providers.common.sql.hooks.sql import DbApiHook + + columns, primary_keys, reserved_words, escape_column_names = request.param + + inspector = MagicMock(spec=Inspector) + inspector.get_columns.side_effect = lambda table_name, schema: columns + + test_db_hook = MagicMock(placeholder="?", inspector=inspector, spec=DbApiHook) + test_db_hook.run.side_effect = lambda *args: primary_keys + test_db_hook.reserved_words = reserved_words + test_db_hook.escape_word_format = "[{}]" + test_db_hook.escape_column_names = escape_column_names or False + + return test_db_hook diff --git a/tests_common/test_utils/mock_context.py b/tests_common/test_utils/mock_context.py index 3c8a101ce2058..4391490dfb4a2 100644 --- a/tests_common/test_utils/mock_context.py +++ b/tests_common/test_utils/mock_context.py @@ -56,14 +56,16 @@ def xcom_pull( default: Any = None, run_id: str | None = None, ) -> Any: - if map_indexes: - return values.get( - f"{task_ids or self.task_id}_{dag_id or self.dag_id}_{key}_{map_indexes}", default - ) - return values.get(f"{task_ids or self.task_id}_{dag_id or self.dag_id}_{key}", default) + key = f"{self.task_id}_{self.dag_id}_{key}" + if map_indexes is not None and (not isinstance(map_indexes, int) or map_indexes >= 0): + key += f"_{map_indexes}" + return values.get(key, default) def xcom_push(self, key: str, value: Any, session: Session = NEW_SESSION, **kwargs) -> None: - values[f"{self.task_id}_{self.dag_id}_{key}_{self.map_index}"] = value + key = f"{self.task_id}_{self.dag_id}_{key}" + if self.map_index is not None and self.map_index >= 0: + key += f"_{self.map_index}" + values[key] = value values["ti"] = MockedTaskInstance(task=task) diff --git a/tests_common/test_utils/mock_operators.py b/tests_common/test_utils/mock_operators.py index 9c9e3cfb5488f..81f53abf648ce 100644 --- a/tests_common/test_utils/mock_operators.py +++ b/tests_common/test_utils/mock_operators.py @@ -17,7 +17,7 @@ from __future__ import annotations from collections.abc import Sequence -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING import attr @@ -27,8 +27,6 @@ from tests_common.test_utils.compat import BaseOperatorLink if TYPE_CHECKING: - import jinja2 - from airflow.sdk.definitions.context import Context @@ -46,48 +44,6 @@ def execute(self, context: Context): pass -class NestedFields: - """Nested fields for testing purposes.""" - - def __init__(self, field_1, field_2): - self.field_1 = field_1 - self.field_2 = field_2 - - -class MockOperatorWithNestedFields(BaseOperator): - """Operator with nested fields for testing purposes.""" - - template_fields: Sequence[str] = ("arg1", "arg2") - - def __init__(self, arg1: str = "", arg2: NestedFields | None = None, **kwargs): - super().__init__(**kwargs) - self.arg1 = arg1 - self.arg2 = arg2 - - def _render_nested_template_fields( - self, - content: Any, - context: Context, - jinja_env: jinja2.Environment, - seen_oids: set, - ) -> None: - if id(content) not in seen_oids: - template_fields: tuple | None = None - - if isinstance(content, NestedFields): - template_fields = ("field_1", "field_2") - - if template_fields: - seen_oids.add(id(content)) - self._do_render_template_fields(content, template_fields, context, jinja_env, seen_oids) - return - - super()._render_nested_template_fields(content, context, jinja_env, seen_oids) - - def execute(self, context: Context): - pass - - class AirflowLink(BaseOperatorLink): """Operator Link for Apache Airflow Website.""" diff --git a/tests_common/test_utils/version_compat.py b/tests_common/test_utils/version_compat.py index 21e7170194e36..7227de2d85962 100644 --- a/tests_common/test_utils/version_compat.py +++ b/tests_common/test_utils/version_compat.py @@ -34,3 +34,4 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]: AIRFLOW_V_2_10_PLUS = get_base_airflow_version_tuple() >= (2, 10, 0) AIRFLOW_V_3_0_PLUS = get_base_airflow_version_tuple() >= (3, 0, 0) +[].sort()